kernels-bot commited on
Commit
d8e8ea2
·
verified ·
1 Parent(s): a7f41ae

Uploaded using `kernel-builder` (batch 8/32).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h +0 -222
  2. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp +0 -67
  3. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h +0 -324
  4. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp +0 -528
  5. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h +0 -143
  6. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h +0 -187
  7. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h +0 -402
  8. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h +0 -644
  9. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h +0 -375
  10. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h +0 -141
  11. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h +0 -276
  12. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h +0 -573
  13. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h +0 -144
  14. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h +0 -186
  15. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h +0 -127
  16. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h +0 -157
  17. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h +0 -69
  18. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp +0 -369
  19. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp +0 -116
  20. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h +0 -111
  21. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h +0 -541
  22. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h +0 -591
  23. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h +0 -157
  24. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h +0 -38
  25. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +0 -472
  26. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp +0 -570
  27. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp +0 -341
  28. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h +0 -135
  29. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h +0 -94
  30. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h +0 -1549
  31. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h +0 -385
  32. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h +0 -350
  33. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +0 -311
  34. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp +0 -146
  35. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +0 -162
  36. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h +0 -168
  37. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +0 -159
  38. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h +0 -355
  39. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h +0 -250
  40. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h +0 -2075
  41. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +0 -142
  42. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h +0 -514
  43. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h +0 -141
  44. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h +0 -186
  45. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp +0 -782
  46. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +0 -802
  47. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h +0 -66
  48. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h +0 -531
  49. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h +0 -210
  50. build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +0 -228
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h DELETED
@@ -1,222 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /* \file
32
- \brief Defines a math function
33
-
34
-
35
- */
36
-
37
- #pragma once
38
-
39
- #include <vector>
40
- #include <string>
41
- #include <memory>
42
- #include <algorithm>
43
- #include <unordered_map>
44
-
45
- // CUTLASS Library includes
46
- #include "cutlass/blas3.h"
47
- #include "cutlass/library/library.h"
48
- #include "cutlass/library/util.h"
49
- #include "cutlass/library/manifest.h"
50
-
51
- // Profiler includes
52
- #include "options.h"
53
- #include "device_context.h"
54
- #include "operation_profiler.h"
55
- #include "performance_result.h"
56
- #include "problem_space.h"
57
-
58
- /////////////////////////////////////////////////////////////////////////////////////////////////
59
-
60
- namespace cutlass {
61
- namespace profiler {
62
-
63
- /////////////////////////////////////////////////////////////////////////////////////////////////
64
-
65
- /// Abstract base class for each math function
66
- class TrmmOperationProfiler : public OperationProfiler {
67
- public:
68
-
69
- /// Problem structure obtained from problem space
70
- struct TrmmProblem {
71
- int64_t m;
72
- int64_t n;
73
- int64_t lda;
74
- int64_t ldb;
75
- int64_t ldd;
76
- SideMode side_mode;
77
- FillMode fill_mode;
78
- DiagType diag_type;
79
- std::vector<uint8_t> alpha;
80
- std::vector<uint8_t> beta;
81
- int64_t split_k_slices;
82
- int64_t batch_count;
83
-
84
- //
85
- // Methods
86
- //
87
-
88
- TrmmProblem():
89
- m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { }
90
-
91
- /// Parses the problem
92
- Status parse(
93
- library::TrmmDescription const &operation_desc,
94
- ProblemSpace const &problem_space,
95
- ProblemSpace::Problem const &problem);
96
-
97
- /// Initializes a performance result
98
- void initialize_result(
99
- PerformanceResult &result,
100
- library::TrmmDescription const &operation_desc,
101
- ProblemSpace const &problem_space);
102
- };
103
-
104
- /// Workspace used
105
- struct TrmmWorkspace {
106
-
107
- DeviceAllocation *A;
108
- DeviceAllocation *B;
109
- DeviceAllocation *D;
110
- DeviceAllocation *Computed;
111
- DeviceAllocation *Reference;
112
-
113
- library::TrmmConfiguration configuration;
114
- library::TrmmArguments arguments;
115
-
116
- /// Buffer used for the operation's host workspace
117
- std::vector<uint8_t> host_workspace;
118
-
119
- /// Buffer used for the operations' device workspace
120
- DeviceAllocation device_workspace;
121
-
122
- //
123
- // Methods
124
- //
125
-
126
- TrmmWorkspace():
127
- A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { }
128
- };
129
-
130
- protected:
131
-
132
- //
133
- // Data members
134
- //
135
-
136
- /// GEMM problem obtained from problem space
137
- TrmmProblem problem_;
138
-
139
- /// Device memory allocations
140
- TrmmWorkspace trmm_workspace_;
141
-
142
-
143
- public:
144
- //
145
- // Methods
146
- //
147
-
148
- /// Ctor
149
- TrmmOperationProfiler(Options const &options);
150
-
151
- /// Destructor
152
- virtual ~TrmmOperationProfiler();
153
-
154
- /// Prints usage statement for the math function
155
- virtual void print_usage(std::ostream &out) const;
156
-
157
- /// Prints examples
158
- virtual void print_examples(std::ostream &out) const;
159
-
160
- /// Extracts the problem dimensions
161
- virtual Status initialize_configuration(
162
- Options const &options,
163
- PerformanceReport &report,
164
- DeviceContext &device_context,
165
- library::Operation const *operation,
166
- ProblemSpace const &problem_space,
167
- ProblemSpace::Problem const &problem);
168
-
169
- /// Initializes workspace
170
- virtual Status initialize_workspace(
171
- Options const &options,
172
- PerformanceReport &report,
173
- DeviceContext &device_context,
174
- library::Operation const *operation,
175
- ProblemSpace const &problem_space,
176
- ProblemSpace::Problem const &problem);
177
-
178
- /// Verifies CUTLASS against references
179
- virtual bool verify_cutlass(
180
- Options const &options,
181
- PerformanceReport &report,
182
- DeviceContext &device_context,
183
- library::Operation const *operation,
184
- ProblemSpace const &problem_space,
185
- ProblemSpace::Problem const &problem);
186
-
187
- /// Measures performance results
188
- virtual bool profile(
189
- Options const &options,
190
- PerformanceReport &report,
191
- DeviceContext &device_context,
192
- library::Operation const *operation,
193
- ProblemSpace const &problem_space,
194
- ProblemSpace::Problem const &problem);
195
-
196
- protected:
197
-
198
- /// Initializes the performance result
199
- void initialize_result_(
200
- PerformanceResult &result,
201
- Options const &options,
202
- library::TrmmDescription const &operation_desc,
203
- ProblemSpace const &problem_space);
204
-
205
- /// Verifies CUTLASS against references
206
- bool verify_with_cublas_(
207
- Options const &options,
208
- PerformanceReport &report,
209
- DeviceContext &device_context,
210
- library::Operation const *operation,
211
- ProblemSpace const &problem_space,
212
- ProblemSpace::Problem const &problem);
213
-
214
- };
215
-
216
- /////////////////////////////////////////////////////////////////////////////////////////////////
217
-
218
- } // namespace profiler
219
- } // namespace cutlass
220
-
221
- /////////////////////////////////////////////////////////////////////////////////////////////////
222
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp DELETED
@@ -1,67 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <cuda_runtime.h>
35
-
36
- struct GPU_Clock
37
- {
38
- GPU_Clock() {
39
- cudaEventCreate(&start_);
40
- cudaEventCreate(&stop_);
41
- cudaEventRecord(start_);
42
- }
43
-
44
- ~GPU_Clock() {
45
- cudaEventDestroy(start_);
46
- cudaEventDestroy(stop_);
47
- }
48
-
49
- void start() {
50
- cudaEventRecord(start_);
51
- }
52
-
53
- float milliseconds() {
54
- cudaEventRecord(stop_);
55
- cudaEventSynchronize(stop_);
56
- float time;
57
- cudaEventElapsedTime(&time, start_, stop_);
58
- return time;
59
- }
60
-
61
- float seconds() {
62
- return milliseconds() * float(1e-3);
63
- }
64
-
65
- private:
66
- cudaEvent_t start_, stop_;
67
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h DELETED
@@ -1,324 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * Utility for parsing command line arguments
37
- */
38
-
39
- #include <iostream>
40
- #include <limits>
41
- #include <sstream>
42
- #include <string>
43
- #include <vector>
44
- #include <unordered_map>
45
-
46
- #include <cuda_runtime.h>
47
-
48
- #include "cutlass/cutlass.h"
49
-
50
- namespace cutlass {
51
-
52
- /******************************************************************************
53
- * command_line
54
- ******************************************************************************/
55
-
56
- /**
57
- * Utility for parsing command line arguments
58
- */
59
- struct CommandLine {
60
- std::vector<std::string> keys;
61
- std::vector<std::string> values;
62
- std::vector<std::string> args;
63
-
64
- /**
65
- * Constructor
66
- */
67
- CommandLine(int argc, const char** argv) {
68
- using namespace std;
69
-
70
- for (int i = 1; i < argc; i++) {
71
- string arg = argv[i];
72
-
73
- if ((arg[0] != '-') || (arg[1] != '-')) {
74
- args.push_back(arg);
75
- continue;
76
- }
77
-
78
- string::size_type pos;
79
- string key, val;
80
- if ((pos = arg.find('=')) == string::npos) {
81
- key = string(arg, 2, arg.length() - 2);
82
- val = "";
83
- } else {
84
- key = string(arg, 2, pos - 2);
85
- val = string(arg, pos + 1, arg.length() - 1);
86
- }
87
-
88
- keys.push_back(key);
89
- values.push_back(val);
90
- }
91
- }
92
-
93
- /**
94
- * Constructor to represent a command line from a map of [argument] -> [value]
95
- */
96
- CommandLine(std::unordered_map<std::string, std::string>& arg_map) {
97
- for (const auto& [key, value] : arg_map) {
98
- keys.push_back(key);
99
- values.push_back(value);
100
- }
101
- }
102
-
103
- /**
104
- * Checks whether a flag "--<flag>" is present in the commandline
105
- */
106
- bool check_cmd_line_flag(const char* arg_name) const {
107
- using namespace std;
108
-
109
- for (int i = 0; i < int(keys.size()); ++i) {
110
- if (keys[i] == string(arg_name)) return true;
111
- }
112
- return false;
113
- }
114
-
115
- /**
116
- * Returns number of naked (non-flag and non-key-value) commandline parameters
117
- */
118
- size_t num_naked_args() const {
119
- return args.size();
120
- }
121
-
122
- /**
123
- * Print naked (non-flag and non-key-value) commandline parameters
124
- */
125
- void print_naked_args(std::ostream &out) const {
126
- for (auto arg : args) {
127
- out << " " << arg <<"\n";
128
- }
129
- }
130
-
131
- /**
132
- * Returns the commandline parameter for a given index (not including flags)
133
- */
134
- template <typename value_t>
135
- void get_cmd_line_argument(size_t index, value_t& val) const {
136
- using namespace std;
137
- if (index < args.size()) {
138
- istringstream str_stream(args[index]);
139
- str_stream >> val;
140
- }
141
- }
142
-
143
- /**
144
- * Obtains the boolean value specified for a given commandline parameter --<flag>=<bool>
145
- */
146
- void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const {
147
- val = _default;
148
- if (check_cmd_line_flag(arg_name)) {
149
- std::string value;
150
- get_cmd_line_argument(arg_name, value);
151
-
152
- val = !(value == "0" || value == "false");
153
- }
154
- }
155
-
156
- /**
157
- * Obtains the value specified for a given commandline parameter --<flag>=<value>
158
- */
159
- template <typename value_t>
160
- void get_cmd_line_argument(const char* arg_name,
161
- value_t& val) const {
162
-
163
- get_cmd_line_argument(arg_name, val, val);
164
- }
165
-
166
- /**
167
- * Obtains the value specified for a given commandline parameter --<flag>=<value>
168
- */
169
- template <typename value_t>
170
- void get_cmd_line_argument(const char* arg_name,
171
- value_t& val,
172
- value_t const& _default) const {
173
- using namespace std;
174
-
175
- val = _default;
176
-
177
- for (int i = 0; i < int(keys.size()); ++i) {
178
- if (keys[i] == string(arg_name)) {
179
- istringstream str_stream(values[i]);
180
- str_stream >> val;
181
- }
182
- }
183
- }
184
-
185
- /**
186
- * Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
187
- */
188
- template <typename value_t>
189
- void get_cmd_line_arguments(const char* arg_name,
190
- std::vector<value_t>& vals,
191
- char sep = ',') const {
192
- using namespace std;
193
-
194
- if (check_cmd_line_flag(arg_name)) {
195
- // Clear any default values
196
- vals.clear();
197
-
198
- // Recover from multi-value string
199
- for (size_t i = 0; i < keys.size(); ++i) {
200
- if (keys[i] == string(arg_name)) {
201
- string val_string(values[i]);
202
- separate_string(val_string, vals, sep);
203
- }
204
- }
205
- }
206
- }
207
-
208
- /**
209
- * Returns the values specified for a given commandline parameter
210
- * --<flag>=<value>,<value_start:value_end>*
211
- */
212
- void get_cmd_line_argument_pairs(const char* arg_name,
213
- std::vector<std::pair<std::string, std::string> >& tokens,
214
- char delim = ',',
215
- char sep = ':') const {
216
- if (check_cmd_line_flag(arg_name)) {
217
- std::string value;
218
- get_cmd_line_argument(arg_name, value);
219
-
220
- tokenize(tokens, value, delim, sep);
221
- }
222
- }
223
-
224
- /**
225
- * Returns a list of ranges specified for a given commandline parameter
226
- * --<flag>=<key:value>,<key:value>*
227
- */
228
- void get_cmd_line_argument_ranges(const char* arg_name,
229
- std::vector<std::vector<std::string> >& vals,
230
- char delim = ',',
231
- char sep = ':') const {
232
- std::vector<std::string> ranges;
233
- get_cmd_line_arguments(arg_name, ranges, delim);
234
-
235
- for (std::vector<std::string>::const_iterator range = ranges.begin();
236
- range != ranges.end(); ++range) {
237
-
238
- std::vector<std::string> range_vals;
239
- separate_string(*range, range_vals, sep);
240
- vals.push_back(range_vals);
241
- }
242
- }
243
-
244
- /**
245
- * The number of pairs parsed
246
- */
247
- int parsed_argc() const { return (int)keys.size(); }
248
-
249
- //-------------------------------------------------------------------------
250
- // Utility functions
251
- //-------------------------------------------------------------------------
252
-
253
- /// Tokenizes a comma-delimited list of string pairs delimited by ':'
254
- static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
255
- std::string const& str,
256
- char delim = ',',
257
- char sep = ':') {
258
- // Home-built to avoid Boost dependency
259
- size_t s_idx = 0;
260
- size_t d_idx = std::string::npos;
261
- while (s_idx < str.size()) {
262
- d_idx = str.find_first_of(delim, s_idx);
263
-
264
- size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
265
- size_t sep_idx = str.find_first_of(sep, s_idx);
266
- size_t offset = 1;
267
- if (sep_idx == std::string::npos || sep_idx >= end_idx) {
268
- sep_idx = end_idx;
269
- offset = 0;
270
- }
271
-
272
- std::pair<std::string, std::string> item(
273
- str.substr(s_idx, sep_idx - s_idx),
274
- str.substr(sep_idx + offset, end_idx - sep_idx - offset));
275
-
276
- tokens.push_back(item);
277
- s_idx = end_idx + 1;
278
- }
279
- }
280
-
281
- /// Tokenizes a comma-delimited list of string pairs delimited by ':'
282
- static void tokenize(std::vector<std::string>& tokens,
283
- std::string const& str,
284
- char delim = ',',
285
- char sep = ':') {
286
- typedef std::vector<std::pair<std::string, std::string> > TokenVector;
287
- typedef TokenVector::const_iterator token_iterator;
288
-
289
- std::vector<std::pair<std::string, std::string> > token_pairs;
290
- tokenize(token_pairs, str, delim, sep);
291
- for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
292
- tokens.push_back(tok->first);
293
- }
294
- }
295
-
296
- template <typename value_t>
297
- static void separate_string(std::string const& str,
298
- std::vector<value_t>& vals,
299
- char sep = ',') {
300
- std::istringstream str_stream(str);
301
- std::string::size_type old_pos = 0;
302
- std::string::size_type new_pos = 0;
303
-
304
- // Iterate <sep>-delimited values
305
- value_t val;
306
- while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
307
- if (new_pos != old_pos) {
308
- str_stream.width(new_pos - old_pos);
309
- str_stream >> val;
310
- vals.push_back(val);
311
- }
312
-
313
- // skip over delimiter
314
- str_stream.ignore(1);
315
- old_pos = new_pos + 1;
316
- }
317
-
318
- // Read last value
319
- str_stream >> val;
320
- vals.push_back(val);
321
- }
322
- };
323
-
324
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp DELETED
@@ -1,528 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <cuda_runtime.h>
35
- #include <cublas_v2.h>
36
-
37
- //-- BLAM_DEBUG_OUT ---------------------------------------------------------
38
- #ifdef BLAM_DEBUG
39
- # include <iostream>
40
- # ifndef BLAM_DEBUG_OUT
41
- # define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl
42
- # define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl
43
- # endif // BLAM_DEBUG_OUT
44
- #else
45
- # ifndef BLAM_DEBUG_OUT
46
- # define BLAM_DEBUG_OUT(msg)
47
- # define BLAM_DEBUG_OUT_2(msg)
48
- # endif // BLAM_DEBUG_OUT
49
- #endif // BLAM_DEBUG
50
-
51
- // User could potentially define ComplexFloat/ComplexDouble instead of std::
52
- #ifndef BLAM_COMPLEX_TYPES
53
- #define BLAM_COMPLEX_TYPES 1
54
- #include "cutlass/cutlass.h"
55
- #include CUDA_STD_HEADER(complex)
56
-
57
- namespace blam {
58
- template <typename T>
59
- using Complex = cuda::std::complex<T>;
60
- using ComplexFloat = cuda::std::complex<float>;
61
- using ComplexDouble = cuda::std::complex<double>;
62
- }
63
- #endif // BLAM_COMPLEX_TYPES
64
-
65
- // User could potentially define Half instead of cute::
66
- #ifndef BLAM_HALF_TYPE
67
- #define BLAM_HALF_TYPE 1
68
- #include <cute/numeric/numeric_types.hpp>
69
- namespace blam {
70
- using Half = cute::half_t;
71
- }
72
- #endif // BLAM_HALF_TYPE
73
-
74
- namespace blam
75
- {
76
- namespace cublas
77
- {
78
-
79
- inline const char*
80
- cublas_get_error(cublasStatus_t status)
81
- {
82
- switch (status) {
83
- case CUBLAS_STATUS_SUCCESS:
84
- return "CUBLAS_STATUS_SUCCESS";
85
- case CUBLAS_STATUS_NOT_INITIALIZED:
86
- return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized.";
87
- case CUBLAS_STATUS_ALLOC_FAILED:
88
- return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library.";
89
- case CUBLAS_STATUS_INVALID_VALUE:
90
- return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function.";
91
- case CUBLAS_STATUS_ARCH_MISMATCH:
92
- return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture.";
93
- case CUBLAS_STATUS_MAPPING_ERROR:
94
- return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed.";
95
- case CUBLAS_STATUS_EXECUTION_FAILED:
96
- return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute.";
97
- case CUBLAS_STATUS_INTERNAL_ERROR:
98
- return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed.";
99
- case CUBLAS_STATUS_NOT_SUPPORTED:
100
- return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported.";
101
- case CUBLAS_STATUS_LICENSE_ERROR:
102
- return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing.";
103
- default:
104
- return "CUBLAS_ERROR -- <unknown>";
105
- }
106
- }
107
-
108
- inline bool
109
- cublas_is_error(cublasStatus_t status)
110
- {
111
- return status != CUBLAS_STATUS_SUCCESS;
112
- }
113
-
114
-
115
- // hgemm
116
- inline cublasStatus_t
117
- gemm(cublasHandle_t handle,
118
- cublasOperation_t transA, cublasOperation_t transB,
119
- int m, int n, int k,
120
- const Half* alpha,
121
- const Half* A, int ldA,
122
- const Half* B, int ldB,
123
- const Half* beta,
124
- Half* C, int ldC)
125
- {
126
- BLAM_DEBUG_OUT("cublasHgemm");
127
-
128
- return cublasGemmEx(handle, transA, transB,
129
- m, n, k,
130
- reinterpret_cast<const __half*>(alpha),
131
- reinterpret_cast<const __half*>(A), CUDA_R_16F, ldA,
132
- reinterpret_cast<const __half*>(B), CUDA_R_16F, ldB,
133
- reinterpret_cast<const __half*>(beta),
134
- reinterpret_cast< __half*>(C), CUDA_R_16F, ldC,
135
- CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
136
- }
137
-
138
- // mixed hf gemm
139
- inline cublasStatus_t
140
- gemm(cublasHandle_t handle,
141
- cublasOperation_t transA, cublasOperation_t transB,
142
- int m, int n, int k,
143
- const float* alpha,
144
- const Half* A, int ldA,
145
- const Half* B, int ldB,
146
- const float* beta,
147
- float* C, int ldC)
148
- {
149
- BLAM_DEBUG_OUT("cublasGemmEx mixed half-float");
150
-
151
- return cublasGemmEx(handle, transA, transB,
152
- m, n, k,
153
- alpha,
154
- reinterpret_cast<const __half*>(A), CUDA_R_16F, ldA,
155
- reinterpret_cast<const __half*>(B), CUDA_R_16F, ldB,
156
- beta,
157
- C, CUDA_R_32F, ldC,
158
- CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
159
- }
160
-
161
- // igemm
162
- inline cublasStatus_t
163
- gemm(cublasHandle_t handle,
164
- cublasOperation_t transA, cublasOperation_t transB,
165
- int m, int n, int k,
166
- const int32_t* alpha,
167
- const int8_t* A, int ldA,
168
- const int8_t* B, int ldB,
169
- const int32_t* beta,
170
- int32_t* C, int ldC)
171
- {
172
- BLAM_DEBUG_OUT("cublasIgemm");
173
-
174
- return cublasGemmEx(handle, transA, transB,
175
- m, n, k,
176
- alpha,
177
- A, CUDA_R_8I, ldA,
178
- B, CUDA_R_8I, ldB,
179
- beta,
180
- C, CUDA_R_32I, ldC,
181
- CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
182
- }
183
-
184
- // sgemm
185
- inline cublasStatus_t
186
- gemm(cublasHandle_t handle,
187
- cublasOperation_t transA, cublasOperation_t transB,
188
- int m, int n, int k,
189
- const float* alpha,
190
- const float* A, int ldA,
191
- const float* B, int ldB,
192
- const float* beta,
193
- float* C, int ldC)
194
- {
195
- BLAM_DEBUG_OUT("cublasSgemm");
196
-
197
- return cublasSgemm(handle, transA, transB,
198
- m, n, k,
199
- alpha,
200
- A, ldA,
201
- B, ldB,
202
- beta,
203
- C, ldC);
204
- }
205
-
206
- // dgemm
207
- inline cublasStatus_t
208
- gemm(cublasHandle_t handle,
209
- cublasOperation_t transA, cublasOperation_t transB,
210
- int m, int n, int k,
211
- const double* alpha,
212
- const double* A, int ldA,
213
- const double* B, int ldB,
214
- const double* beta,
215
- double* C, int ldC)
216
- {
217
- BLAM_DEBUG_OUT("cublasDgemm");
218
-
219
- return cublasDgemm(handle, transA, transB,
220
- m, n, k,
221
- alpha,
222
- A, ldA,
223
- B, ldB,
224
- beta,
225
- C, ldC);
226
- }
227
-
228
- // cgemm
229
- inline cublasStatus_t
230
- gemm(cublasHandle_t handle,
231
- cublasOperation_t transA, cublasOperation_t transB,
232
- int m, int n, int k,
233
- const ComplexFloat* alpha,
234
- const ComplexFloat* A, int ldA,
235
- const ComplexFloat* B, int ldB,
236
- const ComplexFloat* beta,
237
- ComplexFloat* C, int ldC)
238
- {
239
- BLAM_DEBUG_OUT("cublasCgemm");
240
-
241
- return cublasCgemm(handle, transA, transB,
242
- m, n, k,
243
- reinterpret_cast<const cuFloatComplex*>(alpha),
244
- reinterpret_cast<const cuFloatComplex*>(A), ldA,
245
- reinterpret_cast<const cuFloatComplex*>(B), ldB,
246
- reinterpret_cast<const cuFloatComplex*>(beta),
247
- reinterpret_cast<cuFloatComplex*>(C), ldC);
248
- }
249
-
250
- // zgemm
251
- inline cublasStatus_t
252
- gemm(cublasHandle_t handle,
253
- cublasOperation_t transA, cublasOperation_t transB,
254
- int m, int n, int k,
255
- const ComplexDouble* alpha,
256
- const ComplexDouble* A, int ldA,
257
- const ComplexDouble* B, int ldB,
258
- const ComplexDouble* beta,
259
- ComplexDouble* C, int ldC)
260
- {
261
- BLAM_DEBUG_OUT("cublasZgemm");
262
-
263
- return cublasZgemm(handle, transA, transB,
264
- m, n, k,
265
- reinterpret_cast<const cuDoubleComplex*>(alpha),
266
- reinterpret_cast<const cuDoubleComplex*>(A), ldA,
267
- reinterpret_cast<const cuDoubleComplex*>(B), ldB,
268
- reinterpret_cast<const cuDoubleComplex*>(beta),
269
- reinterpret_cast<cuDoubleComplex*>(C), ldC);
270
- }
271
-
272
- // hgemm
273
- inline cublasStatus_t
274
- gemm_batch(cublasHandle_t handle,
275
- cublasOperation_t transA, cublasOperation_t transB,
276
- int m, int n, int k,
277
- const Half* alpha,
278
- const Half* A, int ldA, int loA,
279
- const Half* B, int ldB, int loB,
280
- const Half* beta,
281
- Half* C, int ldC, int loC,
282
- int batch_size)
283
- {
284
- BLAM_DEBUG_OUT("cublasHgemmStridedBatched");
285
-
286
- return cublasHgemmStridedBatched(handle, transA, transB,
287
- m, n, k,
288
- reinterpret_cast<const __half*>(alpha),
289
- reinterpret_cast<const __half*>(A), ldA, loA,
290
- reinterpret_cast<const __half*>(B), ldB, loB,
291
- reinterpret_cast<const __half*>(beta),
292
- reinterpret_cast<__half*>(C), ldC, loC,
293
- batch_size);
294
- }
295
-
296
- // sgemm
297
- inline cublasStatus_t
298
- gemm_batch(cublasHandle_t handle,
299
- cublasOperation_t transA, cublasOperation_t transB,
300
- int m, int n, int k,
301
- const float* alpha,
302
- const float* A, int ldA, int loA,
303
- const float* B, int ldB, int loB,
304
- const float* beta,
305
- float* C, int ldC, int loC,
306
- int batch_size)
307
- {
308
- BLAM_DEBUG_OUT("cublasSgemmStridedBatched");
309
-
310
- return cublasSgemmStridedBatched(handle, transA, transB,
311
- m, n, k,
312
- alpha,
313
- A, ldA, loA,
314
- B, ldB, loB,
315
- beta,
316
- C, ldC, loC,
317
- batch_size);
318
- }
319
-
320
- // dgemm
321
- inline cublasStatus_t
322
- gemm_batch(cublasHandle_t handle,
323
- cublasOperation_t transA, cublasOperation_t transB,
324
- int m, int n, int k,
325
- const double* alpha,
326
- const double* A, int ldA, int loA,
327
- const double* B, int ldB, int loB,
328
- const double* beta,
329
- double* C, int ldC, int loC,
330
- int batch_size)
331
- {
332
- BLAM_DEBUG_OUT("cublasDgemmStridedBatched");
333
-
334
- return cublasDgemmStridedBatched(handle, transA, transB,
335
- m, n, k,
336
- alpha,
337
- A, ldA, loA,
338
- B, ldB, loB,
339
- beta,
340
- C, ldC, loC,
341
- batch_size);
342
- }
343
-
344
- // cgemm
345
- inline cublasStatus_t
346
- gemm_batch(cublasHandle_t handle,
347
- cublasOperation_t transA, cublasOperation_t transB,
348
- int m, int n, int k,
349
- const ComplexFloat* alpha,
350
- const ComplexFloat* A, int ldA, int loA,
351
- const ComplexFloat* B, int ldB, int loB,
352
- const ComplexFloat* beta,
353
- ComplexFloat* C, int ldC, int loC,
354
- int batch_size)
355
- {
356
- BLAM_DEBUG_OUT("cublasCgemmStridedBatched");
357
-
358
- return cublasCgemmStridedBatched(handle, transA, transB,
359
- m, n, k,
360
- reinterpret_cast<const cuFloatComplex*>(alpha),
361
- reinterpret_cast<const cuFloatComplex*>(A), ldA, loA,
362
- reinterpret_cast<const cuFloatComplex*>(B), ldB, loB,
363
- reinterpret_cast<const cuFloatComplex*>(beta),
364
- reinterpret_cast<cuFloatComplex*>(C), ldC, loC,
365
- batch_size);
366
- }
367
-
368
- // zgemm
369
- inline cublasStatus_t
370
- gemm_batch(cublasHandle_t handle,
371
- cublasOperation_t transA, cublasOperation_t transB,
372
- int m, int n, int k,
373
- const ComplexDouble* alpha,
374
- const ComplexDouble* A, int ldA, int loA,
375
- const ComplexDouble* B, int ldB, int loB,
376
- const ComplexDouble* beta,
377
- ComplexDouble* C, int ldC, int loC,
378
- int batch_size)
379
- {
380
- BLAM_DEBUG_OUT("cublasZgemmStridedBatched");
381
-
382
- return cublasZgemmStridedBatched(handle, transA, transB,
383
- m, n, k,
384
- reinterpret_cast<const cuDoubleComplex*>(alpha),
385
- reinterpret_cast<const cuDoubleComplex*>(A), ldA, loA,
386
- reinterpret_cast<const cuDoubleComplex*>(B), ldB, loB,
387
- reinterpret_cast<const cuDoubleComplex*>(beta),
388
- reinterpret_cast<cuDoubleComplex*>(C), ldC, loC,
389
- batch_size);
390
- }
391
-
392
- // hgemm
393
- inline cublasStatus_t
394
- gemm_batch(cublasHandle_t handle,
395
- cublasOperation_t transA, cublasOperation_t transB,
396
- int m, int n, int k,
397
- const Half* alpha,
398
- const Half* const A[], int ldA,
399
- const Half* const B[], int ldB,
400
- const Half* beta,
401
- Half* const C[], int ldC,
402
- int batch_size)
403
- {
404
- BLAM_DEBUG_OUT("cublasHgemmBatched");
405
-
406
- return cublasHgemmBatched(handle, transA, transB,
407
- m, n, k,
408
- reinterpret_cast<const __half*>(alpha),
409
- reinterpret_cast<const __half**>(const_cast<const Half**>(A)), ldA,
410
- // A, ldA, // cuBLAS 9.2
411
- reinterpret_cast<const __half**>(const_cast<const Half**>(B)), ldB,
412
- // B, ldB, // cuBLAS 9.2
413
- reinterpret_cast<const __half*>(beta),
414
- reinterpret_cast<__half**>(const_cast<Half**>(C)), ldC,
415
- // C, ldC, // cuBLAS 9.2
416
- batch_size);
417
- }
418
-
419
- // sgemm
420
- inline cublasStatus_t
421
- gemm_batch(cublasHandle_t handle,
422
- cublasOperation_t transA, cublasOperation_t transB,
423
- int m, int n, int k,
424
- const float* alpha,
425
- const float* const A[], int ldA,
426
- const float* const B[], int ldB,
427
- const float* beta,
428
- float* const C[], int ldC,
429
- int batch_size)
430
- {
431
- BLAM_DEBUG_OUT("cublasSgemmBatched");
432
-
433
- return cublasSgemmBatched(handle, transA, transB,
434
- m, n, k,
435
- alpha,
436
- const_cast<const float**>(A), ldA,
437
- // A, ldA, // cuBLAS 9.2
438
- const_cast<const float**>(B), ldB,
439
- // B, ldB, // cuBLAS 9.2
440
- beta,
441
- const_cast<float**>(C), ldC,
442
- // C, ldC, // cuBLAS 9.2
443
- batch_size);
444
- }
445
-
446
- // dgemm
447
- inline cublasStatus_t
448
- gemm_batch(cublasHandle_t handle,
449
- cublasOperation_t transA, cublasOperation_t transB,
450
- int m, int n, int k,
451
- const double* alpha,
452
- const double* const A[], int ldA,
453
- const double* const B[], int ldB,
454
- const double* beta,
455
- double* const C[], int ldC,
456
- int batch_size)
457
- {
458
- BLAM_DEBUG_OUT("cublasDgemmBatched");
459
-
460
- return cublasDgemmBatched(handle, transA, transB,
461
- m, n, k,
462
- alpha,
463
- const_cast<const double**>(A), ldA,
464
- // A, ldA, // cuBLAS 9.2
465
- const_cast<const double**>(B), ldB,
466
- // B, ldB, // cuBLAS 9.2
467
- beta,
468
- const_cast<double**>(C), ldC,
469
- // C, ldC, // cuBLAS 9.2
470
- batch_size);
471
- }
472
-
473
- // cgemm
474
- inline cublasStatus_t
475
- gemm_batch(cublasHandle_t handle,
476
- cublasOperation_t transA, cublasOperation_t transB,
477
- int m, int n, int k,
478
- const ComplexFloat* alpha,
479
- const ComplexFloat* const A[], int ldA,
480
- const ComplexFloat* const B[], int ldB,
481
- const ComplexFloat* beta,
482
- ComplexFloat* const C[], int ldC,
483
- int batch_size)
484
- {
485
- BLAM_DEBUG_OUT("cublasCgemmBatched");
486
-
487
- return cublasCgemmBatched(handle, transA, transB,
488
- m, n, k,
489
- reinterpret_cast<const cuFloatComplex*>(alpha),
490
- const_cast<const cuFloatComplex**>(reinterpret_cast<const cuFloatComplex* const *>(A)), ldA,
491
- //reinterpret_cast<const cuFloatComplex* const *>(A), ldA, // cuBLAS 9.2
492
- const_cast<const cuFloatComplex**>(reinterpret_cast<const cuFloatComplex* const *>(B)), ldB,
493
- //reinterpret_cast<const cuFloatComplex* const *>(B), ldB, // cuBLAS 9.2
494
- reinterpret_cast<const cuFloatComplex*>(beta),
495
- const_cast<cuFloatComplex**>(reinterpret_cast<cuFloatComplex* const *>(C)), ldC,
496
- //reinterpret_cast<cuFloatComplex* const *>(C), ldC, // cuBLAS 9.2
497
- batch_size);
498
- }
499
-
500
- // zgemm
501
- inline cublasStatus_t
502
- gemm_batch(cublasHandle_t handle,
503
- cublasOperation_t transA, cublasOperation_t transB,
504
- int m, int n, int k,
505
- const ComplexDouble* alpha,
506
- const ComplexDouble* const A[], int ldA,
507
- const ComplexDouble* const B[], int ldB,
508
- const ComplexDouble* beta,
509
- ComplexDouble* const C[], int ldC,
510
- int batch_size)
511
- {
512
- BLAM_DEBUG_OUT("cublasZgemmBatched");
513
-
514
- return cublasZgemmBatched(handle, transA, transB,
515
- m, n, k,
516
- reinterpret_cast<const cuDoubleComplex*>(alpha),
517
- const_cast<const cuDoubleComplex**>(reinterpret_cast<const cuDoubleComplex* const *>(A)), ldA,
518
- //reinterpret_cast<const cuDoubleComplex* const *>(A), ldA, // cuBLAS 9.2
519
- const_cast<const cuDoubleComplex**>(reinterpret_cast<const cuDoubleComplex* const *>(B)), ldB,
520
- //reinterpret_cast<const cuDoubleComplex* const *>(B), ldB, // cuBLAS 9.2
521
- reinterpret_cast<const cuDoubleComplex*>(beta),
522
- const_cast<cuDoubleComplex**>(reinterpret_cast<cuDoubleComplex* const *>(C)), ldC,
523
- //reinterpret_cast<cuDoubleComplex* const *>(C), ldC, // cuBLAS 9.2
524
- batch_size);
525
- }
526
-
527
- } // end namespace cublas
528
- } // end namespace blam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h DELETED
@@ -1,143 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Contains code for debugging cutlass code
34
- */
35
-
36
- #pragma once
37
-
38
- #include "device_dump.h"
39
-
40
- ////////////////////////////////////////////////////////////////////////////////////////////////////
41
-
42
- /******************************************************************************
43
- * Debug and logging macros
44
- ******************************************************************************/
45
-
46
- /**
47
- * Formats and prints the given message to stdout
48
- */
49
- #if !defined(CUDA_LOG)
50
- #if !defined(__CUDA_ARCH__)
51
- #define CUDA_LOG(format, ...) printf(format, __VA_ARGS__)
52
- #else
53
- #define CUDA_LOG(format, ...) \
54
- printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
55
- blockIdx.x, \
56
- blockIdx.y, \
57
- blockIdx.z, \
58
- threadIdx.x, \
59
- threadIdx.y, \
60
- threadIdx.z, \
61
- __VA_ARGS__);
62
- #endif
63
- #endif
64
-
65
- /**
66
- * Formats and prints the given message to stdout only if DEBUG is defined
67
- */
68
- #if !defined(CUDA_LOG_DEBUG)
69
- #ifdef DEBUG
70
- #define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
71
- #else
72
- #define CUDA_LOG_DEBUG(format, ...)
73
- #endif
74
- #endif
75
-
76
- /**
77
- * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code)
78
- * along with the supplied source context.
79
- *
80
- * \return The CUDA error.
81
- */
82
- __host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
83
- const char* expression,
84
- const char* filename,
85
- int line) {
86
- (void)filename;
87
- (void)line;
88
- if (error) {
89
- #if !defined(__CUDA_ARCH__)
90
- fprintf(
91
- stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error));
92
- fflush(stderr);
93
- #else
94
- printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression);
95
- #endif
96
- }
97
- return error;
98
- }
99
-
100
- /**
101
- * \brief Perror macro
102
- */
103
- #ifndef CUDA_PERROR
104
- #define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)
105
- #endif
106
-
107
- /**
108
- * \brief Perror macro with exit
109
- */
110
- #ifndef CUDA_PERROR_EXIT
111
- #define CUDA_PERROR_EXIT(e) \
112
- do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \
113
- exit(1); \
114
- } } while (0)
115
- #endif
116
-
117
- /**
118
- * \brief Perror macro only if DEBUG is defined
119
- */
120
- #ifndef CUDA_PERROR_DEBUG
121
- #ifdef DEBUG
122
- #define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
123
- #else
124
- #define CUDA_PERROR_DEBUG(e) (e)
125
- #endif
126
- #endif
127
-
128
- ////////////////////////////////////////////////////////////////////////////////////////////////////
129
-
130
- // A small helper class to dump a type at compile time
131
- // Usage:: DumpType<Class>::Class
132
- template <typename T>
133
- struct DebugType {};
134
-
135
- template <typename T>
136
- void DebugTypeFunc(T const& t) {
137
- T::t;
138
- }
139
-
140
- // A small helper class to dump a compile time constant at compile time
141
- // Usage: DumpValue<Class::kConstant>::kConstant
142
- template <int Value>
143
- struct DebugValue {};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h DELETED
@@ -1,187 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <cstdio>
35
- #include "cutlass/cutlass.h"
36
-
37
- /**
38
- * \file
39
- * \brief C++ interface to dump fragments and shared memory contents for
40
- * debugging.
41
- */
42
-
43
- namespace cutlass {
44
- namespace debug {
45
-
46
- /******************************************************************************
47
- * Dump the fragments
48
- ******************************************************************************/
49
-
50
- /// The first N threads dump the first M elements from their fragments with a
51
- /// stride of S elements. If N is not specified, dump the data of all the
52
- /// threads. If M is not specified, dump all the elements of the fragment.
53
- template <typename Fragment>
54
- CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0,
55
- int S = 1) {
56
- int total_threads = blockDim.x * blockDim.y * blockDim.z;
57
- int block_id =
58
- blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
59
- int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) +
60
- (threadIdx.y * blockDim.x) + threadIdx.x;
61
-
62
- if (N < 0 || N > total_threads) {
63
- if (thread_id == 0 && block_id == 0)
64
- printf("Thread number N = %d should between [1, %d].\n", N,
65
- total_threads);
66
-
67
- __syncthreads();
68
-
69
- return;
70
- }
71
-
72
- int total_elements = int(frag.size());
73
-
74
- if (M < 0 || M > total_elements) {
75
- if (thread_id == 0 && block_id == 0)
76
- printf("Element number M = %d should between [1, %d].\n", M,
77
- total_elements);
78
-
79
- __syncthreads();
80
-
81
- return;
82
- }
83
-
84
- if (N == 0) N = total_threads;
85
-
86
- if (M == 0) M = total_elements;
87
-
88
- if (S < 1 || S > M) {
89
- if (thread_id == 0 && block_id == 0)
90
- printf("Stride S = %d should between [1, %d].\n", S, M);
91
-
92
- __syncthreads();
93
-
94
- return;
95
- }
96
-
97
- if (thread_id == 0 && block_id == 0)
98
- printf("\n*******************Dumping the fragments*******************\n\n");
99
-
100
- CUTLASS_PRAGMA_NO_UNROLL
101
- for (int tid = 0; tid < N; ++tid) {
102
- if (tid == thread_id) {
103
- printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31);
104
- CUTLASS_PRAGMA_NO_UNROLL
105
- for (int i = 0; i < M; i += S) {
106
- printf("%.0f ", float(typename Fragment::value_type(frag[i])));
107
- }
108
- printf("\n");
109
- }
110
-
111
- __syncthreads();
112
- }
113
-
114
- if (thread_id == 0 && block_id == 0)
115
- printf("\n***********************************************************\n\n");
116
-
117
- __syncthreads();
118
-
119
- return;
120
- }
121
-
122
- /******************************************************************************
123
- * Dump the shared memory
124
- ******************************************************************************/
125
-
126
- #define SHMEM_ROW_SIZE 128
127
-
128
- /// Dump the shared memory contents. ptr is the begin address, size specifies
129
- /// the number of elements that need to be dumped, and S specifies the stride.
130
- template <typename Element>
131
- CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) {
132
- int block_id =
133
- blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
134
- int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) +
135
- (threadIdx.y * blockDim.x) + threadIdx.x;
136
-
137
- if (ptr == nullptr) {
138
- if (thread_id == 0 && block_id == 0) printf("ptr is null.\n");
139
-
140
- __syncthreads();
141
- return;
142
- }
143
-
144
- if (size < 1) {
145
- if (thread_id == 0 && block_id == 0)
146
- printf("Element size is less than 1\n");
147
-
148
- __syncthreads();
149
-
150
- return;
151
- }
152
-
153
- int row_elements = SHMEM_ROW_SIZE / sizeof(Element);
154
-
155
- if (S < 1 || S > row_elements) {
156
- if (thread_id == 0 && block_id == 0)
157
- printf("Stride S = %d should between [1, %d].\n", S, row_elements);
158
-
159
- __syncthreads();
160
-
161
- return;
162
- }
163
-
164
- __syncthreads();
165
-
166
- if (thread_id == 0)
167
- printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id);
168
-
169
- if (thread_id == 0) {
170
- for (int i = 0; i < size; i += row_elements) {
171
- for (int j = 0; j < row_elements; j += S) {
172
- printf("%.0f ", float(ptr[i + j]));
173
- }
174
-
175
- printf("\n");
176
- }
177
- }
178
-
179
- if (thread_id == 0)
180
- printf("\n***********************************************************\n\n");
181
-
182
- __syncthreads();
183
-
184
- return;
185
- }
186
- } // namespace debug
187
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h DELETED
@@ -1,402 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C'].
37
- */
38
-
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_coord.h"
43
- #include "cutlass/tensor_ref.h"
44
- #include "device_utils.h"
45
- #include <cfloat>
46
-
47
- namespace cutlass {
48
-
49
- /** \brief interface to do group norm on a device memory tensor with NHWC layout.
50
- * \tparam T: data type
51
- */
52
- template <typename T>
53
- void groupnorm(cutlass::Tensor4DCoord input_size,
54
- const int num_groups,
55
- const float eps,
56
- TensorRef<T, layout::TensorNHWC> ref_output,
57
- TensorRef<T, layout::TensorNHWC> ref_input,
58
- TensorRef<T, layout::TensorNHWC> ref_gamma,
59
- TensorRef<T, layout::TensorNHWC> ref_beta,
60
- cudaStream_t stream);
61
-
62
- extern __shared__ char groupnorm_shm[];
63
-
64
- // For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory,
65
- // we store the input in the shared memory.
66
- // grid(num_groups, dim0)
67
- // block(BLOCKSIZE)
68
- // BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
69
- template<typename TVec, typename T, int T_PER_TVec>
70
- __global__ void groupnorm_twopass_store_locally(T* output,
71
- const T* input,
72
- const T* gamma,
73
- const T* beta,
74
- int num_groups,
75
- int prod_dim1_to_last_dim,
76
- int last_dim,
77
- const float eps,
78
- const int TVecs_PER_THREAD)
79
- {
80
- const int bid = blockIdx.y; // index of batch
81
- const int gid = blockIdx.x; // index of group
82
- const int tid = threadIdx.x; // index of thread
83
- const int bdimx = blockDim.x;
84
- const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
85
- const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
86
- const int s_group_stride = last_dim / num_groups;
87
- const int v_group_stride = s_group_stride / T_PER_TVec;
88
- const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
89
- const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
90
- TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
91
- T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid;
92
- float local_sum[1] = {0.0f};
93
-
94
- // load from global memory into shared memory
95
- #pragma unroll
96
- for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
97
- const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
98
- const int offset_in_group =
99
- ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
100
- / T_PER_TVec;
101
- if (current_load_start_idx < s_reduce_elements) {
102
- TVec tmp_vec = input_TVec_ptr[offset_in_group];
103
- T* tmp_vec_ptr = (T*)(&tmp_vec);
104
- const int local_val_offset = i * T_PER_TVec;
105
- #pragma unroll
106
- for (int j = 0; j < T_PER_TVec; j++) {
107
- float tmp = static_cast<float>(tmp_vec_ptr[j]);
108
- local_sum[0] += tmp;
109
- local_val[local_val_offset + j] = tmp_vec_ptr[j];
110
- }
111
- }
112
- }
113
- __shared__ float s_mean, s_variance;
114
-
115
- // reduction for mean
116
- if (bdimx <= 32) {
117
- warpReduceSum<float, 1>(local_sum);
118
- }
119
- else {
120
- blockReduceSum<float, 1>(local_sum);
121
- }
122
- if (tid == 0) {
123
- s_mean = local_sum[0] / s_reduce_elements;
124
- }
125
- __syncthreads();
126
-
127
- // reduction for std
128
- local_sum[0] = 0.0f;
129
- #pragma unroll
130
- for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
131
- const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
132
- if (current_load_start_idx < s_reduce_elements) {
133
- const int local_val_offset = i * T_PER_TVec;
134
- #pragma unroll
135
- for (int j = 0; j < T_PER_TVec; j++) {
136
- float tmp = static_cast<float>(local_val[local_val_offset + j]);
137
- tmp -= s_mean;
138
- local_sum[0] += tmp * tmp;
139
- }
140
- }
141
- }
142
- if (bdimx <= 32) {
143
- warpReduceSum<float, 1>(local_sum);
144
- }
145
- else {
146
- blockReduceSum<float, 1>(local_sum);
147
- }
148
- if (tid == 0) {
149
- s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
150
- }
151
- __syncthreads();
152
-
153
- // normalize
154
- const int gamma_offset_of_group = gid * v_group_stride;
155
- const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
156
- const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
157
- #pragma unroll
158
- for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
159
- const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
160
- const int offset_in_group =
161
- ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
162
- / T_PER_TVec;
163
- const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
164
- const int local_val_offset = i * T_PER_TVec;
165
- if (current_load_start_idx < s_reduce_elements) {
166
- TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
167
- TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
168
- T* gamma_val_ptr = (T*)(&gamma_val);
169
- T* beta_val_ptr = (T*)(&beta_val);
170
- TVec tmp_vec;
171
- T* tmp_vec_ptr = (T*)(&tmp_vec);
172
- #pragma unroll
173
- for (int j = 0; j < T_PER_TVec; j++) {
174
- float tmp = (static_cast<float>(local_val[local_val_offset + j]) - s_mean) * s_variance
175
- * static_cast<float>(gamma_val_ptr[j])
176
- + static_cast<float>(beta_val_ptr[j]);
177
- if (sizeof(T) == sizeof(half)) {
178
- tmp_vec_ptr[j] = T(__float2half_rn(tmp));
179
- }
180
- else {
181
- tmp_vec_ptr[j] = T(tmp);
182
- }
183
- }
184
- output_TVec_ptr[offset_in_group] = tmp_vec;
185
- }
186
- }
187
- }
188
-
189
- // For large prod_dim1_to_last_dim/num_groups,
190
- // in which the data cannot be stored locally,
191
- // we will load from global memory multiple times,
192
- // grid(num_groups, dim0)
193
- // block(BLOCKSIZE)
194
- // BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
195
- template<typename TVec, typename T, int T_PER_TVec>
196
- __global__ void groupnorm_twopass_multiple_load(T* output,
197
- const T* input,
198
- const T* gamma,
199
- const T* beta,
200
- int num_groups,
201
- int prod_dim1_to_last_dim,
202
- int last_dim,
203
- const float eps,
204
- const int TVecs_PER_THREAD)
205
- {
206
- const int bid = blockIdx.y; // index of batch
207
- const int gid = blockIdx.x; // index of group
208
- const int tid = threadIdx.x; // index of thread
209
- const int bdimx = blockDim.x;
210
- const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
211
- const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
212
- const int s_group_stride = last_dim / num_groups;
213
- const int v_group_stride = s_group_stride / T_PER_TVec;
214
- const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
215
- const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
216
- TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
217
- float local_sum[1] = {0.0f};
218
-
219
- #pragma unroll
220
- for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
221
- const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
222
- if (current_load_start_idx < s_reduce_elements) {
223
- const int offset_in_group =
224
- ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
225
- / T_PER_TVec;
226
- TVec tmp_vec = input_TVec_ptr[offset_in_group];
227
- T* tmp_vec_ptr = (T*)(&tmp_vec);
228
- #pragma unroll
229
- for (int j = 0; j < T_PER_TVec; j++) {
230
- float tmp = static_cast<float>(tmp_vec_ptr[j]);
231
- local_sum[0] += tmp;
232
- }
233
- }
234
- }
235
- __shared__ float s_mean, s_variance;
236
-
237
- // reduction for mean
238
- if (bdimx <= 32) {
239
- warpReduceSum<float, 1>(local_sum);
240
- }
241
- else {
242
- blockReduceSum<float, 1>(local_sum);
243
- }
244
- if (tid == 0) {
245
- s_mean = local_sum[0] / s_reduce_elements;
246
- }
247
- __syncthreads();
248
-
249
- // reduction for std
250
- local_sum[0] = 0.0f;
251
- #pragma unroll
252
- for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
253
- const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
254
- if (current_load_start_idx < s_reduce_elements) {
255
- const int offset_in_group =
256
- ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
257
- / T_PER_TVec;
258
- TVec tmp_vec = input_TVec_ptr[offset_in_group];
259
- T* tmp_vec_ptr = (T*)(&tmp_vec);
260
- #pragma unroll
261
- for (int j = 0; j < T_PER_TVec; j++) {
262
- float tmp = static_cast<float>(tmp_vec_ptr[j]);
263
- tmp -= s_mean;
264
- local_sum[0] += tmp * tmp;
265
- }
266
- }
267
- }
268
- if (bdimx <= 32) {
269
- warpReduceSum<float, 1>(local_sum);
270
- }
271
- else {
272
- blockReduceSum<float, 1>(local_sum);
273
- }
274
- if (tid == 0) {
275
- s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
276
- }
277
- __syncthreads();
278
-
279
- // normalize
280
- const int gamma_offset_of_group = gid * v_group_stride;
281
- const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
282
- const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
283
- #pragma unroll
284
- for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
285
- const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
286
- if (current_load_start_idx < s_reduce_elements) {
287
- const int offset_in_group =
288
- ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
289
- / T_PER_TVec;
290
- const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
291
- TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
292
- TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
293
- T* gamma_val_ptr = (T*)(&gamma_val);
294
- T* beta_val_ptr = (T*)(&beta_val);
295
- TVec tmp_vec = input_TVec_ptr[offset_in_group];
296
- T* tmp_vec_ptr = (T*)(&tmp_vec);
297
- TVec output_tmp_vec;
298
- T* output_tmp_vec_ptr = (T*)(&output_tmp_vec);
299
- #pragma unroll
300
- for (int j = 0; j < T_PER_TVec; j++) {
301
- float tmp =
302
- (static_cast<float>(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast<float>(gamma_val_ptr[j])
303
- + static_cast<float>(beta_val_ptr[j]);
304
- if (sizeof(T) == sizeof(half)) {
305
- output_tmp_vec_ptr[j] = T(__float2half_rn(tmp));
306
- }
307
- else {
308
- output_tmp_vec_ptr[j] = T(tmp);
309
- }
310
- }
311
- output_TVec_ptr[offset_in_group] = output_tmp_vec;
312
- }
313
- }
314
- }
315
-
316
- //ref_input & ref_output should be [N, H, W, C]
317
- //ref_gamma & ref_beta should be [1, 1, 1, C]
318
- template <typename T>
319
- void groupnorm(cutlass::Tensor4DCoord input_size,
320
- const int num_groups,
321
- const float eps,
322
- TensorRef<T, layout::TensorNHWC> ref_output,
323
- TensorRef<T, layout::TensorNHWC> ref_input,
324
- TensorRef<T, layout::TensorNHWC> ref_gamma,
325
- TensorRef<T, layout::TensorNHWC> ref_beta,
326
- cudaStream_t stream){
327
- const int N = input_size.n();
328
- const int H = input_size.h();
329
- const int W = input_size.w();
330
- const int C = input_size.c();
331
- if (C % num_groups != 0){
332
- printf("[ERROR] C should be a multiple of num_groups.\n");
333
- }
334
- T* output = ref_output.data();
335
- const T* input = ref_input.data();
336
- const T* gamma = ref_gamma.data();
337
- const T* beta = ref_beta.data();
338
-
339
- const int dim0 = N;
340
- const int last_dim = C;
341
- const int prod_dim1_to_last_dim = H*W*C;
342
- const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
343
- const int s_group_stride = last_dim / num_groups;
344
- dim3 grid(num_groups, dim0);
345
- int threadblock_size = 32;
346
- if (s_group_stride % 2 == 0) {
347
- const int T_PER_TVec = 2;
348
- while (threadblock_size < 1024) {
349
- if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
350
- break;
351
- threadblock_size *= 2;
352
- }
353
- dim3 block(threadblock_size);
354
- const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
355
- const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
356
- // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32;
357
- // the size of grid & block may have better choice for different cases.
358
- // ensure shared memory is smaller than 48KB
359
- if (std::is_same<T, float>::value){
360
- if (shm_size < 48 * 1024) {
361
- groupnorm_twopass_store_locally<float2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
362
- output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
363
- }
364
- else {
365
- groupnorm_twopass_multiple_load<float2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
366
- output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
367
- }
368
- }
369
- else{
370
- if (shm_size < 48 * 1024) {
371
- groupnorm_twopass_store_locally<half2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
372
- output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
373
- }
374
- else {
375
- groupnorm_twopass_multiple_load<half2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
376
- output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
377
- }
378
- }
379
- }
380
- else {
381
- const int T_PER_TVec = 1;
382
- while (threadblock_size < 1024) {
383
- if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
384
- break;
385
- threadblock_size *= 2;
386
- }
387
- dim3 block(threadblock_size);
388
- const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
389
- const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
390
- if (shm_size < 48 * 1024) {
391
- groupnorm_twopass_store_locally<T, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
392
- output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
393
- }
394
- else {
395
- groupnorm_twopass_multiple_load<T, T, T_PER_TVec><<<grid, block, 0, stream>>>(
396
- output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
397
- }
398
- }
399
-
400
- }
401
-
402
- } //namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h DELETED
@@ -1,644 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout.
37
- */
38
-
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_coord.h"
43
- #include "cutlass/tensor_ref.h"
44
- #include "device_utils.h"
45
- #include <cfloat>
46
-
47
- namespace cutlass {
48
-
49
- /** \brief interface to do layernorm on a device memory tensor with RowMajor layout.
50
- * \tparam T: data type
51
- */
52
- template <typename T>
53
- void layernorm(cutlass::MatrixCoord tensor_size,
54
- TensorRef<T, layout::RowMajor> ref_output,
55
- TensorRef<T, layout::RowMajor> ref_input,
56
- TensorRef<T, layout::RowMajor> ref_gamma,
57
- TensorRef<T, layout::RowMajor> ref_beta,
58
- cudaStream_t stream);
59
-
60
- /**
61
- * output [m, n] row-major
62
- * input [m, n] row-major
63
- * gamma [n]
64
- * beta [n]
65
- * grid(m)
66
- * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements
67
- */
68
- template<typename T, int ITEM_PER_THREAD>
69
- __global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output,
70
- const T* input,
71
- const T* gamma,
72
- const T* beta,
73
- const int m,
74
- const int n)
75
- {
76
- const int m_idx = blockIdx.x;
77
- const int tid = threadIdx.x;
78
- const int bdimx = blockDim.x;
79
- __shared__ float s_mean, s_variance;
80
- T local_val[ITEM_PER_THREAD];
81
- float local_sums[1] = {0.0f};
82
- int offset = m_idx * n;
83
- input += offset;
84
- output += offset;
85
-
86
- const T zero = T(0.0f);
87
- #pragma unroll
88
- for (int i = 0 ; i < ITEM_PER_THREAD ; i++){
89
- int index = tid + i*bdimx;
90
- local_val[i] = index < n ? input[index] : zero;
91
- local_sums[0] += static_cast<float>(local_val[i]);
92
- }
93
- if (blockDim.x <= 32) {
94
- warpReduceSum<float, 1>(local_sums);
95
- }
96
- else {
97
- blockReduceSum<float, 1>(local_sums);
98
- }
99
- if (threadIdx.x == 0) {
100
- s_mean = local_sums[0] / n;
101
- }
102
- __syncthreads();
103
-
104
- local_sums[0] = 0.0f;
105
- #pragma unroll
106
- for (int i = 0 ; i < ITEM_PER_THREAD ; i++){
107
- int index = tid + i*bdimx;
108
- if (index < n){
109
- const float tmp = static_cast<float>(local_val[i]) - s_mean;
110
- local_sums[0] += tmp * tmp;
111
- }
112
- }
113
-
114
- if (blockDim.x <= 32) {
115
- warpReduceSum<float, 1>(local_sums);
116
- }
117
- else {
118
- blockReduceSum<float, 1>(local_sums);
119
- }
120
- if (threadIdx.x == 0) {
121
- s_variance = rsqrtf(local_sums[0] / n + 1e-5);
122
- }
123
- __syncthreads();
124
-
125
- #pragma unroll
126
- for (int i = 0 ; i < ITEM_PER_THREAD ; i++){
127
- int index = tid + i*bdimx;
128
- if (index < n) {
129
- const T gamma_val = gamma[index];
130
- const T beta_val = beta[index];
131
- output[index] = T((static_cast<float>(local_val[i]) - s_mean) * s_variance * static_cast<float>(gamma_val) + static_cast<float>(beta_val));
132
- }
133
- }
134
- }
135
-
136
- /**
137
- * output [m, n] row-major
138
- * input [m, n] row-major
139
- * gamma [n]
140
- * beta [n]
141
- * grid(m)
142
- * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements;
143
- */
144
- template<typename T2, typename T, int ITEM_PER_THREAD>
145
- __global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output,
146
- const T2* input,
147
- const T2* gamma,
148
- const T2* beta,
149
- const int m,
150
- const int n)
151
- {
152
- const int m_idx = blockIdx.x;
153
- const int tid = threadIdx.x;
154
- const int bdimx = blockDim.x;
155
- __shared__ float s_mean, s_variance;
156
- float local_sums[1] = {0.0f};
157
- T2 local_val[ITEM_PER_THREAD];
158
- const int n_2 = n / 2;
159
- int offset = m_idx * n_2;
160
- input += offset;
161
- output += offset;
162
-
163
- const T2 zero = {T(0.0f), T(0.0f)};
164
- #pragma UNROLL
165
- for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
166
- const int index = i*bdimx + tid;
167
- local_val[i] = index < n_2 ? input[index] : zero;
168
- local_sums[0] += static_cast<float>(local_val[i].x) + static_cast<float>(local_val[i].y);
169
- }
170
-
171
- if (blockDim.x <= 32) {
172
- warpReduceSum<float, 1>(local_sums);
173
- }
174
- else {
175
- blockReduceSum<float, 1>(local_sums);
176
- }
177
- if (threadIdx.x == 0) {
178
- s_mean = local_sums[0] / n;
179
- }
180
- __syncthreads();
181
-
182
- local_sums[0] = 0.0f;
183
- #pragma UNROLL
184
- for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
185
- const int index = i*bdimx + tid;
186
- if (index < n_2){
187
- const float2 tmp = {static_cast<float>(local_val[i].x) - s_mean,
188
- static_cast<float>(local_val[i].y) - s_mean};
189
- local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y;
190
- }
191
- }
192
- if (blockDim.x <= 32) {
193
- warpReduceSum<float, 1>(local_sums);
194
- }
195
- else {
196
- blockReduceSum<float, 1>(local_sums);
197
- }
198
- if (threadIdx.x == 0) {
199
- s_variance = rsqrtf(local_sums[0] / n + 1e-5);
200
- }
201
- __syncthreads();
202
-
203
- #pragma UNROLL
204
- for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
205
- const int index = i*bdimx + tid;
206
- if (index < n_2){
207
- const T2 gamma_val = gamma[index];
208
- const T2 beta_val = beta[index];
209
- T2 tmp;
210
- tmp.x = T((static_cast<float>(local_val[i].x) - s_mean)*s_variance*static_cast<float>(gamma_val.x) + static_cast<float>(beta_val.x));
211
- tmp.y = T((static_cast<float>(local_val[i].y) - s_mean)*s_variance*static_cast<float>(gamma_val.y) + static_cast<float>(beta_val.y));
212
- output[index] = tmp;
213
- }
214
- }
215
- }
216
-
217
- /**
218
- * output [m, n] row-major
219
- * input [m, n] row-major
220
- * gamma [n]
221
- * beta [n]
222
- * grid(m)
223
- * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements;
224
- */
225
- template<typename T4, typename T, int ITEM_PER_THREAD>
226
- __global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output,
227
- const T4* input,
228
- const T4* gamma,
229
- const T4* beta,
230
- const int m,
231
- const int n)
232
- {
233
- const int m_idx = blockIdx.x;
234
- const int tid = threadIdx.x;
235
- const int bdimx = blockDim.x;
236
- __shared__ float s_mean, s_variance;
237
- float local_sums[1] = {0.0f};
238
- T4 local_val[ITEM_PER_THREAD];
239
- const int n_4 = n / 4;
240
- int offset = m_idx * n_4;
241
- input += offset;
242
- output += offset;
243
-
244
- const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)};
245
- #pragma UNROLL
246
- for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
247
- const int index = i*bdimx + tid;
248
- local_val[i] = index < n_4 ? input[index] : zero;
249
- local_sums[0] += static_cast<float>(local_val[i].x) + static_cast<float>(local_val[i].y) +
250
- static_cast<float>(local_val[i].z) + static_cast<float>(local_val[i].w);
251
- }
252
-
253
- if (blockDim.x <= 32) {
254
- warpReduceSum<float, 1>(local_sums);
255
- }
256
- else {
257
- blockReduceSum<float, 1>(local_sums);
258
- }
259
- if (threadIdx.x == 0) {
260
- s_mean = local_sums[0] / n;
261
- }
262
- __syncthreads();
263
-
264
- local_sums[0] = 0.0f;
265
- #pragma UNROLL
266
- for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
267
- const int index = i*bdimx + tid;
268
- if (index < n_4){
269
- const float4 tmp = {static_cast<float>(local_val[i].x) - s_mean,
270
- static_cast<float>(local_val[i].y) - s_mean,
271
- static_cast<float>(local_val[i].z) - s_mean,
272
- static_cast<float>(local_val[i].w) - s_mean};
273
- local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w;
274
- }
275
- }
276
- if (blockDim.x <= 32) {
277
- warpReduceSum<float, 1>(local_sums);
278
- }
279
- else {
280
- blockReduceSum<float, 1>(local_sums);
281
- }
282
- if (threadIdx.x == 0) {
283
- s_variance = rsqrtf(local_sums[0] / n + 1e-5);
284
- }
285
- __syncthreads();
286
-
287
- #pragma UNROLL
288
- for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
289
- const int index = i*bdimx + tid;
290
- if (index < n_4){
291
- const T4 gamma_val = gamma[index];
292
- const T4 beta_val = beta[index];
293
- T4 tmp;
294
- tmp.x = T((static_cast<float>(local_val[i].x) - s_mean)*s_variance*static_cast<float>(gamma_val.x) + static_cast<float>(beta_val.x));
295
- tmp.y = T((static_cast<float>(local_val[i].y) - s_mean)*s_variance*static_cast<float>(gamma_val.y) + static_cast<float>(beta_val.y));
296
- tmp.z = T((static_cast<float>(local_val[i].z) - s_mean)*s_variance*static_cast<float>(gamma_val.z) + static_cast<float>(beta_val.z));
297
- tmp.w = T((static_cast<float>(local_val[i].w) - s_mean)*s_variance*static_cast<float>(gamma_val.w) + static_cast<float>(beta_val.w));
298
- output[index] = tmp;
299
- }
300
- }
301
- }
302
-
303
- /**
304
- * output [m, n] row-major
305
- * input [m, n] row-major
306
- * gamma [n]
307
- * beta [n]
308
- * grid(m)
309
- * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements
310
- */
311
- template<typename T>
312
- __global__ void layernorm_twoPassAlgo_e1(T* output,
313
- const T* input,
314
- const T* gamma,
315
- const T* beta,
316
- const int m,
317
- const int n)
318
- {
319
- const int m_idx = blockIdx.x;
320
- const int tid = threadIdx.x;
321
- const int bdimx = blockDim.x;
322
- __shared__ float s_mean, s_variance;
323
- float local_sums[1] = {0.0f};
324
- int offset = m_idx * n;
325
- input += offset;
326
- output += offset;
327
-
328
- for (int index = tid ; index < n ; index += bdimx){
329
- float local_val = static_cast<float>(input[index]);
330
- local_sums[0] += local_val;
331
- }
332
- if (blockDim.x <= 32) {
333
- warpReduceSum<float, 1>(local_sums);
334
- }
335
- else {
336
- blockReduceSum<float, 1>(local_sums);
337
- }
338
- if (threadIdx.x == 0) {
339
- s_mean = local_sums[0] / n;
340
- }
341
- __syncthreads();
342
-
343
- local_sums[0] = 0.0f;
344
- for (int index = tid ; index < n ; index += bdimx){
345
- float local_val = static_cast<float>(input[index]);
346
- local_val = local_val - s_mean;
347
- local_sums[0] += local_val * local_val;
348
- }
349
-
350
- if (blockDim.x <= 32) {
351
- warpReduceSum<float, 1>(local_sums);
352
- }
353
- else {
354
- blockReduceSum<float, 1>(local_sums);
355
- }
356
- if (threadIdx.x == 0) {
357
- s_variance = rsqrtf(local_sums[0] / n + 1e-5);
358
- }
359
- __syncthreads();
360
-
361
- for (int index = tid ; index < n ; index += bdimx){
362
- const T gamma_val = gamma[index];
363
- const T beta_val = beta[index];
364
- const T local_val = input[index];
365
- output[index] = T((static_cast<float>(local_val) - s_mean) * s_variance * static_cast<float>(gamma_val) + static_cast<float>(beta_val));
366
- }
367
- }
368
-
369
- /**
370
- * output [m, n] row-major
371
- * input [m, n] row-major
372
- * gamma [n]
373
- * beta [n]
374
- * grid(m)
375
- * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements;
376
- */
377
- template<typename T2, typename T>
378
- __global__ void layernorm_twoPassAlgo_e2(T2* output,
379
- const T2* input,
380
- const T2* gamma,
381
- const T2* beta,
382
- const int m,
383
- const int n)
384
- {
385
- const int m_idx = blockIdx.x;
386
- const int tid = threadIdx.x;
387
- const int bdimx = blockDim.x;
388
- __shared__ float s_mean, s_variance;
389
- float local_sums[1] = {0.0f};
390
- const int n_2 = n / 2;
391
- int offset = m_idx * n_2;
392
- input += offset;
393
- output += offset;
394
-
395
- for (int index = tid; index < n_2; index += bdimx) {
396
- const T2 local_val = input[index];
397
- local_sums[0] += static_cast<float>(local_val.x) + static_cast<float>(local_val.y);
398
- }
399
-
400
- if (blockDim.x <= 32) {
401
- warpReduceSum<float, 1>(local_sums);
402
- }
403
- else {
404
- blockReduceSum<float, 1>(local_sums);
405
- }
406
- if (threadIdx.x == 0) {
407
- s_mean = local_sums[0] / n;
408
- }
409
- __syncthreads();
410
-
411
- local_sums[0] = 0.0f;
412
- for (int index = tid; index < n_2; index += bdimx) {
413
- const T2 local_val = input[index];
414
- const float2 tmp = {static_cast<float>(local_val.x) - s_mean,
415
- static_cast<float>(local_val.y) - s_mean};
416
- local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y;
417
- }
418
- if (blockDim.x <= 32) {
419
- warpReduceSum<float, 1>(local_sums);
420
- }
421
- else {
422
- blockReduceSum<float, 1>(local_sums);
423
- }
424
- if (threadIdx.x == 0) {
425
- s_variance = rsqrtf(local_sums[0] / n + 1e-5);
426
- }
427
- __syncthreads();
428
-
429
- for (int index = tid; index < n_2; index += bdimx) {
430
- const T2 local_val = input[index];
431
- const T2 gamma_val = gamma[index];
432
- const T2 beta_val = beta[index];
433
- T2 tmp;
434
- tmp.x = T((static_cast<float>(local_val.x) - s_mean)*s_variance*static_cast<float>(gamma_val.x) + static_cast<float>(beta_val.x));
435
- tmp.y = T((static_cast<float>(local_val.y) - s_mean)*s_variance*static_cast<float>(gamma_val.y) + static_cast<float>(beta_val.y));
436
- output[index] = tmp;
437
- }
438
- }
439
-
440
- template <typename T>
441
- void layernorm(cutlass::MatrixCoord tensor_size,
442
- TensorRef<T, layout::RowMajor> ref_output,
443
- TensorRef<T, layout::RowMajor> ref_input,
444
- TensorRef<T, layout::RowMajor> ref_gamma,
445
- TensorRef<T, layout::RowMajor> ref_beta,
446
- cudaStream_t stream){
447
- const int m = tensor_size.row();
448
- const int n = tensor_size.column();
449
- T* output = ref_output.data();
450
- const T* input = ref_input.data();
451
- const T* gamma = ref_gamma.data();
452
- const T* beta = ref_beta.data();
453
- dim3 grid(m);
454
- dim3 block((n + 31)/32*32);
455
- if (block.x > 1024){
456
- block.x = 1024;
457
- }
458
- // TODO : There should be better configs for different cases, we only use several samples to show how to use here
459
- // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels.
460
- if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) {
461
- block.x = (n/4 + 31)/32*32;
462
- if (std::is_same<T, float>::value) {
463
- layernorm_twoPassAlgo_stored_locally_e4<float4, float, 1><<<grid, block, 0, stream>>>(
464
- (float4*)output,
465
- (const float4*)input,
466
- (const float4*)gamma,
467
- (const float4*)beta,
468
- m,
469
- n);
470
- } // if (std::is_same<T, float>::value)
471
- else {
472
- layernorm_twoPassAlgo_stored_locally_e4<half4, half, 1><<<grid, block, 0, stream>>>(
473
- (half4*)output,
474
- (const half4*)input,
475
- (const half4*)gamma,
476
- (const half4*)beta,
477
- m,
478
- n);
479
- }
480
- } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096))
481
- else if (n % 2 == 0) {
482
- if (n / 2 <= 1024) {
483
- block.x = (n/2 + 31)/32*32;
484
- if (std::is_same<T, float>::value) {
485
- layernorm_twoPassAlgo_stored_locally_e2<float2, float, 1><<<grid, block, 0, stream>>>(
486
- (float2*)output,
487
- (const float2*)input,
488
- (const float2*)gamma,
489
- (const float2*)beta,
490
- m,
491
- n);
492
- } //if (std::is_same<T, float>::value)
493
- else {
494
- layernorm_twoPassAlgo_stored_locally_e2<half2, half, 1><<<grid, block, 0, stream>>>(
495
- (half2*)output,
496
- (const half2*)input,
497
- (const half2*)gamma,
498
- (const half2*)beta,
499
- m,
500
- n);
501
- }
502
- } // if (n / 2 <= 1024)
503
- else if (n <= 8192) {
504
- block.x = ((n + 7)/8 + 31)/32*32;
505
- if (std::is_same<T, float>::value) {
506
- layernorm_twoPassAlgo_stored_locally_e2<float2, float, 4><<<grid, block, 0, stream>>>(
507
- (float2*)output,
508
- (const float2*)input,
509
- (const float2*)gamma,
510
- (const float2*)beta,
511
- m,
512
- n);
513
- } // if (std::is_same<T, float>::value)
514
- else {
515
- layernorm_twoPassAlgo_stored_locally_e2<half2, half, 4><<<grid, block, 0, stream>>>(
516
- (half2*)output,
517
- (const half2*)input,
518
- (const half2*)gamma,
519
- (const half2*)beta,
520
- m,
521
- n);
522
- }
523
- } // if (n <= 8192)
524
- else if (n <= 16384) {
525
- block.x = ((n + 15)/ 16 + 31)/32*32;
526
- if (std::is_same<T, float>::value) {
527
- layernorm_twoPassAlgo_stored_locally_e2<float2, float, 8><<<grid, block, 0, stream>>>(
528
- (float2*)output,
529
- (const float2*)input,
530
- (const float2*)gamma,
531
- (const float2*)beta,
532
- m,
533
- n);
534
- } // if (std::is_same<T, float>::value)
535
- else {
536
- layernorm_twoPassAlgo_stored_locally_e2<half2, half, 8><<<grid, block, 0, stream>>>(
537
- (half2*)output,
538
- (const half2*)input,
539
- (const half2*)gamma,
540
- (const half2*)beta,
541
- m,
542
- n);
543
- }
544
- } // if (n <= 16384)
545
- else if (n <= 32768) {
546
- block.x = ((n + 31)/32 + 31)/32*32;
547
- if (std::is_same<T, float>::value) {
548
- layernorm_twoPassAlgo_stored_locally_e2<float2, float, 16><<<grid, block, 0, stream>>>(
549
- (float2*)output,
550
- (const float2*)input,
551
- (const float2*)gamma,
552
- (const float2*)beta,
553
- m,
554
- n);
555
- } // if (std::is_same<T, float>::value)
556
- else {
557
- layernorm_twoPassAlgo_stored_locally_e2<half2, half, 16><<<grid, block, 0, stream>>>(
558
- (half2*)output,
559
- (const half2*)input,
560
- (const half2*)gamma,
561
- (const half2*)beta,
562
- m,
563
- n);
564
- }
565
- } // if (n <= 32768)
566
- else {
567
- if (block.x > 512)
568
- block.x = 512;
569
- if (std::is_same<T, float>::value) {
570
- layernorm_twoPassAlgo_e2<float2, float><<<grid, block, 0, stream>>>(
571
- (float2 *)output,
572
- (const float2 *)input,
573
- (const float2 *)gamma,
574
- (const float2 *)beta,
575
- m,
576
- n);
577
- } // if (std::is_same<T, float>::value)
578
- else {
579
- layernorm_twoPassAlgo_e2<half2, half><<<grid, block, 0, stream>>>(
580
- (half2 *)output,
581
- (const half2 *)input,
582
- (const half2 *)gamma,
583
- (const half2 *)beta,
584
- m,
585
- n);
586
- }
587
- }
588
- } // if (n % 2 == 0)
589
- else {
590
- if (n <= 1024) {
591
- layernorm_twoPassAlgo_stored_locally_e1<T, 1><<<grid, block, 0, stream>>>(
592
- output,
593
- input,
594
- gamma,
595
- beta,
596
- m,
597
- n);
598
- } // if (n <= 1024)
599
- else if (n <= 8192) {
600
- block.x = ((n + 7)/8 + 31)/32*32;
601
- layernorm_twoPassAlgo_stored_locally_e1<T, 8><<<grid, block, 0, stream>>>(
602
- output,
603
- input,
604
- gamma,
605
- beta,
606
- m,
607
- n);
608
- } // if (n <= 8192)
609
- else if (n <= 16384) {
610
- block.x = ((n + 15)/16 + 32)/32*32;
611
- layernorm_twoPassAlgo_stored_locally_e1<T, 16><<<grid, block, 0, stream>>>(
612
- output,
613
- input,
614
- gamma,
615
- beta,
616
- m,
617
- n);
618
- } // if (n <= 16384)
619
- else if (n <= 32768) {
620
- block.x = ((n + 31)/32 + 31)/32*32;
621
- layernorm_twoPassAlgo_stored_locally_e1<T, 32><<<grid, block, 0, stream>>>(
622
- output,
623
- input,
624
- gamma,
625
- beta,
626
- m,
627
- n);
628
- } // if (n <= 32768)
629
- else{
630
- if (block.x > 512) {
631
- block.x = 512;
632
- }
633
- layernorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
634
- output,
635
- input,
636
- gamma,
637
- beta,
638
- m,
639
- n);
640
- }
641
- }
642
- }
643
-
644
- } //namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h DELETED
@@ -1,375 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief C++ interface to CUDA device memory management functions.
37
- */
38
-
39
- #include <memory>
40
- #include <sstream>
41
-
42
- #include "cutlass/platform/platform.h"
43
- #include "cutlass/numeric_types.h"
44
- #include "cutlass/trace.h"
45
- #include "exceptions.h"
46
-
47
- namespace cutlass {
48
- namespace device_memory {
49
-
50
- /******************************************************************************
51
- * Allocation lifetime
52
- ******************************************************************************/
53
-
54
- /// Allocate a buffer of \p count elements of type \p T on the current CUDA device
55
- template <typename T>
56
- T* allocate(size_t count = 1) {
57
-
58
- T* ptr = 0;
59
- size_t bytes = count * sizeof_bits<T>::value / 8;
60
-
61
- cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes);
62
-
63
- if (cuda_error != cudaSuccess) {
64
- #if (CUTLASS_DEBUG_TRACE_LEVEL > 0)
65
- std::ostringstream os;
66
- os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes;
67
- CUTLASS_TRACE_HOST(os.str());
68
- #endif
69
- throw cuda_exception("Failed to allocate memory", cuda_error);
70
- }
71
- #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
72
- else {
73
- std::ostringstream os;
74
- os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes;
75
- CUTLASS_TRACE_HOST(os.str());
76
- }
77
- #endif
78
-
79
- return ptr;
80
- }
81
-
82
- /// Free the buffer pointed to by \p ptr
83
- template <typename T>
84
- void free(T* ptr) {
85
- if (ptr) {
86
- cudaError_t cuda_error = (cudaFree(ptr));
87
- if (cuda_error != cudaSuccess) {
88
- throw cuda_exception("Failed to free device memory", cuda_error);
89
- }
90
- }
91
- }
92
-
93
- /******************************************************************************
94
- * Data movement
95
- ******************************************************************************/
96
-
97
- template <typename T>
98
- void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) {
99
- size_t bytes = count * sizeof_bits<T>::value / 8;
100
- if (bytes == 0 && count > 0) {
101
- bytes = 1;
102
- }
103
- cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind));
104
- if (cuda_error != cudaSuccess) {
105
- std::ostringstream os;
106
- os << "cutlass::device_memory::copy: cudaMemcpy() failed: "
107
- << "dst=" << dst << ", src=" << src
108
- << ", bytes=" << bytes << ", count=" << count;
109
- if (kind == cudaMemcpyHostToDevice) {
110
- os << ", kind=cudaMemcpyHostToDevice";
111
- }
112
- else if (kind == cudaMemcpyDeviceToHost) {
113
- os << ", kind=cudaMemcpyDeviceToHost";
114
- }
115
- else if (kind == cudaMemcpyDeviceToDevice) {
116
- os << ", kind=cudaMemcpyDeviceToDevice";
117
- }
118
- else if (kind == cudaMemcpyHostToHost) {
119
- os << ", kind=cudaMemcpyHostToHost";
120
- }
121
- else if (kind == cudaMemcpyDefault) {
122
- os << ", kind=cudaMemcpyDefault";
123
- }
124
- else {
125
- os << ", kind=Unknown";
126
- }
127
- os << ", error: " << cudaGetErrorString(cuda_error);
128
-
129
- throw cuda_exception(os.str().c_str(), cuda_error);
130
- }
131
- }
132
-
133
- template <typename T>
134
- void copy_to_device(T* dst, T const* src, size_t count = 1) {
135
- copy(dst, src, count, cudaMemcpyHostToDevice);
136
- }
137
-
138
- template <typename T>
139
- void copy_to_host(T* dst, T const* src, size_t count = 1) {
140
- copy(dst, src, count, cudaMemcpyDeviceToHost);
141
- }
142
-
143
- template <typename T>
144
- void copy_device_to_device(T* dst, T const* src, size_t count = 1) {
145
- copy(dst, src, count, cudaMemcpyDeviceToDevice);
146
- }
147
-
148
- template <typename T>
149
- void copy_host_to_host(T* dst, T const* src, size_t count = 1) {
150
- copy(dst, src, count, cudaMemcpyHostToHost);
151
- }
152
-
153
- /// Copies elements from device memory to host-side range
154
- template <typename OutputIterator, typename T>
155
- void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) {
156
- size_t elements = end - begin;
157
- copy_to_host(&*begin, device_begin, elements);
158
- }
159
-
160
- /// Copies elements to device memory from host-side range
161
- template <typename T, typename InputIterator>
162
- void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) {
163
- size_t elements = end - begin;
164
- copy_to_device(device_begin, &*begin, elements);
165
- }
166
-
167
- /////////////////////////////////////////////////////////////////////////////////////////////////
168
-
169
- } // namespace device_memory
170
-
171
- /////////////////////////////////////////////////////////////////////////////////////////////////
172
-
173
- template <typename T>
174
- class DeviceAllocation {
175
- public:
176
-
177
- /// Delete functor for CUDA device memory
178
- struct deleter {
179
- void operator()(T* ptr) {
180
- cudaError_t cuda_error = (cudaFree(ptr));
181
- if (cuda_error != cudaSuccess) {
182
- // noexcept
183
- // throw cuda_exception("cudaFree() failed", cuda_error);
184
- return;
185
- }
186
- }
187
- };
188
-
189
- public:
190
- //
191
- // Data members
192
- //
193
-
194
- /// Number of elements of T allocated on the current CUDA device
195
- size_t capacity;
196
-
197
- /// Smart pointer
198
- platform::unique_ptr<T, deleter> smart_ptr;
199
-
200
- public:
201
-
202
- //
203
- // Static methods
204
- //
205
-
206
- /// Static member to compute the number of bytes needed for a given number of elements
207
- static size_t bytes(size_t elements) {
208
- if (sizeof_bits<T>::value < 8) {
209
- size_t const kElementsPerByte = 8 / sizeof_bits<T>::value;
210
- return elements / kElementsPerByte;
211
- }
212
- else {
213
- size_t const kBytesPerElement = sizeof_bits<T>::value / 8;
214
- return elements * kBytesPerElement;
215
- }
216
- }
217
-
218
- public:
219
-
220
- //
221
- // Methods
222
- //
223
-
224
- /// Constructor: allocates no memory
225
- DeviceAllocation() : capacity(0) {}
226
-
227
- /// Constructor: allocates \p capacity elements on the current CUDA device
228
- DeviceAllocation(size_t _capacity) :
229
- smart_ptr(device_memory::allocate<T>(_capacity)), capacity(_capacity) {}
230
-
231
- /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation
232
- DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {}
233
-
234
- /// Copy constructor
235
- DeviceAllocation(DeviceAllocation const &p):
236
- smart_ptr(device_memory::allocate<T>(p.capacity)), capacity(p.capacity) {
237
-
238
- device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity);
239
- }
240
-
241
- /// Move constructor
242
- DeviceAllocation(DeviceAllocation &&p): capacity(0) {
243
- std::swap(smart_ptr, p.smart_ptr);
244
- std::swap(capacity, p.capacity);
245
- }
246
-
247
- /// Destructor
248
- ~DeviceAllocation() { reset(); }
249
-
250
- /// Returns a pointer to the managed object
251
- T* get() const { return smart_ptr.get(); }
252
-
253
- /// Releases the ownership of the managed object (without deleting) and resets capacity to zero
254
- T* release() {
255
- capacity = 0;
256
- return smart_ptr.release();
257
- }
258
-
259
- /// Deletes the managed object and resets capacity to zero
260
- void reset() {
261
- capacity = 0;
262
- smart_ptr.reset();
263
- }
264
-
265
- /// Deletes managed object, if owned, and allocates a new object
266
- void reset(size_t _capacity) {
267
- reset(device_memory::allocate<T>(_capacity), _capacity);
268
- }
269
-
270
- /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity
271
- void reset(T* _ptr, size_t _capacity) {
272
- smart_ptr.reset(_ptr);
273
- capacity = _capacity;
274
- }
275
-
276
- /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released.
277
- void reallocate(size_t new_capacity) {
278
-
279
- platform::unique_ptr<T, deleter> new_allocation(device_memory::allocate<T>(new_capacity));
280
-
281
- device_memory::copy_device_to_device(
282
- new_allocation.get(),
283
- smart_ptr.get(),
284
- std::min(new_capacity, capacity));
285
-
286
- std::swap(smart_ptr, new_allocation);
287
- std::swap(new_capacity, capacity);
288
- }
289
-
290
- /// Returns the number of elements
291
- size_t size() const {
292
- return capacity;
293
- }
294
-
295
- /// Returns the number of bytes needed to store the allocation
296
- size_t bytes() const {
297
- return bytes(capacity);
298
- }
299
-
300
- /// Returns a pointer to the object owned by *this
301
- T* operator->() const { return smart_ptr.get(); }
302
-
303
- /// Returns the deleter object which would be used for destruction of the managed object.
304
- deleter& get_deleter() { return smart_ptr.get_deleter(); }
305
-
306
- /// Returns the deleter object which would be used for destruction of the managed object (const)
307
- const deleter& get_deleter() const { return smart_ptr.get_deleter(); }
308
-
309
- /// Copies a device-side memory allocation
310
- DeviceAllocation & operator=(DeviceAllocation const &p) {
311
- if (capacity != p.capacity) {
312
- smart_ptr.reset(device_memory::allocate<T>(p.capacity));
313
- capacity = p.capacity;
314
- }
315
- device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity);
316
- return *this;
317
- }
318
-
319
- /// Move assignment
320
- DeviceAllocation & operator=(DeviceAllocation && p) {
321
- std::swap(smart_ptr, p.smart_ptr);
322
- std::swap(capacity, p.capacity);
323
- return *this;
324
- }
325
-
326
- /// Copies the entire allocation from another location in device memory.
327
- void copy_from_device(T const *ptr) const {
328
- copy_from_device(ptr, capacity);
329
- }
330
-
331
- /// Copies a given number of elements from device memory
332
- void copy_from_device(T const *ptr, size_t elements) const {
333
- device_memory::copy_device_to_device(get(), ptr, elements);
334
- }
335
-
336
- void copy_to_device(T *ptr) const {
337
- copy_to_device(ptr, capacity);
338
- }
339
-
340
- void copy_to_device(T *ptr, size_t elements) const {
341
- device_memory::copy_device_to_device(ptr, get(), elements);
342
- }
343
-
344
- void copy_from_host(T const *ptr) const {
345
- copy_from_host(ptr, capacity);
346
- }
347
-
348
- void copy_from_host(T const *ptr, size_t elements) const {
349
- device_memory::copy_to_device(get(), ptr, elements);
350
- }
351
-
352
- void copy_to_host(T *ptr) const {
353
- copy_to_host(ptr, capacity);
354
- }
355
-
356
- void copy_to_host(T *ptr, size_t elements) const {
357
- device_memory::copy_to_host(ptr, get(), elements);
358
- }
359
- };
360
-
361
- /////////////////////////////////////////////////////////////////////////////////////////////////
362
-
363
- namespace device_memory {
364
-
365
- /// Device allocation abstraction that tracks size and capacity
366
- template <typename T>
367
- using allocation = cutlass::DeviceAllocation<T>;
368
-
369
- } // namespace device_memory
370
-
371
- /////////////////////////////////////////////////////////////////////////////////////////////////
372
-
373
- } // namespace cutlass
374
-
375
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h DELETED
@@ -1,141 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout.
37
- */
38
-
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_coord.h"
43
- #include "cutlass/tensor_ref.h"
44
-
45
- namespace cutlass {
46
-
47
- /** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout.
48
- * \tparam T: data type
49
- */
50
- template <typename T>
51
- void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size,
52
- cutlass::Tensor4DCoord output_tensor_size,
53
- TensorRef<T, layout::TensorNCHW> ref_input,
54
- TensorRef<T, layout::TensorNHWC> ref_output,
55
- cudaStream_t stream);
56
-
57
- template <typename T>
58
- __global__ void nchw_to_nhwc_kernel(T *output,
59
- const T *input,
60
- const int n,
61
- const int h,
62
- const int w,
63
- const int c) {
64
- const int hw = h*w;
65
- const int chw = c*hw;
66
- __shared__ T shbuf[32 * (32 + 1)];
67
- const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x;
68
- const int32_t wid = tid / 32;
69
- const int32_t lid = tid % 32;
70
- const int32_t ni = blockIdx.z;
71
- const int32_t ci0 = blockIdx.y * 32;
72
- const int32_t hwi0 = blockIdx.x * 32;
73
-
74
- const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0;
75
- const T *A = input + input_idx;
76
- if (hwi0 + lid < hw) {
77
- const int lid_x_33 = lid * 33;
78
- if ((ci0 + 32) <= c) {
79
- int ci = wid; // between 0 and 7
80
- CUTLASS_PRAGMA_UNROLL
81
- for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) {
82
- shbuf[lid_x_33 + ci] = A[lid];
83
- A = &A[8 * hw];
84
- ci += 8;
85
- }
86
- } else {
87
- for (int ci = wid; ci < 32; ci += 8) {
88
- if ((ci + ci0) < c) {
89
- shbuf[lid_x_33 + ci] = A[lid];
90
- }
91
- A = &A[8 * hw];
92
- }
93
- }
94
- }
95
- __syncthreads();
96
-
97
- const int32_t ciOut = ci0 + lid;
98
- output = &output[ni * chw + ciOut];
99
- if (ciOut < c) {
100
- if (hwi0 + 32 < hw) {
101
- int hwI = wid;
102
- CUTLASS_PRAGMA_UNROLL
103
- for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) {
104
- output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid];
105
- hwI += 8;
106
- }
107
- } else {
108
- for (int hwI = wid; hwI < 32; hwI += 8) {
109
- if (hwi0 + hwI < hw) {
110
- output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid];
111
- }
112
- }
113
- }
114
- }
115
- }
116
-
117
- template <typename T>
118
- void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size,
119
- cutlass::Tensor4DCoord output_tensor_size,
120
- TensorRef<T, layout::TensorNCHW> ref_input,
121
- TensorRef<T, layout::TensorNHWC> ref_output,
122
- cudaStream_t stream) {
123
-
124
- assert(
125
- input_tensor_size.n() == output_tensor_size.n() &&
126
- input_tensor_size.c() == output_tensor_size.h() &&
127
- input_tensor_size.h() == output_tensor_size.w() &&
128
- input_tensor_size.w() == output_tensor_size.c());
129
-
130
- int n = output_tensor_size.n();
131
- int h = output_tensor_size.h();
132
- int w = output_tensor_size.w();
133
- int c = output_tensor_size.c();
134
-
135
- dim3 grid((h*w + 31)/32, (c + 31)/32, n);
136
- dim3 block(32, 8);
137
- nchw_to_nhwc_kernel<<<grid, block, 0, stream>>>(ref_output.data(), ref_input.data(),
138
- n, h, w, c);
139
- }
140
-
141
- } //namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h DELETED
@@ -1,276 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief cuda kernels for padding in device memory with NHWC layout.
37
- */
38
-
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_coord.h"
43
- #include "cutlass/tensor_ref.h"
44
-
45
- namespace cutlass {
46
-
47
- /** \brief interface for padding in a device memory tensor with NHWC layout
48
- * \tparam T: data type
49
- */
50
- template <typename T>
51
- void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size,
52
- cutlass::Tensor4DCoord output_tensor_size,
53
- TensorRef<T, layout::TensorNHWC> ref_input,
54
- TensorRef<T, layout::TensorNHWC> ref_output,
55
- cudaStream_t stream);
56
-
57
-
58
- template <typename T>
59
- __global__ void nhwc_padding_kernel(const int32_t n,
60
- const int32_t h,
61
- const int32_t w,
62
- const int32_t c_in,
63
- const int32_t c_out,
64
- const T zero,
65
- const T *input,
66
- T *output){
67
-
68
- const int32_t idx_jump = blockDim.x * gridDim.x;
69
- const int32_t total_elements = n * h * w * c_out;
70
-
71
- int32_t c_idx, w_idx, h_idx, n_idx, resudial;
72
-
73
- T value;
74
- for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) {
75
-
76
- c_idx = idx%c_out;
77
- if (c_idx >= c_in){
78
- value = zero;
79
- }
80
- else{
81
- resudial = idx/c_out;
82
- w_idx = resudial%w;
83
- resudial = resudial/w;
84
- h_idx = resudial%h;
85
- n_idx = resudial/h;
86
- resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx;
87
- value = input[resudial];
88
- }
89
- output[idx] = value;
90
- }
91
- }
92
-
93
-
94
- // fast kernel for c_in = 3 & c_out = 4
95
- template <typename Tio, typename Telement, int element_in_Tio>
96
- __global__ void nhwc_padding_channel_3To4_kernel(const int32_t n,
97
- const int32_t h,
98
- const int32_t w,
99
- const Tio *input,
100
- Tio *output,
101
- const int32_t max_output_element,
102
- const int32_t max_input_element,
103
- const Tio zero_io,
104
- const Telement zero_element){
105
- __shared__ Tio shm[192];
106
- const int tidx = blockIdx.x * 192 + threadIdx.x;
107
- const int threadidx = threadIdx.x;
108
-
109
- shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
110
- __syncthreads();
111
-
112
- const int output_offset = blockIdx.x * 256;
113
- const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256;
114
- for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
115
- {
116
- const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4;
117
- Telement array[element_in_Tio];
118
- CUTLASS_PRAGMA_UNROLL
119
- for (int k = 0 ; k < element_in_Tio ; k++)
120
- array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k];
121
- output[i] = *((const Tio *)array);
122
- }
123
- }
124
-
125
- // fast kernel for c_in = 3 & c_out = 8
126
- template <typename Tio, typename Telement, int element_in_Tio>
127
- __global__ void nhwc_padding_channel_3To8_kernel(const int32_t n,
128
- const int32_t h,
129
- const int32_t w,
130
- const Tio *input,
131
- Tio *output,
132
- const int32_t max_output_element,
133
- const int32_t max_input_element,
134
- const Tio zero_io,
135
- const Telement zero_element){
136
- __shared__ Tio shm[192];
137
- const int tidx = blockIdx.x * 192 + threadIdx.x;
138
- const int threadidx = threadIdx.x;
139
-
140
- shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
141
- __syncthreads();
142
-
143
- const int output_offset = blockIdx.x * 512;
144
- const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512;
145
- for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
146
- {
147
- const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3;
148
- Telement array[element_in_Tio];
149
- //float
150
- if (element_in_Tio == 4){
151
- CUTLASS_PRAGMA_UNROLL
152
- for (int k = 0 ; k < element_in_Tio ; k++)
153
- array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]);
154
- }
155
- //half
156
- else{
157
- CUTLASS_PRAGMA_UNROLL
158
- for (int k = 0 ; k < element_in_Tio ; k++)
159
- array[k] = (k >= 3) ? zero_element : shm_element[k];
160
- }
161
- output[i] = *((const Tio *)array);
162
- }
163
- }
164
-
165
- template <typename T>
166
- void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size,
167
- cutlass::Tensor4DCoord output_tensor_size,
168
- TensorRef<T, layout::TensorNHWC> ref_input,
169
- TensorRef<T, layout::TensorNHWC> ref_output,
170
- cudaStream_t stream){
171
- assert(
172
- input_tensor_size.n() == output_tensor_size.n() &&
173
- input_tensor_size.h() == output_tensor_size.h() &&
174
- input_tensor_size.w() == output_tensor_size.w() &&
175
- input_tensor_size.c() <= output_tensor_size.c());
176
-
177
- int n = input_tensor_size.n();
178
- int h = input_tensor_size.h();
179
- int w = input_tensor_size.w();
180
- int c_in = input_tensor_size.c();
181
- int c_out = output_tensor_size.c();
182
-
183
- //case 1 : channel == 3 padding to 4 or 8
184
- if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){
185
- dim3 block(192);
186
- const int nhw = n*h*w;
187
- const int nhwc = nhw*c_in;
188
- //for half_t
189
- if (cutlass::sizeof_bits<T>::value == 16){
190
- const int element_in_Tio = 8;
191
- const int max_input_element = nhwc/element_in_Tio;
192
- const int max_output_element = nhw*c_out/element_in_Tio;
193
- const int4 zero_io = {0, 0, 0, 0};
194
- const half_t zero_element = static_cast<half_t>(0.0f);
195
- dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio));
196
- if (c_out == 4){
197
- nhwc_padding_channel_3To4_kernel<int4, half_t, element_in_Tio><<<grid, block, 0, stream>>>
198
- (n, h, w,
199
- (const int4 *)ref_input.data(),
200
- (int4 *)ref_output.data(),
201
- max_output_element,
202
- max_input_element,
203
- zero_io,
204
- zero_element);
205
- }
206
- else if (c_out == 8){
207
- nhwc_padding_channel_3To8_kernel<int4, half_t, element_in_Tio><<<grid, block, 0, stream>>>
208
- (n, h, w,
209
- (const int4 *)ref_input.data(),
210
- (int4 *)ref_output.data(),
211
- max_output_element,
212
- max_input_element,
213
- zero_io,
214
- zero_element);
215
- }
216
- }
217
- //for float
218
- else{
219
- const int element_in_Tio = 4;
220
- const int max_input_element = nhwc/element_in_Tio;
221
- const int max_output_element = nhw*c_out/element_in_Tio;
222
- const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f};
223
- const float zero_element = 0.0f;
224
- dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio));
225
- if (c_out == 4){
226
- nhwc_padding_channel_3To4_kernel<float4, float, element_in_Tio><<<grid, block, 0, stream>>>
227
- (n, h, w,
228
- (const float4 *)ref_input.data(),
229
- (float4 *)ref_output.data(),
230
- max_output_element,
231
- max_input_element,
232
- zero_io,
233
- zero_element);
234
- }
235
- else if (c_out == 8){
236
- nhwc_padding_channel_3To8_kernel<float4, float, element_in_Tio><<<grid, block, 0, stream>>>
237
- (n, h, w,
238
- (const float4 *)ref_input.data(),
239
- (float4 *)ref_output.data(),
240
- max_output_element,
241
- max_input_element,
242
- zero_io,
243
- zero_element);
244
- }
245
- }
246
- }
247
- //case 2 : even channel
248
- else if ((c_out % 2) == 0 && (c_in % 2) == 0){
249
- int32_t total_elements = n * h * w * c_out / 2;
250
- int block_size = 256;
251
- dim3 grid((total_elements + 255)/256);
252
- dim3 block(block_size);
253
- //for half_t
254
- if (cutlass::sizeof_bits<T>::value == 16){
255
- const __half2 zero = {0.0f, 0.0f};
256
- nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data());
257
- }
258
- //for float
259
- else{
260
- const float2 zero = {0.0f, 0.0f};
261
- nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data());
262
- }
263
- }
264
- //case 3 : odd channel
265
- else{
266
- int32_t total_elements = n * h * w * c_out;
267
- int block_size = 256;
268
- dim3 grid((total_elements + 255)/256);
269
- dim3 block(block_size);
270
- const T zero = static_cast<T>(0.0f);
271
- nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data());
272
- }
273
- }
274
-
275
-
276
- } //namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h DELETED
@@ -1,573 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout.
37
- */
38
-
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_coord.h"
43
- #include "cutlass/tensor_ref.h"
44
- #include "device_utils.h"
45
- #include <cfloat>
46
-
47
- namespace cutlass {
48
-
49
- /** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout.
50
- * \tparam T: data type
51
- */
52
- template <typename T>
53
- void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size,
54
- cutlass::Tensor4DCoord filter_tensor_size,
55
- cutlass::Tensor4DCoord output_tensor_size,
56
- cutlass::MatrixCoord padding,
57
- cutlass::MatrixCoord stride,
58
- TensorRef<T, layout::TensorNHWC> ref_input,
59
- TensorRef<T, layout::TensorNHWC> ref_output,
60
- int poolingType, //0 for avg pooling ; 1 for max pooling
61
- cudaStream_t stream);
62
-
63
- /** get the output size of pooling
64
- */
65
- inline int getOutputSize(int H_W, int padding, int kernel_size, int stride)
66
- {
67
- return (H_W + 2 * padding - kernel_size) / stride + 1;
68
- }
69
-
70
- /**
71
- * input is [N, H, W, C]
72
- * assume stride == kernel_size
73
- * output_h = (H + 2*padding_H - kernel_H)/stride_H
74
- * output_w = (W + 2*padding_W - kernel_W)/stride_W
75
- * output is [N, output_h, output_w, C]
76
- * grid(N, output_h, output_w)
77
- * block(min(C, 256)) :
78
- * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output)
79
- */
80
- template<typename T, bool IS_AVG_POOLING>
81
- __global__ void pooling_nhwc_element1_kernel(T* output,
82
- const T* input,
83
- const int N,
84
- const int H,
85
- const int W,
86
- const int C,
87
- const int output_H,
88
- const int output_W,
89
- const int kernel_H,
90
- const int kernel_W,
91
- const int stride_H,
92
- const int stride_W,
93
- const int padding_H,
94
- const int padding_W)
95
- {
96
- const int tid = threadIdx.x;
97
- const int n_idx = blockIdx.x;
98
- const int output_h_idx = blockIdx.y;
99
- const int output_w_idx = blockIdx.z;
100
-
101
- int h_start_idx = output_h_idx * stride_H - padding_H;
102
- int h_end_idx = h_start_idx + kernel_H;
103
- h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx;
104
- h_end_idx = h_end_idx > H ? H : h_end_idx;
105
-
106
- int w_start_idx = output_w_idx * stride_W - padding_W;
107
- int w_end_idx = w_start_idx + kernel_W;
108
- w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx;
109
- w_end_idx = w_end_idx > W ? W : w_end_idx;
110
-
111
- input += n_idx * H * W * C;
112
- output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C;
113
- const int kernel_size2 = kernel_H * kernel_W;
114
- for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) {
115
- float pooling;
116
- if (IS_AVG_POOLING){
117
- pooling = 0.0f;
118
- }
119
- else{
120
- pooling = -FLT_MAX;
121
- }
122
- for (int h = h_start_idx; h < h_end_idx; h++) {
123
- for (int w = w_start_idx; w < w_end_idx; w++) {
124
- const int idx = (h * W + w) * C;
125
- const float tmp = static_cast<float>(input[idx + c_idx]);
126
- if (IS_AVG_POOLING){
127
- pooling = pooling + tmp;
128
- }
129
- else{
130
- pooling = pooling > tmp ? pooling : tmp;
131
- }
132
- }
133
- }
134
-
135
- T output_val;
136
- if (IS_AVG_POOLING){
137
- output_val = T(pooling/kernel_size2);
138
- }
139
- else{
140
- output_val = T(pooling);
141
- }
142
- output[c_idx] = output_val;
143
- }
144
- }
145
-
146
- template<typename T2, typename T, bool IS_AVG_POOLING>
147
- __global__ void pooling_nhwc_element2_kernel(T2* output,
148
- const T2* input,
149
- const int N,
150
- const int H,
151
- const int W,
152
- const int C,
153
- const int output_H,
154
- const int output_W,
155
- const int kernel_H,
156
- const int kernel_W,
157
- const int stride_H,
158
- const int stride_W,
159
- const int padding_H,
160
- const int padding_W)
161
- {
162
- const int tid = threadIdx.x;
163
- const int n_idx = blockIdx.x;
164
- const int output_h_idx = blockIdx.y;
165
- const int output_w_idx = blockIdx.z;
166
-
167
- int h_start_idx = output_h_idx * stride_H - padding_H;
168
- int h_end_idx = h_start_idx + kernel_H;
169
- h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx;
170
- h_end_idx = h_end_idx > H ? H : h_end_idx;
171
-
172
- int w_start_idx = output_w_idx * stride_W - padding_W;
173
- int w_end_idx = w_start_idx + kernel_W;
174
- w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx;
175
- w_end_idx = w_end_idx > W ? W : w_end_idx;
176
-
177
- input += n_idx * H * W * C;
178
- output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C;
179
- const int kernel_size2 = kernel_H * kernel_W;
180
- for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) {
181
- float2 pooling;
182
- if (IS_AVG_POOLING) {
183
- pooling = {0.0f, 0.0f};
184
- }
185
- else {
186
- pooling = {-FLT_MAX, -FLT_MAX};
187
- }
188
- for (int h = h_start_idx; h < h_end_idx; h++) {
189
- for (int w = w_start_idx; w < w_end_idx; w++) {
190
- const int idx = (h * W + w) * C;
191
- const T2 tmp = input[idx + c_idx];
192
- const float2 tmp_flt2 = {static_cast<float>(tmp.x), static_cast<float>(tmp.y)};
193
- if (IS_AVG_POOLING) {
194
- pooling.x += tmp_flt2.x;
195
- pooling.y += tmp_flt2.y;
196
- }
197
- else {
198
- pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x;
199
- pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y;
200
- }
201
- }
202
- }
203
-
204
- T2 output_val;
205
- if (IS_AVG_POOLING) {
206
- output_val.x = T(pooling.x/kernel_size2);
207
- output_val.y = T(pooling.y/kernel_size2);
208
- }
209
- else {
210
- output_val.x = T(pooling.x);
211
- output_val.y = T(pooling.y);
212
- }
213
- output[c_idx] = output_val;
214
- }
215
- }
216
-
217
- /**
218
- * output [N, 1, 1, C]
219
- * input [N, H, W, C]
220
- * grid(C, N)
221
- * block(block_size) -- each block deals with H*W/block_size elements;
222
- */
223
- template<typename T, bool IS_AVG_POOLING>
224
- __global__ void pooling_nxhTo1x1_element1_kernel(
225
- T* output, const T* input, const int N, const int HW, const int C)
226
- {
227
- const int c_idx = blockIdx.x;
228
- const int n_idx = blockIdx.y;
229
- float pooling[1];
230
- if (IS_AVG_POOLING) {
231
- pooling[0] = 0.0f;
232
- }
233
- else {
234
- pooling[0] = -FLT_MAX;
235
- }
236
- const size_t input_offset = n_idx * HW * C + c_idx;
237
- input += input_offset;
238
- const size_t output_offset = n_idx * C + c_idx;
239
- output += output_offset;
240
- int tid = threadIdx.x;
241
-
242
- for (int index = tid; index < HW; index += blockDim.x) {
243
- float val = static_cast<float>(input[index * C]);
244
- if (IS_AVG_POOLING) {
245
- pooling[0] += val;
246
- }
247
- else {
248
- pooling[0] = pooling[0] > val ? pooling[0] : val;
249
- }
250
- }
251
- if (blockDim.x <= 32) {
252
- if (IS_AVG_POOLING) {
253
- warpReduceSum<float, 1>(pooling);
254
- }
255
- else {
256
- warpReduceMax<float, 1>(pooling);
257
- }
258
- }
259
- else {
260
- if (IS_AVG_POOLING) {
261
- blockReduceSum<float, 1>(pooling);
262
- }
263
- else {
264
- blockReduceMax<float, 1>(pooling);
265
- }
266
- }
267
- __syncthreads();
268
- if (threadIdx.x == 0) {
269
- T output_val;
270
- if (IS_AVG_POOLING) {
271
- output_val = T(pooling[0] / HW);
272
- }
273
- else {
274
- output_val = T(pooling[0]);
275
- }
276
- output[0] = output_val;
277
- }
278
- }
279
-
280
-
281
- /**
282
- * output [N, 1, 1, C]
283
- * input [N, H, W, C]
284
- * grid(C/2, N)
285
- * block(block_size) -- each thread deals with H*W/block_size * 2 elements;
286
- */
287
- template<typename T2, typename T, bool IS_AVG_POOLING>
288
- __global__ void pooling_nxhTo1x1_element2_kernel(
289
- T2* output, const T2* input, const int N, const int HW, const int C)
290
- {
291
- const int c_idx = blockIdx.x;
292
- const int n_idx = blockIdx.y;
293
- float pooling[2];
294
- if (IS_AVG_POOLING) {
295
- pooling[0] = pooling[1] = 0.0f;
296
- }
297
- else {
298
- pooling[0] = pooling[1] = -FLT_MAX;
299
- }
300
- const int C_2 = C / 2;
301
- const size_t input_offset = n_idx * HW * C_2 + c_idx;
302
- input += input_offset;
303
- const size_t output_offset = n_idx * C_2 + c_idx;
304
- output += output_offset;
305
- int tid = threadIdx.x;
306
-
307
- for (int index = tid; index < HW; index += blockDim.x) {
308
- T2 val = input[index * C_2];
309
- float2 val_flt2 = {static_cast<float>(val.x), static_cast<float>(val.y)};
310
- if (IS_AVG_POOLING) {
311
- pooling[0] += val_flt2.x;
312
- pooling[1] += val_flt2.y;
313
- }
314
- else {
315
- pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x;
316
- pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y;
317
- }
318
- }
319
- if (blockDim.x <= 32) {
320
- if (IS_AVG_POOLING) {
321
- warpReduceSum<float, 2>(pooling);
322
- }
323
- else {
324
- warpReduceMax<float, 2>(pooling);
325
- }
326
- }
327
- else {
328
- if (IS_AVG_POOLING) {
329
- blockReduceSum<float, 2>(pooling);
330
- }
331
- else {
332
- blockReduceMax<float, 2>(pooling);
333
- }
334
- }
335
- __syncthreads();
336
- if (threadIdx.x == 0) {
337
- T2 output_val;
338
- if (IS_AVG_POOLING) {
339
- output_val.x = T(pooling[0] / HW);
340
- output_val.y = T(pooling[1] / HW);
341
- }
342
- else {
343
- output_val.x = T(pooling[0]);
344
- output_val.y = T(pooling[1]);
345
- }
346
- output[0] = output_val;
347
- }
348
- }
349
-
350
- template <typename T>
351
- void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size,
352
- cutlass::Tensor4DCoord filter_tensor_size,
353
- cutlass::Tensor4DCoord output_tensor_size,
354
- cutlass::Tensor4DCoord padding,
355
- cutlass::MatrixCoord stride,
356
- TensorRef<T, layout::TensorNHWC> ref_input,
357
- TensorRef<T, layout::TensorNHWC> ref_output,
358
- int poolingType, //0 for avg pooling ; 1 for max pooling
359
- cudaStream_t stream) {
360
-
361
- assert(input_tensor_size.n() == output_tensor_size.n() &&
362
- input_tensor_size.c() == output_tensor_size.c());
363
-
364
- const int N = input_tensor_size.n();
365
- const int H = input_tensor_size.h();
366
- const int W = input_tensor_size.w();
367
- const int C = input_tensor_size.c();
368
- const int padding_H = padding.h();
369
- const int padding_W = padding.w();
370
- const int kernel_H = filter_tensor_size.h();
371
- const int kernel_W = filter_tensor_size.w();
372
- const int stride_H = stride.row();
373
- const int stride_W = stride.column();
374
-
375
- const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H);
376
- const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W);
377
-
378
- assert(output_tensor_size.h() == output_H &&
379
- output_tensor_size.w() == output_W);
380
-
381
- if (C % 2 != 0) {
382
- if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) {
383
- dim3 grid(C, N);
384
- dim3 block(256);
385
- if (H*W < block.x){
386
- block.x = (H*W + 31)/32*32;
387
- }
388
- if (poolingType == 0) {
389
- pooling_nxhTo1x1_element1_kernel<T, true><<<grid, block, 0, stream>>>(
390
- ref_output.data(),
391
- ref_input.data(),
392
- N,
393
- H*W,
394
- C);
395
- } // if (poolingType == 0)
396
- else {
397
- pooling_nxhTo1x1_element1_kernel<T, false><<<grid, block, 0, stream>>>(
398
- ref_output.data(),
399
- ref_input.data(),
400
- N,
401
- H*W,
402
- C);
403
- }
404
- } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0))
405
- else {
406
- dim3 grid(N, output_H, output_W);
407
- dim3 block(256);
408
- if (C < block.x) {
409
- block.x = C;
410
- }
411
- if (poolingType == 0) {
412
- pooling_nhwc_element1_kernel<T, true><<<grid, block, 0, stream>>>(
413
- ref_output.data(),
414
- ref_input.data(),
415
- N,
416
- H,
417
- W,
418
- C,
419
- output_H,
420
- output_W,
421
- kernel_H,
422
- kernel_W,
423
- stride_H,
424
- stride_W,
425
- padding_H,
426
- padding_W);
427
- } // if (poolingType == 0)
428
- else {
429
- pooling_nhwc_element1_kernel<T, false><<<grid, block, 0, stream>>>(
430
- ref_output.data(),
431
- ref_input.data(),
432
- N,
433
- H,
434
- W,
435
- C,
436
- output_H,
437
- output_W,
438
- kernel_H,
439
- kernel_W,
440
- stride_H,
441
- stride_W,
442
- padding_H,
443
- padding_W);
444
- }
445
- }
446
- } // if (C % 2 != 0))
447
- else {
448
- if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) {
449
- dim3 grid(C/2, N);
450
- dim3 block(256);
451
- if (H*W < block.x){
452
- block.x = (H*W + 31)/32*32;
453
- }
454
- if (poolingType == 0) {
455
- if (std::is_same<T, float>::value) {
456
- pooling_nxhTo1x1_element2_kernel<float2, float, true><<<grid, block, 0, stream>>>(
457
- (float2*)(ref_output.data()),
458
- (const float2*)(ref_input.data()),
459
- N,
460
- H*W,
461
- C);
462
- } // if (std::is_same<T, float>::value)
463
- else {
464
- pooling_nxhTo1x1_element2_kernel<half2, half, true><<<grid, block, 0, stream>>>(
465
- (half2*)(ref_output.data()),
466
- (const half2*)(ref_input.data()),
467
- N,
468
- H*W,
469
- C);
470
- }
471
- } // if (poolingType == 0)
472
- else {
473
- if (std::is_same<T, float>::value) {
474
- pooling_nxhTo1x1_element2_kernel<float2, float, false><<<grid, block, 0, stream>>>(
475
- (float2*)(ref_output.data()),
476
- (const float2*)(ref_input.data()),
477
- N,
478
- H*W,
479
- C);
480
- } // if (std::is_same<T, float>::value)
481
- else {
482
- pooling_nxhTo1x1_element2_kernel<half2, half, false><<<grid, block, 0, stream>>>(
483
- (half2*)(ref_output.data()),
484
- (const half2*)(ref_input.data()),
485
- N,
486
- H*W,
487
- C);
488
- }
489
- }
490
- } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0))
491
- else {
492
- dim3 grid(N, output_H, output_W);
493
- dim3 block(256);
494
- if (C/2 < block.x) {
495
- block.x = C/2;
496
- }
497
- if (poolingType == 0) {
498
- if (std::is_same<T, float>::value) {
499
- pooling_nhwc_element2_kernel<float2, float, true><<<grid, block, 0, stream>>>(
500
- (float2*)(ref_output.data()),
501
- (const float2*)(ref_input.data()),
502
- N,
503
- H,
504
- W,
505
- C/2,
506
- output_H,
507
- output_W,
508
- kernel_H,
509
- kernel_W,
510
- stride_H,
511
- stride_W,
512
- padding_H,
513
- padding_W);
514
- } // if (std::is_same<T, float>::value)
515
- else {
516
- pooling_nhwc_element2_kernel<half2, half, true><<<grid, block, 0, stream>>>(
517
- (half2*)(ref_output.data()),
518
- (const half2*)(ref_input.data()),
519
- N,
520
- H,
521
- W,
522
- C/2,
523
- output_H,
524
- output_W,
525
- kernel_H,
526
- kernel_W,
527
- stride_H,
528
- stride_W,
529
- padding_H,
530
- padding_W);
531
- }
532
- } // if (poolingType == 0)
533
- else {
534
- if (std::is_same<T, float>::value) {
535
- pooling_nhwc_element2_kernel<float2, float, false><<<grid, block, 0, stream>>>(
536
- (float2*)(ref_output.data()),
537
- (const float2*)(ref_input.data()),
538
- N,
539
- H,
540
- W,
541
- C/2,
542
- output_H,
543
- output_W,
544
- kernel_H,
545
- kernel_W,
546
- stride_H,
547
- stride_W,
548
- padding_H,
549
- padding_W);
550
- } // if (std::is_same<T, float>::value)
551
- else {
552
- pooling_nhwc_element2_kernel<half2, half, false><<<grid, block, 0, stream>>>(
553
- (half2*)(ref_output.data()),
554
- (const half2*)(ref_input.data()),
555
- N,
556
- H,
557
- W,
558
- C/2,
559
- output_H,
560
- output_W,
561
- kernel_H,
562
- kernel_W,
563
- stride_H,
564
- stride_W,
565
- padding_H,
566
- padding_W);
567
- }
568
- }
569
- }
570
- }
571
- }
572
-
573
- } //namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h DELETED
@@ -1,144 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout.
37
- */
38
-
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_coord.h"
43
- #include "cutlass/tensor_ref.h"
44
-
45
- namespace cutlass {
46
-
47
- /** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout.
48
- * \tparam T: data type
49
- */
50
- template <typename T>
51
- void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size,
52
- cutlass::Tensor4DCoord output_tensor_size,
53
- TensorRef<T, layout::TensorNHWC> ref_input,
54
- TensorRef<T, layout::TensorNCHW> ref_output,
55
- cudaStream_t stream);
56
-
57
-
58
- template <typename T>
59
- __global__ void nhwc_to_nchw_kernel(T *output,
60
- const T *input,
61
- const int n,
62
- const int h,
63
- const int w,
64
- const int c) {
65
-
66
- const int hw = h*w;
67
- const int hwc = hw*c;
68
- __shared__ T shbuf[32 * (32 + 1)];
69
- const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x;
70
- const int32_t wid = tid / 32;
71
- const int32_t lid = tid % 32;
72
- const int32_t ni = blockIdx.z;
73
- const int32_t hwi0 = blockIdx.y * 32;
74
- const int32_t ci0 = blockIdx.x * 32;
75
-
76
- const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0;
77
- const T *A = input + input_idx;
78
- if (ci0 + lid < c) {
79
- const int lid_x_33 = lid * 33;
80
- if ((hwi0 + 32) <= hw) {
81
- int hwi = wid; // between 0 and 7
82
- CUTLASS_PRAGMA_UNROLL
83
- for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) {
84
- shbuf[lid_x_33 + hwi] = A[lid];
85
- A = &A[8 * c];
86
- hwi += 8;
87
- }
88
- } else {
89
- for (int hwi = wid; hwi < 32; hwi += 8) {
90
- if ((hwi + hwi0) < hw) {
91
- shbuf[lid_x_33 + hwi] = A[lid];
92
- }
93
- A = &A[8 * c];
94
- }
95
- }
96
- }
97
- __syncthreads();
98
-
99
- const int32_t hwiOut = hwi0 + lid;
100
- output = &output[ni * hwc + hwiOut];
101
- if (hwiOut < hw) {
102
- if (ci0 + 32 < c) {
103
- int cI = wid;
104
- CUTLASS_PRAGMA_UNROLL
105
- for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) {
106
- output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid];
107
- cI += 8;
108
- }
109
- } else {
110
- for (int cI = wid; cI < 32; cI += 8) {
111
- if (ci0 + cI < c) {
112
- output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid];
113
- }
114
- }
115
- }
116
- }
117
- }
118
-
119
- template <typename T>
120
- void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size,
121
- cutlass::Tensor4DCoord output_tensor_size,
122
- TensorRef<T, layout::TensorNHWC> ref_input,
123
- TensorRef<T, layout::TensorNCHW> ref_output,
124
- cudaStream_t stream) {
125
-
126
- assert(
127
- input_tensor_size.n() == output_tensor_size.n() &&
128
- input_tensor_size.h() == output_tensor_size.c() &&
129
- input_tensor_size.w() == output_tensor_size.h() &&
130
- input_tensor_size.c() == output_tensor_size.w());
131
-
132
- int n = input_tensor_size.n();
133
- int h = input_tensor_size.h();
134
- int w = input_tensor_size.w();
135
- int c = input_tensor_size.c();
136
-
137
- dim3 grid((c + 31)/32, (h*w + 31)/32, n);
138
- dim3 block(32, 8);
139
- nhwc_to_nchw_kernel<<<grid, block, 0, stream>>>(ref_output.data(), ref_input.data(),
140
- n, h, w, c);
141
-
142
- }
143
-
144
- } //namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h DELETED
@@ -1,186 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/layout/tensor.h"
36
- #include "cutlass/numeric_types.h"
37
- #include "cutlass/tensor_coord.h"
38
- #include "cutlass/tensor_ref.h"
39
- #include "cutlass/util/device_utils.h"
40
- #include <cfloat>
41
-
42
- namespace cutlass {
43
-
44
- __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
45
- const float4 *weight,
46
- const int m, const int n, float epsilon) {
47
- const int m_idx = blockIdx.x;
48
- const int tid = threadIdx.x;
49
- const int bdimx = blockDim.x;
50
- __shared__ float s_mean;
51
- float local_sums[1] = {0.0f};
52
- const int n_8 = n / 8;
53
- int offset = m_idx * n_8;
54
- input += offset;
55
- output += offset;
56
-
57
- for (int index = tid; index < n_8; index += bdimx) {
58
- const float4 local_val = input[index];
59
- const half2 *h1 = (half2 *)&local_val.x;
60
- const half2 *h2 = (half2 *)&local_val.y;
61
- const half2 *h3 = (half2 *)&local_val.z;
62
- const half2 *h4 = (half2 *)&local_val.w;
63
- local_sums[0] += static_cast<float>(h1->x) * static_cast<float>(h1->x) +
64
- static_cast<float>(h1->y) * static_cast<float>(h1->y) +
65
- static_cast<float>(h2->x) * static_cast<float>(h2->x) +
66
- static_cast<float>(h2->y) * static_cast<float>(h2->y) +
67
- static_cast<float>(h3->x) * static_cast<float>(h3->x) +
68
- static_cast<float>(h3->y) * static_cast<float>(h3->y) +
69
- static_cast<float>(h4->x) * static_cast<float>(h4->x) +
70
- static_cast<float>(h4->y) * static_cast<float>(h4->y);
71
- }
72
-
73
- if (blockDim.x <= 32) {
74
- warpReduceSum<float, 1>(local_sums);
75
- } else {
76
- blockReduceSum<float, 1>(local_sums);
77
- }
78
- if (threadIdx.x == 0) {
79
- s_mean = rsqrtf(local_sums[0] / n + epsilon);
80
- }
81
- __syncthreads();
82
-
83
- for (int index = tid; index < n_8; index += bdimx) {
84
- const float4 local_val = input[index];
85
- const float4 weight_val = weight[index];
86
-
87
- const half2 *l1 = (half2 *)&local_val.x;
88
- const half2 *l2 = (half2 *)&local_val.y;
89
- const half2 *l3 = (half2 *)&local_val.z;
90
- const half2 *l4 = (half2 *)&local_val.w;
91
-
92
- const half2 *g1 = (half2 *)&weight_val.x;
93
- const half2 *g2 = (half2 *)&weight_val.y;
94
- const half2 *g3 = (half2 *)&weight_val.z;
95
- const half2 *g4 = (half2 *)&weight_val.w;
96
-
97
- float4 tmp;
98
- half2 *h1 = (half2 *)&tmp.x;
99
- half2 *h2 = (half2 *)&tmp.y;
100
- half2 *h3 = (half2 *)&tmp.z;
101
- half2 *h4 = (half2 *)&tmp.w;
102
-
103
- h1->x = half(static_cast<float>(l1->x) * s_mean * static_cast<float>(g1->x));
104
- h1->y = half(static_cast<float>(l1->y) * s_mean * static_cast<float>(g1->y));
105
- h2->x = half(static_cast<float>(l2->x) * s_mean * static_cast<float>(g2->x));
106
- h2->y = half(static_cast<float>(l2->y) * s_mean * static_cast<float>(g2->y));
107
- h3->x = half(static_cast<float>(l3->x) * s_mean * static_cast<float>(g3->x));
108
- h3->y = half(static_cast<float>(l3->y) * s_mean * static_cast<float>(g3->y));
109
- h4->x = half(static_cast<float>(l4->x) * s_mean * static_cast<float>(g4->x));
110
- h4->y = half(static_cast<float>(l4->y) * s_mean * static_cast<float>(g4->y));
111
-
112
- output[index] = tmp;
113
- }
114
- }
115
-
116
- template<typename T>
117
- __global__ void rmsnorm_twoPassAlgo_e1(T* output,
118
- const T* input,
119
- const T* weight,
120
- const int m, const int n,
121
- float epsilon)
122
- {
123
- const int m_idx = blockIdx.x;
124
- const int tid = threadIdx.x;
125
- const int bdimx = blockDim.x;
126
- __shared__ float s_mean;
127
- float local_sums[1] = {0.0f};
128
- int offset = m_idx * n;
129
- input += offset;
130
- output += offset;
131
-
132
- for (int index = tid ; index < n ; index += bdimx){
133
- float local_val = static_cast<float>(input[index]);
134
- local_sums[0] += local_val * local_val;
135
- }
136
- if (blockDim.x <= 32) {
137
- warpReduceSum<float, 1>(local_sums);
138
- }
139
- else {
140
- blockReduceSum<float, 1>(local_sums);
141
- }
142
- if (threadIdx.x == 0) {
143
- s_mean = rsqrtf(local_sums[0] / n + epsilon);
144
- }
145
- __syncthreads();
146
-
147
- for (int index = tid ; index < n ; index += bdimx){
148
- const T weight_val = weight[index];
149
- const T local_val = input[index];
150
- output[index] = T(static_cast<float>(local_val) * s_mean * static_cast<float>(weight_val));
151
- }
152
- }
153
-
154
- template <typename T>
155
- void rmsnorm(cutlass::MatrixCoord tensor_size,
156
- TensorRef<T, layout::RowMajor> ref_output,
157
- TensorRef<T, layout::RowMajor> ref_input,
158
- TensorRef<T, layout::RowMajor> ref_weight,
159
- cudaStream_t stream, float epsilon = 1e-5f){
160
- const int m = tensor_size.row();
161
- const int n = tensor_size.column();
162
- T* output = ref_output.data();
163
- const T* input = ref_input.data();
164
- const T* weight = ref_weight.data();
165
- dim3 grid(m);
166
-
167
- if (n % 8 == 0 && std::is_same<T, cutlass::half_t>::value) {
168
- dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32));
169
-
170
- rmsnorm_twoPassAlgo_e8<<<grid, block, 0, stream>>>(
171
- (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon);
172
- } else {
173
- dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32));
174
-
175
- rmsnorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
176
- output, input, weight, m, n, epsilon);
177
- }
178
-
179
- auto result = cudaGetLastError();
180
- if (result != cudaSuccess) {
181
- std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl;
182
- abort();
183
- }
184
- }
185
-
186
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h DELETED
@@ -1,127 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief utils code for device cutlass code
34
- */
35
-
36
- #pragma once
37
-
38
- #include <cuda_fp16.h>
39
- #include <cfloat>
40
- #define FINAL_MASK 0xffffffff
41
-
42
- struct half4 {
43
- half x, y, z, w;
44
- };
45
-
46
- template<typename T, int NUM>
47
- __inline__ __device__ T warpReduceSum(T* val)
48
- {
49
- #pragma unroll
50
- for (int i = 0; i < NUM; i++) {
51
- #pragma unroll
52
- for (int mask = 16; mask > 0; mask >>= 1)
53
- val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
54
- }
55
- return (T)(0.0f);
56
- }
57
-
58
- template<typename T, int NUM>
59
- __inline__ __device__ T blockReduceSum(T* val)
60
- {
61
- __shared__ T shared[NUM][33];
62
- int lane = threadIdx.x & 0x1f;
63
- int wid = threadIdx.x >> 5;
64
-
65
- warpReduceSum<T, NUM>(val);
66
-
67
- if (lane == 0) {
68
- #pragma unroll
69
- for (int i = 0; i < NUM; i++) {
70
- shared[i][wid] = val[i];
71
- }
72
- }
73
-
74
- __syncthreads();
75
-
76
- bool is_mask = threadIdx.x < (blockDim.x / 32.f);
77
- #pragma unroll
78
- for (int i = 0; i < NUM; i++) {
79
- val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
80
- }
81
- warpReduceSum<T, NUM>(val);
82
- return (T)0.0f;
83
- }
84
-
85
- template<typename T, int NUM>
86
- __inline__ __device__ T warpReduceMax(T* val)
87
- {
88
- #pragma unroll
89
- for (int i = 0; i < NUM; i++) {
90
- #pragma unroll
91
- for (int mask = 16; mask > 0; mask >>= 1)
92
- val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
93
- }
94
- return (T)(0.0f);
95
- }
96
-
97
- template<typename T, int NUM>
98
- __inline__ __device__ T blockReduceMax(T* val)
99
- {
100
- static __shared__ T shared[32][NUM];
101
- int lane = threadIdx.x & 0x1f; // in-warp idx
102
- int wid = threadIdx.x >> 5; // warp idx
103
-
104
- warpReduceMax<T, NUM>(val); // get maxx in each warp
105
-
106
- if (lane == 0) // record in-warp maxx by warp Idx
107
- {
108
- #pragma unroll
109
- for (int i = 0; i < NUM; i++) {
110
- shared[wid][i] = val[i];
111
- }
112
- }
113
-
114
- __syncthreads();
115
-
116
- // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
117
- // blockDim.x is not divided by 32
118
- bool is_mask = threadIdx.x < (blockDim.x / 32.f);
119
- #pragma unroll
120
- for (int i = 0; i < NUM; i++) {
121
- val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX);
122
- }
123
- warpReduceMax<T, NUM>(val);
124
-
125
- return (T)0.0f;
126
- }
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h DELETED
@@ -1,157 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- /*! \file
34
- \brief This header contains a class to parametrize a statistical distribution function.
35
- */
36
-
37
- #include <ostream>
38
-
39
- namespace cutlass {
40
-
41
- ////////////////////////////////////////////////////////////////////////////////////////////////////
42
-
43
- /// Distribution type
44
- struct Distribution {
45
- /// Variant types
46
- enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes };
47
-
48
- /// Distribution state
49
- union {
50
- /// Uniform distribution
51
- struct {
52
- double min;
53
- double max;
54
- // Percent elements set to NaN
55
- double pnan;
56
- } uniform;
57
-
58
- /// Gaussian distribution
59
- struct {
60
- double mean;
61
- double stddev;
62
- double pnz;
63
- double pnzA;
64
- double pnzB;
65
- double pnzC;
66
- } gaussian;
67
-
68
- /// Elements are linear combination of row and column index
69
- struct {
70
- double start;
71
- double delta;
72
- } sequential;
73
- };
74
-
75
- /// Active variant kind
76
- Kind kind;
77
-
78
- /// Random values are cast to integer after scaling by this power of two
79
- int int_scale;
80
-
81
- //
82
- // Methods
83
- //
84
-
85
- Distribution() : kind(Invalid), int_scale(0) {}
86
-
87
- /// Configures distribution as uniform random
88
- Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) {
89
- kind = Uniform;
90
- uniform.min = _min;
91
- uniform.max = _max;
92
- int_scale = _int_scale;
93
- uniform.pnan = _pnan;
94
- return *this;
95
- }
96
-
97
- /// Configures distribution as Gaussian distribution
98
- Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) {
99
- kind = Gaussian;
100
- gaussian.mean = _mean;
101
- gaussian.stddev = _stddev;
102
- gaussian.pnz = _pnz;
103
- gaussian.pnzA = _pnz;
104
- gaussian.pnzB = _pnz;
105
- gaussian.pnzC = _pnz;
106
- int_scale = _int_scale;
107
- return *this;
108
- }
109
-
110
- /// Sets identity
111
- Distribution &set_identity() {
112
- kind = Identity;
113
- return *this;
114
- }
115
-
116
- /// Sets sequential
117
- Distribution &set_sequential(double start, double delta, int _int_scale = 0) {
118
- kind = Sequential;
119
- sequential.start = start;
120
- sequential.delta = delta;
121
- int_scale = _int_scale;
122
- return *this;
123
- }
124
- };
125
-
126
- } // namespace cutlass
127
-
128
- ////////////////////////////////////////////////////////////////////////////////////////////////////
129
-
130
- /// Prints a Distribution to ostream
131
- inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) {
132
- switch (dist.kind) {
133
- case cutlass::Distribution::Uniform:
134
- out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max
135
- << ", pnan: " << dist.uniform.pnan;
136
- break;
137
- case cutlass::Distribution::Gaussian:
138
- out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev
139
- << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: "
140
- << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC;
141
- break;
142
- case cutlass::Distribution::Identity:
143
- out << "identity";
144
- break;
145
- case cutlass::Distribution::Sequential:
146
- out << "sequential";
147
- break;
148
- default:
149
- out << "unknown";
150
- }
151
-
152
- out << ", int_scale: " << dist.int_scale;
153
-
154
- return out;
155
- }
156
-
157
- ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h DELETED
@@ -1,69 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- ******************************************************************************/
31
-
32
- #pragma once
33
-
34
- /**
35
- * \file
36
- * \brief C++ exception semantics for CUDA error codes
37
- */
38
-
39
- #include <cuda_runtime.h>
40
- #include <iosfwd>
41
- #include <stdexcept>
42
-
43
- #include "cutlass/platform/platform.h"
44
-
45
- namespace cutlass {
46
-
47
- /// C++ exception wrapper for CUDA \p cudaError_t
48
- class cuda_exception : public std::exception {
49
- public:
50
- /// Constructor
51
- cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {}
52
-
53
- /// Returns the underlying CUDA \p cudaError_t
54
- cudaError_t cudaError() const { return err; }
55
-
56
- protected:
57
- /// Explanatory string
58
- const char* msg;
59
-
60
- /// Underlying CUDA \p cudaError_t
61
- cudaError_t err;
62
- };
63
-
64
- /// Writes a cuda_exception instance to an output stream
65
- inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) {
66
- return out << e.what() << ": " << cudaGetErrorString(e.cudaError());
67
- }
68
-
69
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp DELETED
@@ -1,369 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief GETT command line parser to gather semantic modes, their stride order, and extents.
33
- */
34
- #pragma once
35
-
36
- #include <iostream>
37
- #include <iomanip>
38
- #include <utility>
39
- #include <type_traits>
40
- #include <vector>
41
- #include <map>
42
- #include <algorithm>
43
- #include <numeric>
44
-
45
- #include "cutlass/util/command_line.h"
46
-
47
- namespace cutlass {
48
-
49
- // Output shortcuts
50
- std::ostream& operator<<(std::ostream& os, std::vector<char> data) {
51
- for (auto& a : data) os << a;
52
- return os;
53
- }
54
-
55
- template <class T>
56
- std::ostream& operator<<(std::ostream& os, std::vector<T> data) {
57
- for (auto& a : data) os << a << " ";
58
- return os;
59
- }
60
-
61
- struct GettCommandLine {
62
- struct GettProblem {
63
- using extent_type = int;
64
- using stride_type = int64_t;
65
-
66
- // Row modes: appear in A and C/D
67
- std::vector<extent_type> M;
68
- std::vector<stride_type> ldAm;
69
- std::vector<stride_type> ldCm;
70
-
71
- // Column modes: appear in B and C/D
72
- std::vector<extent_type> N;
73
- std::vector<stride_type> ldBn;
74
- std::vector<stride_type> ldCn;
75
-
76
- // Reduction modes: appear in A and B
77
- std::vector<extent_type> K;
78
- std::vector<stride_type> ldAk;
79
- std::vector<stride_type> ldBk;
80
-
81
- // Batch modes: appear in all in/out tensors
82
- std::vector<extent_type> L;
83
- std::vector<stride_type> ldAl;
84
- std::vector<stride_type> ldBl;
85
- std::vector<stride_type> ldCl;
86
- };
87
-
88
- static GettProblem
89
- parse(int argc, char const* argv[], bool parse_verbose = false) {
90
- using extent_type = typename GettProblem::extent_type;
91
- using stride_type = typename GettProblem::stride_type;
92
-
93
- cutlass::CommandLine cmd(argc, argv);
94
-
95
- // modeA
96
- std::vector<char> a_mode;
97
- cmd.get_cmd_line_arguments("modeA", a_mode);
98
-
99
- // modeB
100
- std::vector<char> b_mode;
101
- cmd.get_cmd_line_arguments("modeB", b_mode);
102
-
103
- // modeC
104
- std::vector<char> c_mode;
105
- cmd.get_cmd_line_arguments("modeC", c_mode);
106
-
107
-
108
- // mode_sizes
109
- std::map<char,extent_type> mode_size;
110
- // First, initialize all modes in a, b, c to make sure they're in map
111
- for (char a : a_mode) mode_size[a] = 1;
112
- for (char b : b_mode) mode_size[b] = 1;
113
- for (char c : c_mode) mode_size[c] = 1;
114
-
115
- // Then, overwrite the ones in -extent
116
- std::vector<std::pair<std::string, std::string> > extent_tokens;
117
- cmd.get_cmd_line_argument_pairs("extents", extent_tokens);
118
- for (auto e : extent_tokens) {
119
- if (std::get<0>(e).size() > 1) {
120
- std::cerr << "ERROR: Mode name must only be 1 character long.\n";
121
- print_usage();
122
- exit(1);
123
- }
124
- char label = std::get<0>(e)[0];
125
- int size = std::stoi(std::get<1>(e));
126
- mode_size[label] = size;
127
- }
128
-
129
- // Print out symbolic modes and their extents
130
- if (parse_verbose) {
131
- std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n";
132
- for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n";
133
- }
134
-
135
- //
136
- // Collect/Compute strides
137
- //
138
-
139
- std::map<char,stride_type> mode_ldA;
140
- std::map<char,stride_type> mode_ldB;
141
- std::map<char,stride_type> mode_ldC;
142
-
143
- {
144
- stride_type current;
145
-
146
- current = 1;
147
- for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; }
148
-
149
- current = 1;
150
- for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; }
151
-
152
- current = 1;
153
- for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; }
154
- }
155
-
156
- //
157
- // Collect mode categories
158
- //
159
-
160
- std::vector<char> row_mode; // rows
161
- std::vector<char> col_mode; // columns
162
- std::vector<char> red_mode; // reductions
163
- std::vector<char> bat_mode; // batches
164
-
165
- {
166
- std::vector<char> a_label = a_mode;
167
- std::vector<char> b_label = b_mode;
168
- std::vector<char> c_label = c_mode;
169
-
170
- std::sort(std::begin(a_label), std::end(a_label));
171
- std::sort(std::begin(b_label), std::end(b_label));
172
- std::sort(std::begin(c_label), std::end(c_label));
173
-
174
- // std::set_intersections to find semantic category of each symbolic mode
175
- std::set_intersection(std::begin(a_label), std::end(a_label),
176
- std::begin(c_label), std::end(c_label),
177
- std::back_inserter(row_mode));
178
-
179
- std::set_intersection(std::begin(b_label), std::end(b_label),
180
- std::begin(c_label), std::end(c_label),
181
- std::back_inserter(col_mode));
182
-
183
- std::set_intersection(std::begin(a_label), std::end(a_label),
184
- std::begin(b_label), std::end(b_label),
185
- std::back_inserter(red_mode));
186
-
187
- std::set_intersection(std::begin(row_mode), std::end(row_mode),
188
- std::begin(col_mode), std::end(col_mode),
189
- std::back_inserter(bat_mode));
190
-
191
- // std::set_difference to remove batch modes from other semantic modes
192
- for (char l : bat_mode) {
193
- row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode));
194
- col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode));
195
- red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode));
196
- }
197
- }
198
-
199
- // Print out the semantic association of each symbolic mode
200
- if (parse_verbose) {
201
- std::cout << " rows : " << row_mode << '\n';
202
- std::cout << " cols : " << col_mode << '\n';
203
- std::cout << " reds : " << red_mode << '\n';
204
- std::cout << " bats : " << bat_mode << '\n';
205
- }
206
-
207
- //
208
- // Permute modes
209
- //
210
-
211
- // Permute the batched modes to promote coalescing
212
- // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size
213
- std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) {
214
- return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1])
215
- < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]);
216
- });
217
- // Compute sizes and strides of ordered reduction modes
218
- std::vector<extent_type> L;
219
- std::vector<stride_type> ldAl;
220
- std::vector<stride_type> ldBl;
221
- std::vector<stride_type> ldCl;
222
- for (char l : bat_mode) {
223
- L.push_back(mode_size[l]);
224
- ldAl.push_back(mode_ldA[l]);
225
- ldBl.push_back(mode_ldB[l]);
226
- ldCl.push_back(mode_ldC[l]);
227
- }
228
-
229
- // Permute the reduction modes to promote coalescing
230
- // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size
231
- std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) {
232
- return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1])
233
- < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]);
234
- });
235
- // Compute sizes and strides of ordered reduction modes
236
- std::vector<extent_type> K;
237
- std::vector<stride_type> ldAk;
238
- std::vector<stride_type> ldBk;
239
- for (char k : red_mode) {
240
- K.push_back(mode_size[k]);
241
- ldAk.push_back(mode_ldA[k]);
242
- ldBk.push_back(mode_ldB[k]);
243
- }
244
-
245
- // Permute the row modes to promote coalescing
246
- // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm
247
- std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) {
248
- return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1])
249
- < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]);
250
- });
251
- // Compute sizes and strides of ordered row modes
252
- std::vector<extent_type> M;
253
- std::vector<stride_type> ldAm;
254
- std::vector<stride_type> ldCm;
255
- for (char m : row_mode) {
256
- M.push_back(mode_size[m]);
257
- ldAm.push_back(mode_ldA[m]);
258
- ldCm.push_back(mode_ldC[m]);
259
- }
260
-
261
- // Permute the col modes to promote coalescing
262
- // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn
263
- std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) {
264
- return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1])
265
- < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]);
266
- });
267
- // Compute sizes and strides of ordered col modes
268
- std::vector<extent_type> N;
269
- std::vector<stride_type> ldBn;
270
- std::vector<stride_type> ldCn;
271
- for (char n : col_mode) {
272
- N.push_back(mode_size[n]);
273
- ldBn.push_back(mode_ldB[n]);
274
- ldCn.push_back(mode_ldC[n]);
275
- }
276
-
277
- if (parse_verbose) {
278
- std::cout << "C_";
279
- if (! row_mode.empty()) {
280
- std::cout << "(" << row_mode << ")";
281
- }
282
- if (! col_mode.empty()) {
283
- std::cout << "(" << col_mode << ")";
284
- }
285
- if (! bat_mode.empty()) {
286
- std::cout << "(" << bat_mode << ")";
287
- }
288
- std::cout << " = A_";
289
- if (! row_mode.empty()) {
290
- std::cout << "(" << row_mode << ")";
291
- }
292
- if (! red_mode.empty()) {
293
- std::cout << "(" << red_mode << ")";
294
- }
295
- if (! bat_mode.empty()) {
296
- std::cout << "(" << bat_mode << ")";
297
- }
298
- std::cout << " * B_";
299
- if (! col_mode.empty()) {
300
- std::cout << "(" << col_mode << ")";
301
- }
302
- if (! red_mode.empty()) {
303
- std::cout << "(" << red_mode << ")";
304
- }
305
- if (! bat_mode.empty()) {
306
- std::cout << "(" << bat_mode << ")";
307
- }
308
- std::cout << '\n';
309
-
310
- int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{});
311
- int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{});
312
- int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{});
313
- int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{});
314
-
315
- std::cout << " M : (" << M_size << ") ";
316
- for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " ";
317
- std::cout << '\n';
318
- std::cout << " N : (" << N_size << ") ";
319
- for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " ";
320
- std::cout << '\n';
321
- std::cout << " K : (" << K_size << ") ";
322
- for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " ";
323
- std::cout << '\n';
324
- std::cout << " L : (" << L_size << ") ";
325
- for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " ";
326
- std::cout << '\n';
327
-
328
- std::cout << " ldAm : " << ldAm << '\n';
329
- std::cout << " ldAk : " << ldAk << '\n';
330
- std::cout << " ldAl : " << ldAl << '\n';
331
- std::cout << " ldBn : " << ldBn << '\n';
332
- std::cout << " ldBk : " << ldBk << '\n';
333
- std::cout << " ldBl : " << ldBl << '\n';
334
- std::cout << " ldCm : " << ldCm << '\n';
335
- std::cout << " ldCn : " << ldCn << '\n';
336
- std::cout << " ldCl : " << ldCl << '\n';
337
- }
338
-
339
- return {M, ldAm, ldCm,
340
- N, ldBn, ldCn,
341
- K, ldAk, ldBk,
342
- L, ldAl, ldBl, ldCl};
343
- }
344
-
345
- static void
346
- print_usage() {
347
- std::cout <<
348
- "GETT problem command line parser:\n"
349
- " --modeA=<m0,...>\n"
350
- " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n"
351
- " The semantic association of each symbolic mode is determined automatically.\n\n"
352
-
353
- " --modeB=<m0,...>\n"
354
- " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n"
355
- " The semantic association of each symbolic mode is determined automatically.\n\n"
356
-
357
- " --modeC=<m0,...>\n"
358
- " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n"
359
- " The semantic association of each symbolic mode is determined automatically.\n\n"
360
-
361
- " --extents=<mode:extent,....>\n"
362
- " A command delimited list of symbolic mode and its corresponding extent.\n"
363
- " Extents are defaulted to 1 if any are not provided.\n\n"
364
-
365
- "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n";
366
- }
367
- };
368
-
369
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp DELETED
@@ -1,116 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <cuda.h>
35
-
36
- #include <cute/util/debug.hpp>
37
-
38
- namespace cute
39
- {
40
-
41
- void
42
- device_init(int device_id, bool quiet = false)
43
- {
44
- cudaDeviceProp device_prop;
45
- std::size_t device_free_physmem;
46
- std::size_t device_total_physmem;
47
-
48
- CUTE_CHECK_ERROR(cudaSetDevice(device_id));
49
- CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem));
50
- CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id));
51
-
52
- if (device_prop.major < 1) {
53
- fprintf(stderr, "Device does not support CUDA.\n");
54
- exit(1);
55
- }
56
-
57
- //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000;
58
-
59
- if (!quiet) {
60
- printf("Using device %d: %s (SM%d, %d SMs)\n",
61
- device_id, device_prop.name,
62
- device_prop.major * 10 + device_prop.minor,
63
- device_prop.multiProcessorCount);
64
- fflush(stdout);
65
- }
66
- }
67
-
68
- /**
69
- * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores.
70
- */
71
- inline int
72
- _ConvertSMVer2Cores(int major, int minor)
73
- {
74
- // Defines for GPU Architecture types (using the SM version to determine
75
- // the # of cores per SM
76
- typedef struct {
77
- int SM; // 0xMm (hexadecimal notation), M = SM Major version,
78
- // and m = SM minor version
79
- int Cores;
80
- } sSMtoCores;
81
-
82
- sSMtoCores nGpuArchCoresPerSM[] = {
83
- {0x30, 192},
84
- {0x32, 192},
85
- {0x35, 192},
86
- {0x37, 192},
87
- {0x50, 128},
88
- {0x52, 128},
89
- {0x53, 128},
90
- {0x60, 64},
91
- {0x61, 128},
92
- {0x62, 128},
93
- {0x70, 64},
94
- {0x72, 64},
95
- {0x75, 64},
96
- {-1, -1}};
97
-
98
- int index = 0;
99
-
100
- while (nGpuArchCoresPerSM[index].SM != -1) {
101
- if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) {
102
- return nGpuArchCoresPerSM[index].Cores;
103
- }
104
- index++;
105
- }
106
-
107
- // If we don't find the values, we default use the previous one
108
- // to run properly
109
- printf("MapSMtoCores for SM %d.%d is undefined."
110
- " Default to use %d Cores/SM\n",
111
- major, minor, nGpuArchCoresPerSM[index - 1].Cores);
112
-
113
- return nGpuArchCoresPerSM[index - 1].Cores;
114
- }
115
-
116
- } // end namespace cute
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h DELETED
@@ -1,111 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief reorder data from the host side
34
- */
35
-
36
- #pragma once
37
-
38
- #include "cutlass/coord.h"
39
- #include "cutlass/util/host_tensor.h"
40
- #include "cutlass/tensor_view.h"
41
- #include "cutlass/util/tensor_view_io.h"
42
- #include "cutlass/util/reference/host/gemm.h"
43
-
44
- namespace cutlass {
45
-
46
- /// This is needed for the interleaved integer tensor core kernels. The purpose
47
- /// is to use skip the shared memory part in the epilogue.
48
- template <int Interleaved, typename Element, typename Layout>
49
- void reorder_column(TensorRef<Element, Layout> dest,
50
- TensorRef<Element, Layout> src,
51
- cutlass::gemm::GemmCoord problem_size) {
52
- const int InstructionShapeCol = 8;
53
- // 4 threads per Quad
54
- const int ElementsPerThread = InstructionShapeCol / 4;
55
- // 4 threads per Quad
56
- const int ReorderedElementsPerThread =
57
- Interleaved / 4;
58
-
59
- for (int n = 0; n < problem_size.n(); n++) {
60
- for (int k = 0; k < problem_size.k(); k++) {
61
- dest.at({k, (n / Interleaved) * Interleaved +
62
- ((n % ReorderedElementsPerThread) / ElementsPerThread) *
63
- InstructionShapeCol +
64
- ((n % Interleaved) / ReorderedElementsPerThread) *
65
- ElementsPerThread +
66
- (n % ElementsPerThread)}) = src.at({k, n});
67
- }
68
- }
69
- }
70
-
71
- template <int ColumnInterleaved, int LayoutInterleaved = ColumnInterleaved, typename Element, typename Layout>
72
- void reorder_convK(TensorRef<Element, Layout> dest,
73
- TensorRef<Element, Layout> src,
74
- cutlass::gemm::GemmCoord problem_size) {
75
-
76
- TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedDest(dest.data(), dest.stride(0));
77
- TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedSrc(src.data(), src.stride(0));
78
-
79
- reorder_column<ColumnInterleaved>(
80
- mappedDest, mappedSrc, problem_size);
81
- }
82
-
83
- /// This is needed for the sparse tensor core kernels. The purpose
84
- /// is to use ldmatrix to load from shared memory to the register file.
85
- template <typename Element, typename LayoutDest, typename LayoutSrc>
86
- void reorder_meta(TensorRef<Element, LayoutDest> dest,
87
- TensorRef<Element, LayoutSrc> src,
88
- cutlass::gemm::GemmCoord problem_size) {
89
- for (int m = 0; m < problem_size.m(); m++) {
90
- for (int k = 0; k < problem_size.k(); k++) {
91
- // First reorder the rows.
92
- int group = (sizeof(Element) == 2) ? 32 : 16;
93
- int interweave = (sizeof(Element) == 2) ? 4 : 2;
94
-
95
- int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
96
- int dest_col = k;
97
-
98
- // Next swizzle the 2x2 blocks from Z to N.
99
- if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
100
- ++dest_row;
101
- --dest_col;
102
- } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
103
- --dest_row;
104
- ++dest_col;
105
- }
106
-
107
- dest.at({dest_row, dest_col}) = src.at({m, k});
108
- }
109
- }
110
- }
111
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h DELETED
@@ -1,541 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- /*! \file
34
- \brief HostTensor contributes management for both host and device memory.
35
-
36
- HostTensor allocates host and device memory upon construction. Basic element-wise operations on
37
- host memory synchronize device memory automatically. Explicit copy operations provide abstractions
38
- for CUDA memcpy operations.
39
-
40
- Call {host, device}_{data, ref, view}() for accessing host or device memory.
41
-
42
- See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
43
- */
44
-
45
- #include <vector>
46
-
47
- #include "cutlass/cutlass.h"
48
- #include "cutlass/tensor_ref.h"
49
- #include "cutlass/tensor_view.h"
50
- #include "cutlass/fast_math.h"
51
-
52
- #include "device_memory.h"
53
-
54
- namespace cutlass {
55
-
56
- ///////////////////////////////////////////////////////////////////////////////////////////////////
57
-
58
- /// Host tensor
59
- template <
60
- /// Data type of element stored within tensor (concept: NumericType)
61
- typename Element_,
62
- /// Defines a mapping from logical coordinate to linear memory (concept: Layout)
63
- typename Layout_
64
- >
65
- class HostTensor {
66
- public:
67
-
68
- /// Data type of individual access
69
- using Element = Element_;
70
-
71
- /// Mapping function from logical coordinate to linear memory
72
- using Layout = Layout_;
73
-
74
- /// Logical rank of tensor index space
75
- static int const kRank = Layout::kRank;
76
-
77
- /// Index type
78
- using Index = typename Layout::Index;
79
-
80
- /// Long index used for pointer offsets
81
- using LongIndex = typename Layout::LongIndex;
82
-
83
- /// Coordinate in logical tensor space
84
- using TensorCoord = typename Layout::TensorCoord;
85
-
86
- /// Layout's stride vector
87
- using Stride = typename Layout::Stride;
88
-
89
- /// Tensor reference to device memory
90
- using TensorRef = TensorRef<Element, Layout>;
91
-
92
- /// Tensor reference to constant device memory
93
- using ConstTensorRef = typename TensorRef::ConstTensorRef;
94
-
95
- /// Tensor reference to device memory
96
- using TensorView = TensorView<Element, Layout>;
97
-
98
- /// Tensor reference to constant device memory
99
- using ConstTensorView = typename TensorView::ConstTensorView;
100
-
101
- /// Reference to element in tensor
102
- using Reference = typename TensorRef::Reference;
103
-
104
- /// Constant reference to element in tensor
105
- using ConstReference = typename ConstTensorRef::Reference;
106
-
107
- private:
108
- using StorageUnit = typename platform::conditional_t<std::is_same_v<Element, bool>, uint8_t, // Avoid the std::vector<bool> specialization
109
- typename platform::conditional_t<sizeof_bits<Element>::value % 8 == 0, // Handle subbyte types
110
- Element, uint8_t>>;
111
- using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
112
- static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits;
113
- static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements;
114
- static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes;
115
- static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit;
116
-
117
- //
118
- // Data members
119
- //
120
-
121
- /// Extent of tensor in logical dimensions
122
- TensorCoord extent_;
123
-
124
- /// Layout object
125
- Layout layout_;
126
-
127
- /// Host-side memory allocation
128
- std::vector<StorageUnit> host_;
129
-
130
- /// Device-side memory
131
- device_memory::allocation<StorageUnit> device_;
132
-
133
- /// number of containers
134
- size_t count_to_container_storage_unit_count(size_t count) {
135
- return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit;
136
- }
137
-
138
- public:
139
- //
140
- // Device and Host Methods
141
- //
142
-
143
- /// Default constructor
144
- HostTensor() {}
145
-
146
- /// Constructs a tensor given an extent. Assumes a packed layout
147
- HostTensor(
148
- TensorCoord const &extent,
149
- bool device_backed = true
150
- ) {
151
-
152
- this->reset(extent, Layout::packed(extent), device_backed);
153
- }
154
-
155
- /// Constructs a tensor given an extent and layout
156
- HostTensor(
157
- TensorCoord const &extent,
158
- Layout const &layout,
159
- bool device_backed = true
160
- ) {
161
-
162
- this->reset(extent, layout, device_backed);
163
- }
164
-
165
- ~HostTensor() { }
166
-
167
- /// Clears the HostTensor allocation to size/capacity = 0
168
- void reset() {
169
- extent_ = TensorCoord();
170
- layout_ = Layout::packed(extent_);
171
-
172
- host_.clear();
173
- device_.reset();
174
- }
175
-
176
- /// Resizes internal memory allocations without affecting layout or extent
177
- void reserve(
178
- size_t count, ///< size of tensor in elements
179
- bool device_backed_ = true) { ///< if true, device memory is also allocated
180
- #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
181
- CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")");
182
- #endif
183
-
184
- device_.reset();
185
- host_.clear();
186
-
187
- size_t count_container = count_to_container_storage_unit_count(count);
188
- #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
189
- CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")");
190
- #endif
191
- host_.resize(count_container);
192
-
193
- // Allocate memory
194
- StorageUnit* device_memory = nullptr;
195
- if (device_backed_) {
196
- #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
197
- CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")");
198
- #endif
199
- device_memory = device_memory::allocate<StorageUnit>(count_container);
200
- }
201
- device_.reset(device_memory, device_backed_ ? count_container : 0);
202
- }
203
-
204
- /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
205
- /// extent and layout.
206
- void reset(
207
- TensorCoord const &extent, ///< extent of logical tensor
208
- Layout const &layout, ///< layout object of tensor
209
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
210
-
211
- extent_ = extent;
212
- layout_ = layout;
213
-
214
- reserve(size_t(layout_.capacity(extent_)), device_backed_);
215
- }
216
-
217
- /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
218
- /// extent and layout. Assumes a packed tensor configuration.
219
- void reset(
220
- TensorCoord const &extent, ///< extent of logical tensor
221
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
222
-
223
- reset(extent, Layout::packed(extent), device_backed_);
224
- }
225
-
226
- /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
227
- /// To force allocation, call reset().
228
- void resize(
229
- TensorCoord const &extent, ///< extent of logical tensor
230
- Layout const &layout, ///< layout object of tensor
231
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
232
-
233
- extent_ = extent;
234
- layout_ = layout;
235
-
236
- LongIndex new_size = size_t(layout_.capacity(extent_));
237
- LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));
238
-
239
- if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
240
- reserve(new_size, device_backed_);
241
- }
242
- }
243
-
244
- /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
245
- /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
246
- void resize(
247
- TensorCoord const &extent, ///< extent of logical tensor
248
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
249
-
250
- resize(extent, Layout::packed(extent), device_backed_);
251
- }
252
-
253
- /// Returns the logical number of elements stored in the host tensor
254
- size_t size() const {
255
- return layout_.capacity(extent_);
256
- }
257
-
258
- /// Returns the logical capacity in terms of number of elements. May be larger than the size().
259
- LongIndex capacity() const {
260
- return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements;
261
- }
262
-
263
- /// Gets pointer to host data
264
- Element * host_data() { return reinterpret_cast<Element *>(host_.data()); }
265
-
266
- /// Gets pointer to host data with a pointer offset
267
- Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
268
-
269
- /// Gets a reference to an element in host memory
270
- Reference host_data(LongIndex idx) {
271
- return ReferenceFactory<Element>::get(host_data(), idx);
272
- }
273
-
274
- /// Gets pointer to host data
275
- Element const * host_data() const { return reinterpret_cast<Element const *>(host_.data()); }
276
-
277
- /// Gets pointer to host data with a pointer offset
278
- Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
279
-
280
- /// Gets a constant reference to an element in host memory
281
- ConstReference host_data(LongIndex idx) const {
282
- return ReferenceFactory<Element const>::get(host_data(), idx);
283
- }
284
-
285
- /// Gets pointer to device data
286
- Element * device_data() { return reinterpret_cast<Element *>(device_.get()); }
287
-
288
- /// Gets pointer to device data
289
- Element const * device_data() const { return reinterpret_cast<Element const *>(device_.get()); }
290
-
291
- /// Gets pointer to device data with a pointer offset
292
- Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
293
-
294
- /// Gets pointer to device data with a pointer offset
295
- Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
296
-
297
- /// Accesses the tensor reference pointing to data
298
- TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
299
-
300
- /// Accesses the tensor reference pointing to data
301
- ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
302
-
303
- /// Accesses the tensor reference pointing to data
304
- TensorRef device_ref(LongIndex ptr_element_offset=0) {
305
- return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
306
- }
307
-
308
- /// Accesses the tensor reference pointing to data
309
- ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
310
- return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
311
- }
312
-
313
- /// Accesses the tensor reference pointing to data
314
- TensorView host_view(LongIndex ptr_element_offset=0) {
315
- return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
316
- }
317
-
318
- /// Accesses the tensor reference pointing to data
319
- ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
320
- return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
321
- }
322
-
323
- /// Accesses the tensor reference pointing to data
324
- TensorView device_view(LongIndex ptr_element_offset=0) {
325
- return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
326
- }
327
-
328
- /// Accesses the tensor reference pointing to data
329
- ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
330
- return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
331
- }
332
-
333
- /// Returns true if device memory is allocated
334
- bool device_backed() const {
335
- return (device_.get() == nullptr) ? false : true;
336
- }
337
-
338
-
339
- /// Returns the layout object
340
- Layout & layout() {
341
- return layout_;
342
- }
343
-
344
- /// Returns the layout object
345
- Layout layout() const {
346
- return layout_;
347
- }
348
-
349
- /// Returns the layout object's stride vector
350
- Stride stride() const {
351
- return layout_.stride();
352
- }
353
-
354
- /// Returns the layout object's stride vector
355
- Stride & stride() {
356
- return layout_.stride();
357
- }
358
-
359
- /// Returns the layout object's stride in a given physical dimension
360
- LongIndex stride(int dim) const {
361
- return layout_.stride().at(dim);
362
- }
363
-
364
- /// Returns the layout object's stride in a given physical dimension
365
- LongIndex & stride(int dim) {
366
- return layout_.stride().at(dim);
367
- }
368
-
369
- /// Computes the offset of an index from the origin of the tensor
370
- LongIndex offset(TensorCoord const& coord) const {
371
- return layout_(coord);
372
- }
373
-
374
- /// Returns a reference to the element at the logical Coord in host memory
375
- Reference at(TensorCoord const& coord) {
376
- return host_data(offset(coord));
377
- }
378
-
379
- /// Returns a const reference to the element at the logical Coord in host memory
380
- ConstReference at(TensorCoord const& coord) const {
381
- return host_data(offset(coord));
382
- }
383
-
384
- /// Returns the extent of the tensor
385
- TensorCoord extent() const {
386
- return extent_;
387
- }
388
-
389
- /// Returns the extent of the tensor
390
- TensorCoord & extent() {
391
- return extent_;
392
- }
393
-
394
- /// Copies data from device to host
395
- void sync_host() {
396
- if (device_backed()) {
397
- device_memory::copy_to_host(
398
- host_.data(), device_.get(), device_.size());
399
- }
400
- }
401
-
402
- /// Copies data from host to device
403
- void sync_device() {
404
- if (device_backed()) {
405
- device_memory::copy_to_device(
406
- device_.get(), host_.data(), host_.size());
407
- }
408
- }
409
-
410
- /// Copy data from a caller-supplied device pointer into host memory.
411
- void copy_in_device_to_host(
412
- Element const* ptr_device, ///< source device memory
413
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
414
-
415
- if (count < 0) {
416
- count = capacity();
417
- }
418
- else {
419
- count = __NV_STD_MIN(capacity(), count);
420
- }
421
- size_t container_count = count_to_container_storage_unit_count(count);
422
- device_memory::copy_to_host(
423
- host_.data(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
424
- }
425
-
426
- /// Copy data from a caller-supplied device pointer into host memory.
427
- void copy_in_device_to_device(
428
- Element const* ptr_device, ///< source device memory
429
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
430
-
431
- if (count < 0) {
432
- count = capacity();
433
- }
434
- else {
435
- count = __NV_STD_MIN(capacity(), count);
436
- }
437
- size_t container_count = count_to_container_storage_unit_count(count);
438
- device_memory::copy_device_to_device(
439
- device_.get(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
440
- }
441
-
442
- /// Copy data from a caller-supplied device pointer into host memory.
443
- void copy_in_host_to_device(
444
- Element const* ptr_host, ///< source host memory
445
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
446
-
447
- if (count < 0) {
448
- count = capacity();
449
- }
450
- else {
451
- count = __NV_STD_MIN(capacity(), count);
452
- }
453
- size_t container_count = count_to_container_storage_unit_count(count);
454
- device_memory::copy_to_device(
455
- device_.get(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
456
- }
457
-
458
- /// Copy data from a caller-supplied device pointer into host memory.
459
- void copy_in_host_to_host(
460
- Element const* ptr_host, ///< source host memory
461
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
462
-
463
- if (count < 0) {
464
- count = capacity();
465
- }
466
- else {
467
- count = __NV_STD_MIN(capacity(), count);
468
- }
469
- size_t container_count = count_to_container_storage_unit_count(count);
470
- device_memory::copy_host_to_host(
471
- host_.data(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
472
- }
473
-
474
- /// Copy data from a caller-supplied device pointer into host memory.
475
- void copy_out_device_to_host(
476
- Element * ptr_host, ///< source device memory
477
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
478
-
479
- if (count < 0) {
480
- count = capacity();
481
- }
482
- else {
483
- count = __NV_STD_MIN(capacity(), count);
484
- }
485
- size_t container_count = count_to_container_storage_unit_count(count);
486
- device_memory::copy_to_host(
487
- reinterpret_cast<StorageUnit *>(ptr_host), device_.get(), container_count);
488
- }
489
-
490
- /// Copy data from a caller-supplied device pointer into host memory.
491
- void copy_out_device_to_device(
492
- Element * ptr_device, ///< source device memory
493
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
494
-
495
- if (count < 0) {
496
- count = capacity();
497
- }
498
- else {
499
- count = __NV_STD_MIN(capacity(), count);
500
- }
501
- size_t container_count = count_to_container_storage_unit_count(count);
502
- device_memory::copy_device_to_device(
503
- reinterpret_cast<StorageUnit *>(ptr_device), device_.get(), container_count);
504
- }
505
-
506
- /// Copy data from a caller-supplied device pointer into host memory.
507
- void copy_out_host_to_device(
508
- Element * ptr_device, ///< source host memory
509
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
510
-
511
- if (count < 0) {
512
- count = capacity();
513
- }
514
- else {
515
- count = __NV_STD_MIN(capacity(), count);
516
- }
517
- size_t container_count = count_to_container_storage_unit_count(count);
518
- device_memory::copy_to_device(
519
- reinterpret_cast<StorageUnit *>(ptr_device), host_.data(), container_count);
520
- }
521
-
522
- /// Copy data from a caller-supplied device pointer into host memory.
523
- void copy_out_host_to_host(
524
- Element * ptr_host, ///< source host memory
525
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
526
-
527
- if (count < 0) {
528
- count = capacity();
529
- }
530
- else {
531
- count = __NV_STD_MIN(capacity(), count);
532
- }
533
- size_t container_count = count_to_container_storage_unit_count(count);
534
- device_memory::copy_host_to_host(
535
- reinterpret_cast<StorageUnit *>(ptr_host), host_.data(), container_count);
536
- }
537
- };
538
-
539
- ///////////////////////////////////////////////////////////////////////////////////////////////////
540
-
541
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h DELETED
@@ -1,591 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- /*! \file
34
- \brief HostTensor contributes management for both host and device memory.
35
-
36
- HostTensor allocates host and device memory upon construction. Basic element-wise operations on
37
- host memory synchronize device memory automatically. Explicit copy operations provide abstractions
38
- for CUDA memcpy operations.
39
-
40
- Call {host, device}_{data, ref, view}() for accessing host or device memory.
41
-
42
- See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
43
- */
44
-
45
- #include <vector>
46
-
47
- #include "cutlass/cutlass.h"
48
-
49
- #include "cutlass/tensor_ref_planar_complex.h"
50
- #include "cutlass/tensor_view_planar_complex.h"
51
-
52
- #include "device_memory.h"
53
-
54
- namespace cutlass {
55
-
56
- ///////////////////////////////////////////////////////////////////////////////////////////////////
57
-
58
- /// Host tensor
59
- template <
60
- /// Data type of element stored within tensor (concept: NumericType)
61
- typename Element_,
62
- /// Defines a mapping from logical coordinate to linear memory (concept: Layout)
63
- typename Layout_
64
- >
65
- class HostTensorPlanarComplex {
66
- public:
67
-
68
- /// Data type of individual access
69
- using Element = Element_;
70
-
71
- /// Mapping function from logical coordinate to linear memory
72
- using Layout = Layout_;
73
-
74
- /// Logical rank of tensor index space
75
- static int const kRank = Layout::kRank;
76
-
77
- /// Index type
78
- using Index = typename Layout::Index;
79
-
80
- /// Long index used for pointer offsets
81
- using LongIndex = typename Layout::LongIndex;
82
-
83
- /// Coordinate in logical tensor space
84
- using TensorCoord = typename Layout::TensorCoord;
85
-
86
- /// Layout's stride vector
87
- using Stride = typename Layout::Stride;
88
-
89
- /// Tensor reference to device memory
90
- using TensorRef = TensorRefPlanarComplex<Element, Layout>;
91
-
92
- /// Tensor reference to constant device memory
93
- using ConstTensorRef = typename TensorRef::ConstTensorRef;
94
-
95
- /// Tensor reference to device memory
96
- using TensorView = TensorViewPlanarComplex<Element, Layout>;
97
-
98
- /// Tensor reference to constant device memory
99
- using ConstTensorView = typename TensorView::ConstTensorView;
100
-
101
- /// Reference to element in tensor
102
- using Reference = typename TensorRef::Reference;
103
-
104
- /// Constant reference to element in tensor
105
- using ConstReference = typename ConstTensorRef::Reference;
106
-
107
- private:
108
-
109
- //
110
- // Data members
111
- //
112
-
113
- /// Extent of tensor in logical dimensions
114
- TensorCoord extent_;
115
-
116
- /// Layout object
117
- Layout layout_;
118
-
119
- /// Host-side memory allocation
120
- std::vector<Element> host_;
121
-
122
- /// Device-side memory
123
- device_memory::allocation<Element> device_;
124
-
125
- public:
126
- //
127
- // Device and Host Methods
128
- //
129
-
130
- /// Default constructor
131
- HostTensorPlanarComplex() {}
132
-
133
- /// Constructs a tensor given an extent. Assumes a packed layout
134
- HostTensorPlanarComplex(
135
- TensorCoord const &extent,
136
- bool device_backed = true
137
- ) {
138
-
139
- this->reset(extent, Layout::packed(extent), device_backed);
140
- }
141
-
142
- /// Constructs a tensor given an extent and layout
143
- HostTensorPlanarComplex(
144
- TensorCoord const &extent,
145
- Layout const &layout,
146
- bool device_backed = true
147
- ) {
148
-
149
- this->reset(extent, layout, device_backed);
150
- }
151
-
152
- ~HostTensorPlanarComplex() { }
153
-
154
- /// Clears the HostTensor allocation to size/capacity = 0
155
- void reset() {
156
- extent_ = TensorCoord();
157
- layout_ = Layout::packed(extent_);
158
-
159
- host_.clear();
160
- device_.reset();
161
- }
162
-
163
- /// Resizes internal memory allocations without affecting layout or extent
164
- void reserve(
165
- size_t count, ///< size of tensor in elements
166
- bool device_backed_ = true) { ///< if true, device memory is also allocated
167
-
168
- device_.reset();
169
- host_.clear();
170
-
171
- host_.resize(count * 2);
172
-
173
- // Allocate memory
174
- Element* device_memory = nullptr;
175
- if (device_backed_) {
176
- device_memory = device_memory::allocate<Element>(count * 2);
177
- }
178
- device_.reset(device_memory, device_backed_ ? count * 2 : 0);
179
- }
180
-
181
- /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
182
- /// extent and layout.
183
- void reset(
184
- TensorCoord const &extent, ///< extent of logical tensor
185
- Layout const &layout, ///< layout object of tensor
186
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
187
-
188
- extent_ = extent;
189
- layout_ = layout;
190
-
191
- reserve(size_t(layout_.capacity(extent_)), device_backed_);
192
- }
193
-
194
- /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
195
- /// extent and layout. Assumes a packed tensor configuration.
196
- void reset(
197
- TensorCoord const &extent, ///< extent of logical tensor
198
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
199
-
200
- reset(extent, Layout::packed(extent), device_backed_);
201
- }
202
-
203
- /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
204
- /// To force allocation, call reset().
205
- void resize(
206
- TensorCoord const &extent, ///< extent of logical tensor
207
- Layout const &layout, ///< layout object of tensor
208
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
209
-
210
- extent_ = extent;
211
- layout_ = layout;
212
-
213
- LongIndex new_size = size_t(layout_.capacity(extent_));
214
-
215
- if (static_cast<decltype(host_.size())>(new_size * 2) > host_.size()) {
216
- reserve(new_size);
217
- }
218
- }
219
-
220
- /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
221
- /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
222
- void resize(
223
- TensorCoord const &extent, ///< extent of logical tensor
224
- bool device_backed_ = true) { ///< if true, device memory is also allocated.
225
-
226
- resize(extent, Layout::packed(extent), device_backed_);
227
- }
228
-
229
- /// Returns the number of elements stored in the host tensor
230
- size_t size() const {
231
- return host_.size() / 2;
232
- }
233
-
234
- /// Returns the logical capacity based on extent and layout. May differ from size().
235
- LongIndex capacity() const {
236
- return layout_.capacity(extent_);
237
- }
238
-
239
- /// Stride between real and imaginary parts
240
- LongIndex imaginary_stride() const {
241
- return host_.size() / 2;
242
- }
243
-
244
- /// Gets pointer to host data
245
- Element * host_data() { return host_.data(); }
246
-
247
- /// Gets pointer to host data imaginary part
248
- Element * host_data_imag() { return host_.data() + imaginary_stride(); }
249
-
250
- /// Gets pointer to host data with a pointer offset
251
- Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; }
252
-
253
- /// Gets pointer to host data with a pointer offset
254
- Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; }
255
-
256
- /// Gets a reference to an element in host memory
257
- Reference host_data(LongIndex idx) {
258
- return PlanarComplexReference<Element>(host_data() + idx, host_data_imag() + idx);
259
- }
260
-
261
- /// Gets pointer to host data
262
- Element const * host_data() const { return host_.data(); }
263
-
264
- /// Gets pointer to host data imaginary part
265
- Element const * host_data_imag() const { return host_.data() + imaginary_stride(); }
266
-
267
- /// Gets a constant reference to an element in host memory
268
- ConstReference host_data(LongIndex idx) const {
269
- return PlanarComplexReference<Element const>(host_data() + idx, host_data_imag() + idx);
270
- }
271
-
272
- /// Gets pointer to device data
273
- Element * device_data() { return device_.get(); }
274
-
275
- /// Gets pointer to device data with a pointer offset
276
- Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; }
277
-
278
- /// Gets pointer to device data
279
- Element const * device_data() const { return device_.get(); }
280
-
281
- /// Gets pointer to device data with a pointer offset
282
- Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
283
-
284
- /// Gets a pointer to the device data imaginary part
285
- Element * device_data_imag() { return device_.get() + imaginary_stride(); }
286
-
287
- /// Accesses the tensor reference pointing to data
288
- TensorRef host_ref(LongIndex ptr_element_offset=0) {
289
- return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
290
- }
291
-
292
- /// Returns a tensor reference to the real part of the tensor
293
- cutlass::TensorRef<Element, Layout> host_ref_real() {
294
- return cutlass::TensorRef<Element, Layout>(host_data(), layout_);
295
- }
296
-
297
- /// Returns a tensor reference to the real part of the tensor
298
- cutlass::TensorRef<Element, Layout> host_ref_imag() {
299
- return cutlass::TensorRef<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_);
300
- }
301
-
302
- /// Accesses the tensor reference pointing to data
303
- ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const {
304
- return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
305
- }
306
-
307
- /// Accesses the tensor reference pointing to data
308
- TensorRef device_ref(LongIndex ptr_element_offset=0) {
309
- return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
310
- }
311
-
312
- /// Accesses the tensor reference pointing to data
313
- ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
314
- return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
315
- }
316
-
317
- /// Returns a tensor reference to the real part of the tensor
318
- cutlass::TensorRef<Element, Layout> device_ref_real() {
319
- return cutlass::TensorRef<Element, Layout>(device_data(), layout_);
320
- }
321
-
322
- /// Returns a tensor reference to the real part of the tensor
323
- cutlass::TensorRef<Element, Layout> device_ref_imag() {
324
- return cutlass::TensorRef<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_);
325
- }
326
-
327
- /// Accesses the tensor reference pointing to data
328
- TensorView host_view(LongIndex ptr_element_offset=0) {
329
- return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
330
- }
331
-
332
- /// Accesses the tensor reference pointing to data
333
- ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
334
- return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
335
- }
336
-
337
- /// Accesses the tensor reference pointing to data
338
- cutlass::TensorView<Element, Layout> host_view_real() {
339
- return cutlass::TensorView<Element, Layout>(host_data(), layout_, extent_);
340
- }
341
-
342
- /// Accesses the tensor reference pointing to data
343
- cutlass::TensorView<Element, Layout> host_view_imag() {
344
- return cutlass::TensorView<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_, extent_);
345
- }
346
-
347
- /// Accesses the tensor reference pointing to data
348
- TensorView device_view(LongIndex ptr_element_offset=0) {
349
- return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
350
- }
351
-
352
- /// Accesses the tensor reference pointing to data
353
- ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
354
- return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
355
- }
356
-
357
- /// Accesses the tensor reference pointing to data
358
- cutlass::TensorView<Element, Layout> device_view_real() {
359
- return cutlass::TensorView<Element, Layout>(device_data(), layout_, extent_);
360
- }
361
-
362
- /// Accesses the tensor reference pointing to data
363
- cutlass::TensorView<Element, Layout> device_view_imag() {
364
- return cutlass::TensorView<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_, extent_);
365
- }
366
-
367
- /// Returns true if device memory is allocated
368
- bool device_backed() const {
369
- return (device_.get() == nullptr) ? false : true;
370
- }
371
-
372
- /// Returns the layout object
373
- Layout layout() const {
374
- return layout_;
375
- }
376
-
377
- /// Returns the layout object's stride vector
378
- Stride stride() const {
379
- return layout_.stride();
380
- }
381
-
382
- /// Returns the layout object's stride in a given physical dimension
383
- Index stride(int dim) const {
384
- return layout_.stride().at(dim);
385
- }
386
-
387
- /// Computes the offset of an index from the origin of the tensor
388
- LongIndex offset(TensorCoord const& coord) const {
389
- return layout_(coord);
390
- }
391
-
392
- /// Returns a reference to the element at the logical Coord in host memory
393
- Reference at(TensorCoord const& coord) {
394
- return host_data(offset(coord));
395
- }
396
-
397
- /// Returns a const reference to the element at the logical Coord in host memory
398
- ConstReference at(TensorCoord const& coord) const {
399
- return host_data(offset(coord));
400
- }
401
-
402
- /// Returns the extent of the tensor
403
- TensorCoord extent() const {
404
- return extent_;
405
- }
406
-
407
- /// Returns the extent of the tensor
408
- TensorCoord & extent() {
409
- return extent_;
410
- }
411
-
412
- /// Copies data from device to host
413
- void sync_host() {
414
- if (device_backed()) {
415
- device_memory::copy_to_host(
416
- host_data(), device_data(), imaginary_stride() * 2);
417
- }
418
- }
419
-
420
- /// Copies data from host to device
421
- void sync_device() {
422
- if (device_backed()) {
423
- device_memory::copy_to_device(
424
- device_data(), host_data(), imaginary_stride() * 2);
425
- }
426
- }
427
-
428
- /// Copy data from a caller-supplied device pointer into host memory.
429
- void copy_in_device_to_host(
430
- Element const* ptr_device_real, ///< source device memory
431
- Element const* ptr_device_imag, ///< source device memory
432
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
433
-
434
- if (count < 0) {
435
- count = capacity();
436
- }
437
- else {
438
- count = __NV_STD_MIN(capacity(), count);
439
- }
440
-
441
- device_memory::copy_to_host(
442
- host_data(), ptr_device_real, count);
443
-
444
- device_memory::copy_to_host(
445
- host_data_imag(), ptr_device_imag, count);
446
- }
447
-
448
- /// Copy data from a caller-supplied device pointer into host memory.
449
- void copy_in_device_to_device(
450
- Element const* ptr_device_real, ///< source device memory
451
- Element const* ptr_device_imag, ///< source device memory
452
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
453
-
454
- if (count < 0) {
455
- count = capacity();
456
- }
457
- else {
458
- count = __NV_STD_MIN(capacity(), count);
459
- }
460
-
461
- device_memory::copy_device_to_device(
462
- device_data(), ptr_device_real, count);
463
-
464
- device_memory::copy_device_to_device(
465
- device_data_imag(), ptr_device_imag, count);
466
- }
467
-
468
- /// Copy data from a caller-supplied device pointer into host memory.
469
- void copy_in_host_to_device(
470
- Element const* ptr_host_real, ///< source host memory
471
- Element const* ptr_host_imag, ///< source host memory
472
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
473
-
474
- if (count < 0) {
475
- count = capacity();
476
- }
477
- else {
478
- count = __NV_STD_MIN(capacity(), count);
479
- }
480
-
481
- device_memory::copy_to_device(
482
- device_data(), ptr_host_real, count);
483
-
484
- device_memory::copy_to_device(
485
- device_data_imag(), ptr_host_imag, count);
486
- }
487
-
488
- /// Copy data from a caller-supplied device pointer into host memory.
489
- void copy_in_host_to_host(
490
- Element const* ptr_host_real, ///< source host memory
491
- Element const* ptr_host_imag, ///< source host memory
492
- LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
493
-
494
- if (count < 0) {
495
- count = capacity();
496
- }
497
- else {
498
- count = __NV_STD_MIN(capacity(), count);
499
- }
500
-
501
- device_memory::copy_host_to_host(
502
- host_data(), ptr_host_real, count);
503
-
504
- device_memory::copy_host_to_host(
505
- host_data_imag(), ptr_host_imag, count);
506
- }
507
-
508
- /// Copy data from a caller-supplied device pointer into host memory.
509
- void copy_out_device_to_host(
510
- Element * ptr_host_real, ///< source device memory
511
- Element * ptr_host_imag, ///< source device memory
512
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
513
-
514
- if (count < 0) {
515
- count = capacity();
516
- }
517
- else {
518
- count = __NV_STD_MIN(capacity(), count);
519
- }
520
-
521
- device_memory::copy_to_host(
522
- ptr_host_real, device_data(), count);
523
-
524
- device_memory::copy_to_host(
525
- ptr_host_imag, device_data_imag(), count);
526
- }
527
-
528
- /// Copy data from a caller-supplied device pointer into host memory.
529
- void copy_out_device_to_device(
530
- Element * ptr_device_real, ///< source device memory
531
- Element * ptr_device_imag, ///< source device memory
532
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
533
-
534
- if (count < 0) {
535
- count = capacity();
536
- }
537
- else {
538
- count = __NV_STD_MIN(capacity(), count);
539
- }
540
-
541
- device_memory::copy_device_to_device(
542
- ptr_device_real, device_data(), count);
543
-
544
- device_memory::copy_device_to_device(
545
- ptr_device_imag, device_data_imag(), count);
546
- }
547
-
548
- /// Copy data from a caller-supplied device pointer into host memory.
549
- void copy_out_host_to_device(
550
- Element * ptr_device_real, ///< source device memory
551
- Element * ptr_device_imag, ///< source device memory
552
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
553
-
554
- if (count < 0) {
555
- count = capacity();
556
- }
557
- else {
558
- count = __NV_STD_MIN(capacity(), count);
559
- }
560
-
561
- device_memory::copy_to_device(
562
- ptr_device_real, host_data(), count);
563
-
564
- device_memory::copy_to_device(
565
- ptr_device_imag, host_data_imag(), count);
566
- }
567
-
568
- /// Copy data from a caller-supplied device pointer into host memory.
569
- void copy_out_host_to_host(
570
- Element * ptr_host_real, ///< source host memory
571
- Element * ptr_host_imag, ///< source host memory
572
- LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
573
-
574
- if (count < 0) {
575
- count = capacity();
576
- }
577
- else {
578
- count = __NV_STD_MIN(capacity(), count);
579
- }
580
-
581
- device_memory::copy_host_to_host(
582
- ptr_host_real, host_data(), count);
583
-
584
- device_memory::copy_host_to_host(
585
- ptr_host_imag, host_data_imag(), count);
586
- }
587
- };
588
-
589
- ///////////////////////////////////////////////////////////////////////////////////////////////////
590
-
591
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h DELETED
@@ -1,157 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief uncompress sparse matrix from the host side
34
- */
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/util/host_tensor.h"
39
- #include "cutlass/tensor_view.h"
40
- #include "cutlass/util/tensor_view_io.h"
41
- #include "cutlass/util/reference/host/gemm.h"
42
-
43
- namespace cutlass {
44
-
45
- // uncompress sparse tensor core A matrix
46
- template <typename ElementA, typename LayoutA, typename ElementE,
47
- typename LayoutE>
48
- void uncompress(TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
49
- TensorRef<ElementA, LayoutA> tensor_a,
50
- TensorRef<ElementE, LayoutE> tensor_e, int row, int col) {
51
- // How many uncompressed data we can get with ElementE meta data
52
- int DecompressedElementsPerElementE =
53
- 256 / cutlass::sizeof_bits<ElementA>::value;
54
-
55
- // Process 4bit meta data a time
56
- int step;
57
-
58
- // 1:2 or 2:4 or 4:8
59
- int a, b;
60
-
61
- if (cutlass::sizeof_bits<ElementA>::value == 4) {
62
- step = 8;
63
- a = 4;
64
- b = 8;
65
- } else if (cutlass::sizeof_bits<ElementA>::value == 8) {
66
- step = 4;
67
- a = 2;
68
- b = 4;
69
- } else if (cutlass::sizeof_bits<ElementA>::value == 16) {
70
- step = 4;
71
- a = 2;
72
- b = 4;
73
- } else if (cutlass::sizeof_bits<ElementA>::value == 32) {
74
- step = 2;
75
- a = 1;
76
- b = 2;
77
- }
78
-
79
- int ElementsPerE = (cutlass::sizeof_bits<ElementA>::value == 4) ? 2 : 1;
80
-
81
- for (int r = 0; r < row; ++r) {
82
- for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) {
83
-
84
- ElementE meta = tensor_e.at(MatrixCoord(r, c));
85
-
86
- for (int i = 0; i < DecompressedElementsPerElementE; i += step) {
87
- int e = (meta >> (i / step * 4)) & 0xf;
88
- int idx0 = e & 0x3;
89
- int idx1 = e >> 2;
90
-
91
- if (a == 1) idx0 = idx0 / 2;
92
-
93
- for (int ii = 0; ii < step; ii += ElementsPerE) {
94
- int real_col =
95
- c * DecompressedElementsPerElementE + i + ii;
96
- int compressed_col = (real_col / b) * a;
97
-
98
- if (ii == (idx0 * ElementsPerE)) {
99
- uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
100
- tensor_a.at(MatrixCoord(r, compressed_col));
101
- if (ElementsPerE == 2)
102
- uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
103
- tensor_a.at(MatrixCoord(r, compressed_col + 1));
104
- } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) {
105
- uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
106
- tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE));
107
- if (ElementsPerE == 2)
108
- uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
109
- tensor_a.at(
110
- MatrixCoord(r, compressed_col + ElementsPerE + 1));
111
- } else {
112
- uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
113
- ElementA(0);
114
- if (ElementsPerE == 2)
115
- uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
116
- ElementA(0);
117
- }
118
- }
119
- }
120
- }
121
- }
122
- }
123
-
124
- // uncompress ELL block sparse matrix
125
- template <typename ElementA, typename LayoutA,
126
- typename ElementE, typename LayoutE>
127
- void uncompress_ell_block_sparse(
128
- TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
129
- TensorRef<ElementA, LayoutA> tensor_a,
130
- TensorRef<ElementE, LayoutE> ell_idx,
131
- int rows, int cols,
132
- int ell_num_cols, int ell_blocksize) {
133
-
134
- for (int r = 0; r < rows / ell_blocksize; ++r) {
135
- for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) {
136
-
137
- ElementE idx = ell_idx.at(MatrixCoord(r, c));
138
-
139
- if (idx != -1) {
140
- int row_begin = r * ell_blocksize;
141
- int col_begin_real = idx * ell_blocksize;
142
- int col_begin = c * ell_blocksize;
143
-
144
- for (int i = 0; i < ell_blocksize; ++i) {
145
- for (int j = 0; j < ell_blocksize; ++j) {
146
- uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) =
147
- tensor_a.at(
148
- MatrixCoord(row_begin + i, col_begin +j));
149
- }
150
- }
151
- }
152
- }
153
- }
154
- }
155
-
156
- } // namespace cutlass
157
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h DELETED
@@ -1,38 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/numeric_types.h"
36
-
37
- // integer_sequence moved to cutlass/numeric_types.h
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp DELETED
@@ -1,472 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Utilities for mixed input data type kernels.
33
- */
34
-
35
- #pragma once
36
-
37
- #include <cuda.h>
38
- #include "cute/layout.hpp"
39
- #include "cute/tensor.hpp"
40
- #include "cute/arch/mma_sm90.hpp"
41
- #include "cutlass/cutlass.h"
42
- #include "cutlass/util/device_memory.h"
43
- #include "cutlass/util/reference/device/tensor_fill.h"
44
- #include "cute/util/type_traits.hpp"
45
-
46
- namespace cutlass {
47
-
48
- #define CUDA_CHECK(status) \
49
- { \
50
- cudaError_t error = status; \
51
- if (error != cudaSuccess) { \
52
- std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
53
- << " at line: " << __LINE__ << std::endl; \
54
- exit(EXIT_FAILURE); \
55
- } \
56
- }
57
-
58
- template <
59
- class QuantizedElement,
60
- class DequantizedElement,
61
- class OperandLayout,
62
- class ElementScale,
63
- class ElementZero,
64
- class ScaleBroadCastLayout,
65
- class ThrLayout>
66
- __global__ void dequantize_kernel(DequantizedElement* dq_buffer,
67
- QuantizedElement const* q_buffer,
68
- OperandLayout const operand_layout,
69
- ElementScale const* scale_buffer,
70
- ElementZero const* zero_buffer,
71
- ScaleBroadCastLayout const broadcasted_scale_layout,
72
- ThrLayout thr_layout) {
73
- using namespace cute;
74
-
75
- // Represent the full tensors to gmem elements.
76
- // These are expected to have shape [MN, K, L]
77
- cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
78
- cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr<QuantizedElement const>(q_buffer), operand_layout);
79
- // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
80
- // It is expected that K % G == 0
81
- cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
82
- cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
83
-
84
- // Assign 1 thread per element in the thread block
85
- auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); //
86
- auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
87
-
88
- // Tile across the block
89
- auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
90
- auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
91
- auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
92
- auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
93
-
94
- auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
95
- auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
96
- auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
97
- auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
98
-
99
- // Make a fragment of registers to hold gmem loads
100
- cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
101
- cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
102
- cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
103
- cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
104
- cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
105
- cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
106
-
107
- cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
108
- auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
109
- auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
110
-
111
- const auto num_iters = cute::size<3>(tOpDq_gOpDq);
112
-
113
- for (int ii = 0; ii < num_iters; ++ii) {
114
- const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
115
- if (thread_offset < cute::size<0>(operand_layout)) {
116
- cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
117
- cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
118
- cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
119
- cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
120
- cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
121
- cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{});
122
- cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{});
123
- cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
124
- cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
125
- }
126
- }
127
- }
128
-
129
- template <
130
- class QuantizedElement,
131
- class DequantizedElement,
132
- class OperandLayout,
133
- class ElementScale,
134
- class ElementZero,
135
- class ScaleLayout>
136
- static void dequantize(DequantizedElement* dq_buffer,
137
- QuantizedElement const* q_buffer,
138
- OperandLayout const operand_layout,
139
- ElementScale const* scale_buffer,
140
- ElementZero const* zero_buffer,
141
- ScaleLayout const scale_layout,
142
- int const group_size,
143
- cudaStream_t &stream) {
144
- using namespace cute;
145
-
146
- constexpr int tpb = 128;
147
- auto thr_layout = make_layout(make_shape(Int<tpb>{}));
148
-
149
- const auto num_rows = get<0>(shape(operand_layout));
150
- const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
151
- const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
152
- const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
153
-
154
- if (num_rows != size<0>(scale_layout)) {
155
- std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
156
- << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
157
- << std::endl;
158
- exit(-1);
159
- }
160
-
161
- const auto scale_stride0 = get<0>(stride(scale_layout));
162
- const auto scale_stride1 = get<1>(stride(scale_layout));
163
- const auto scale_stride2 = get<2>(stride(scale_layout));
164
-
165
- auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
166
- auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
167
- auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
168
-
169
- const auto blocks_x = gemm_k;
170
- const auto blocks_y = batches;
171
-
172
- dim3 blocks(blocks_x, blocks_y, 1);
173
- dequantize_kernel<<<blocks, tpb, 0, stream>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
174
- CUDA_CHECK(cudaStreamSynchronize(stream));
175
- }
176
-
177
- template <typename T>
178
- class packed_scale_t {
179
- public:
180
- static_assert(cute::is_same_v<T, cutlass::int8_t> ||
181
- cute::is_same_v<T, cutlass::uint8_t> ||
182
- cute::is_same_v<T, cutlass::float_e4m3_t> ||
183
- cute::is_same_v<T, cutlass::float_e5m2_t>,
184
- "only 8 bit arithmetic types are supported.");
185
- CUTLASS_HOST_DEVICE
186
- explicit packed_scale_t(T val) {
187
- if constexpr (!cute::is_unsigned_v<T>) {
188
- // Only pack negative values. The positive values are generated in flight in the mainloop.
189
- storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
190
- storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
191
- }
192
- else {
193
- storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
194
- storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
195
- }
196
- }
197
- CUTLASS_HOST_DEVICE
198
- packed_scale_t() = default;
199
- CUTLASS_HOST_DEVICE
200
- explicit operator float() const {
201
- return float(get());
202
- }
203
- CUTLASS_HOST_DEVICE
204
- bool operator==(packed_scale_t const& rhs) const {
205
- return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
206
- }
207
- CUTLASS_HOST_DEVICE
208
- bool operator!=(packed_scale_t const& rhs) const {
209
- return !(*this == rhs);
210
- }
211
- CUTLASS_HOST_DEVICE
212
- friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
213
- return packed_scale_t(lhs.get() + rhs.get());
214
- }
215
- CUTLASS_HOST_DEVICE
216
- friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
217
- return packed_scale_t(lhs.get() - rhs.get());
218
- }
219
- CUTLASS_HOST_DEVICE
220
- friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
221
- return packed_scale_t(lhs.get() * rhs.get());
222
- }
223
- CUTLASS_HOST_DEVICE
224
- friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
225
- return packed_scale_t(lhs.get() / rhs.get());
226
- }
227
-
228
- private:
229
- using Storage = uint32_t;
230
- using Stage = uint8_t;
231
-
232
- Storage storage[2] {};
233
-
234
- CUTLASS_HOST_DEVICE
235
- static Storage pack4(T c1, T c2, T c3, T c4) {
236
- Storage result = 0;
237
- result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
238
- result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
239
- result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
240
- result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
241
- return result;
242
- }
243
- CUTLASS_HOST_DEVICE
244
- T get() const {
245
- auto stage = static_cast<Stage>(storage[0] >> 8);
246
- #if defined(__CUDA_ARCH__)
247
- return reinterpret_cast<T const&>(stage);
248
- #else
249
- T tmp;
250
- std::memcpy(&tmp, &stage, sizeof(Stage));
251
- return tmp;
252
- #endif
253
- }
254
- CUTLASS_HOST_DEVICE
255
- T get(int idx) const {
256
- Stage stage;
257
- if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
258
- else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
259
- #if defined(__CUDA_ARCH__)
260
- return reinterpret_cast<T const&>(stage);
261
- #else
262
- T tmp;
263
- std::memcpy(&tmp, &stage, sizeof(Stage));
264
- return tmp;
265
- #endif
266
- }
267
- };
268
-
269
- // In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
270
- // Here the encodings of positive values and negative values are unified (except for the sign bit).
271
- // For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
272
- static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) {
273
-
274
- using StorageType = cutlass::int4b_t::Storage;
275
- constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
276
- const size_t host_buf_size = block_size / pack;
277
- std::vector<StorageType> host_buf(host_buf_size);
278
- cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size);
279
-
280
- for (auto&& d : host_buf) {
281
- StorageType out = 0;
282
- StorageType mask = 0x0f;
283
- for (int i = 0; i < pack; i++) {
284
- cutlass::int4b_t curr;
285
- curr.storage = (d >> (i * 4)) & 0x0f;
286
- switch (curr) {
287
- case 1: curr.storage = StorageType(0b0111); break; // 2's complement
288
- case 2: curr.storage = StorageType(0b0110); break; // 2's complement
289
- case 3: curr.storage = StorageType(0b0101); break; // 2's complement
290
- case 4: curr.storage = StorageType(0b0100); break; // 2's complement
291
- case 5: curr.storage = StorageType(0b0011); break; // 2's complement
292
- case 6: curr.storage = StorageType(0b0010); break; // 2's complement
293
- case 7: curr.storage = StorageType(0b0001); break; // 2's complement
294
- default: break;
295
- }
296
- out |= (curr.storage << (4 * i)) & mask;
297
- mask <<= 4;
298
- }
299
- d = out;
300
- }
301
-
302
- cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size);
303
- return true;
304
- }
305
-
306
- template <class ElementScale>
307
- static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array<ElementScale, 8> *block_out, const size_t block_size) {
308
- std::vector<ElementScale> data_in(block_size);
309
- std::vector<cutlass::Array<ElementScale, 8>> data_out(block_size);
310
-
311
- try {
312
- cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size);
313
- }
314
- catch (cutlass::cuda_exception const& e) {
315
- std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
316
- return false;
317
- }
318
-
319
- for (size_t i = 0; i < block_size; i++) {
320
- cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
321
- data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
322
- }
323
-
324
- try {
325
- cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size);
326
- }
327
- catch (cutlass::cuda_exception const& e) {
328
- std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
329
- return false;
330
- }
331
- return true;
332
- }
333
-
334
- template <class T, class = void>
335
- struct UnderlyingElement {
336
- using type = T;
337
- };
338
-
339
- template <class T>
340
- struct UnderlyingElement<T, cute::void_t<typename T::Element>> {
341
- using type = typename T::Element;
342
- };
343
-
344
- // Given a type of MMA instruction, compute a memory reordering atom that places all values
345
- // owned by each thread in contiguous memory locations. This improves smem load vectorization,
346
- // particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
347
- // of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
348
- // In addition, we can reorder the values across several MMA instructions to get even wider
349
- // vectorization (AtomLayout parameter) and permute the values within each instruction to get
350
- // more optimal conversion instruction sequences (ValLayout parameter).
351
- template <class ElementMma,
352
- class AtomLayout = cute::Layout<cute::_1>,
353
- class ValLayout = cute::Layout<cute::_1>>
354
- constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
355
- {
356
- using namespace cute;
357
-
358
- static_assert(is_static_v<ValLayout>, "ValLayout must be static");
359
- static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
360
-
361
- // 1. Choose an MMA atom to access TV layout and MN shape
362
- // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
363
- using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
364
- using MmaTraits = MMA_Traits<MmaAtom>;
365
- auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
366
- auto tv_layout_mma = typename MmaTraits::ALayout{};
367
- static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
368
-
369
- // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
370
- // Note: this assumes A is partitioned between warps along M mode
371
- auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
372
- auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
373
- auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
374
- auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
375
-
376
- // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
377
- auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
378
-
379
- // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
380
- auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
381
- auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
382
- auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
383
- auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
384
-
385
- return layout_atom;
386
- }
387
-
388
- template <class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
389
- __global__ void reorder_tensor_kernel(
390
- cute::Tensor<EngineSrc, LayoutSrc> S,
391
- cute::Tensor<EngineDst, LayoutDst> D,
392
- TiledCopy tiled_copy)
393
- {
394
- using namespace cute;
395
-
396
- using T = typename EngineDst::value_type;
397
-
398
- Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
399
- Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
400
-
401
- auto thread_copy = tiled_copy.get_slice(threadIdx.x);
402
- Tensor tS = thread_copy.partition_S(gS);
403
- Tensor tD = thread_copy.partition_D(gD);
404
-
405
- copy(tiled_copy, tS, tD);
406
- }
407
-
408
- template <class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
409
- void reorder_tensor(
410
- cute::Tensor<EngineSrc, LayoutSrc> S,
411
- cute::Tensor<EngineDst, LayoutDst> D)
412
- {
413
- using namespace cute;
414
-
415
- using T = typename EngineDst::value_type;
416
- static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
417
-
418
- // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
419
- // This avoids a race condition when writing out subbyte types (e.g. int4b_t).
420
- auto has_major_mode = [](auto s) {
421
- return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; });
422
- };
423
- static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
424
- "Could not find stride-1 mode in destination layout");
425
- constexpr int N = shape_div(Int<8>{}, Int<sizeof_bits_v<T>>{});
426
- auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
427
- make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
428
- make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
429
-
430
- // Make a tiled copy with a simple row-major thread order and above layout
431
- int constexpr NumThreads = 128;
432
- auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
433
- auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
434
-
435
- // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
436
- using TileShape = Shape<_16>;
437
- auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
438
- dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
439
-
440
- reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
441
- CUDA_CHECK(cudaDeviceSynchronize());
442
- }
443
-
444
- // In-place version
445
- template <class T, class LayoutSrc, class LayoutDst>
446
- void reorder_tensor(
447
- T const* src,
448
- LayoutSrc const& layout_src,
449
- T * dst,
450
- LayoutDst const& layout_dst)
451
- {
452
- using namespace cute;
453
- reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
454
- make_tensor(make_gmem_ptr<T>(dst), layout_dst));
455
- }
456
-
457
- // In-place version
458
- template <class T, class LayoutSrc, class LayoutDst>
459
- void reorder_tensor(
460
- T * data,
461
- LayoutSrc const& layout_src,
462
- LayoutDst const& layout_dst)
463
- {
464
- using namespace cute;
465
- cutlass::DeviceAllocation<T> temp(size(layout_src));
466
- reorder_tensor(data, layout_src, temp.get(), layout_dst);
467
- cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
468
- }
469
-
470
- #undef CUDA_CHECK
471
-
472
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp DELETED
@@ -1,570 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cute/layout.hpp"
38
- #include "cute/container/array.hpp" // cute::array
39
- #include "cutlass/conv/convolution.h" // cutlass::conv::Operator
40
-
41
- /////////////////////////////////////////////////////////////////////////////////////////////////
42
-
43
- namespace cutlass {
44
-
45
- /////////////////////////////////////////////////////////////////////////////////////////////////
46
-
47
- // Strides without batch mode
48
-
49
- template <class IntT>
50
- CUTLASS_HOST_DEVICE
51
- cute::Stride<IntT, cute::Int<1>>
52
- make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
53
- static_assert(std::is_integral_v<IntT>,
54
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
55
- auto s_copy = s;
56
- cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
57
- return s_copy;
58
- }
59
-
60
- template <class IntT>
61
- CUTLASS_HOST_DEVICE
62
- cute::Stride<cute::Int<1>, IntT>
63
- make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,int> shape_MKL) {
64
- static_assert(std::is_integral_v<IntT>,
65
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
66
- auto s_copy = s;
67
- cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
68
- return s_copy;
69
- }
70
-
71
- /////////////////////////////////////////////////////////////////////////////////////////////////
72
-
73
- // Strides with batch mode
74
-
75
- template <class IntT>
76
- CUTLASS_HOST_DEVICE
77
- cute::Stride<IntT, cute::Int<1>, int64_t>
78
- make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
79
- static_assert(std::is_integral_v<IntT>,
80
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
81
- auto s_copy = s;
82
- cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
83
- int batch_count = cute::get<2>(shape_MKL);
84
- if (batch_count > 1) {
85
- cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
86
- }
87
- else {
88
- cute::get<2>(s_copy) = static_cast<IntT>(0);
89
- }
90
- return s_copy;
91
- }
92
-
93
- template <class IntT>
94
- CUTLASS_HOST_DEVICE
95
- cute::Stride<cute::Int<1>, IntT, int64_t>
96
- make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
97
- static_assert(std::is_integral_v<IntT>,
98
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
99
- auto s_copy = s;
100
- cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
101
- int batch_count = cute::get<2>(shape_MKL);
102
- if (batch_count > 1) {
103
- cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
104
- }
105
- else {
106
- cute::get<2>(s_copy) = static_cast<IntT>(0);
107
- }
108
- return s_copy;
109
- }
110
-
111
- /////////////////////////////////////////////////////////////////////////////////////////////////
112
-
113
- // Strides with group mode
114
-
115
- template <class StrideIntT>
116
- CUTLASS_HOST_DEVICE
117
- cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
118
- make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
119
- static_assert(std::is_integral_v<StrideIntT>,
120
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
121
- auto s_copy = s;
122
- cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
123
- return s_copy;
124
- }
125
-
126
- template <class StrideIntT>
127
- CUTLASS_HOST_DEVICE
128
- cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
129
- make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
130
- static_assert(std::is_integral_v<StrideIntT>,
131
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
132
- auto s_copy = s;
133
- cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
134
- return s_copy;
135
- }
136
-
137
- /////////////////////////////////////////////////////////////////////////////////////////////////
138
-
139
- // Strides for convolutions
140
-
141
- // Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
142
- // Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order
143
- // and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout
144
- // right in KTRSC order and can be coalesced to just k.
145
- // We enforce this condition here with asserts.
146
- template <class IntT, size_t RankT_>
147
- CUTLASS_HOST_DEVICE
148
- cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
149
- make_cute_packed_stride(
150
- cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
151
- cute::array<int32_t, RankT_> shape_output,
152
- cute::array<IntT, RankT_> stride_output,
153
- cutlass::conv::Operator conv_op) {
154
- static_assert(std::is_integral_v<IntT>,
155
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
156
- static_assert(RankT_ >= 3u);
157
- constexpr static int RankT = static_cast<int>(RankT_);
158
-
159
- assert(stride_output[RankT-1] == 1);
160
- cute::for_each(cute::make_seq<RankT-2>{}, [&](auto i) {
161
- assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]);
162
- });
163
-
164
- auto s_copy = s;
165
- cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ?
166
- stride_output[0] :
167
- stride_output[RankT-2];
168
- return s_copy;
169
- }
170
-
171
- //
172
- // Activation tensor ((w, h, d, n), _1) for fprop kernel
173
- //
174
-
175
- // Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
176
- template <class IntT>
177
- CUTLASS_HOST_DEVICE
178
- cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
179
- make_cute_packed_stride(
180
- cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
181
- cute::array<IntT, 3> stride_nwc,
182
- conv::Operator ConvOp) {
183
- static_assert(std::is_integral_v<IntT>,
184
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
185
- assert(stride_nwc[2] == 1);
186
- auto s_copy = s;
187
- cute::get<0,0>(s_copy) = stride_nwc[1];
188
- cute::get<0,1>(s_copy) = stride_nwc[0];
189
- return s_copy;
190
- }
191
-
192
- // Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
193
- template <class IntT>
194
- CUTLASS_HOST_DEVICE
195
- cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
196
- make_cute_packed_stride(
197
- cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
198
- cute::array<IntT, 4> stride_nhwc,
199
- conv::Operator ConvOp) {
200
- static_assert(std::is_integral_v<IntT>,
201
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
202
- assert(stride_nhwc[3] == 1);
203
- auto s_copy = s;
204
- cute::for_each(cute::make_seq<3>{}, [&](auto i) {
205
- cute::get<0,i>(s_copy) = stride_nhwc[2-i];
206
- });
207
- return s_copy;
208
- }
209
-
210
- // Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
211
- template <class IntT>
212
- CUTLASS_HOST_DEVICE
213
- cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
214
- make_cute_packed_stride(
215
- cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
216
- cute::array<IntT, 5> stride_ndhwc,
217
- conv::Operator ConvOp) {
218
- static_assert(std::is_integral_v<IntT>,
219
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
220
-
221
- assert(stride_ndhwc[4] == 1);
222
- auto s_copy = s;
223
- cute::for_each(cute::make_seq<4>{}, [&](auto i) {
224
- cute::get<0,i>(s_copy) = stride_ndhwc[3-i];
225
- });
226
- return s_copy;
227
- }
228
-
229
- //
230
- // Filter tensor (k, (_1, s, r, t)) for fprop kernel
231
- //
232
-
233
- // Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
234
- template <class IntT>
235
- CUTLASS_HOST_DEVICE
236
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
237
- make_cute_packed_stride(
238
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
239
- cute::array<IntT, 3> stride_ksc,
240
- conv::Operator ConvOp) {
241
- static_assert(std::is_integral_v<IntT>,
242
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
243
-
244
- assert(stride_ksc[2] == 1);
245
- auto s_copy = s;
246
- cute::get<0,0>(s_copy) = stride_ksc[0];
247
- cute::get<1,1>(s_copy) = stride_ksc[1];
248
- return s_copy;
249
- }
250
-
251
- // Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
252
- template <class IntT>
253
- CUTLASS_HOST_DEVICE
254
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
255
- make_cute_packed_stride(
256
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
257
- cute::array<IntT, 4> stride_krsc,
258
- conv::Operator ConvOp) {
259
- static_assert(std::is_integral_v<IntT>,
260
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
261
-
262
- assert(stride_krsc[3] == 1);
263
- auto s_copy = s;
264
- cute::get<0,0>(s_copy) = stride_krsc[0];
265
- cute::for_each(cute::make_seq<2>{}, [&](auto i) {
266
- cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
267
- });
268
- return s_copy;
269
- }
270
-
271
- // Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
272
- template <class IntT>
273
- CUTLASS_HOST_DEVICE
274
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
275
- make_cute_packed_stride(
276
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
277
- cute::array<IntT, 5> stride_ktrsc,
278
- conv::Operator ConvOp) {
279
- static_assert(std::is_integral_v<IntT>,
280
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
281
-
282
- assert(stride_ktrsc[4] == 1);
283
- auto s_copy = s;
284
- cute::get<0,0>(s_copy) = stride_ktrsc[0];
285
- cute::for_each(cute::make_seq<3>{}, [&](auto i) {
286
- cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
287
- });
288
- return s_copy;
289
- }
290
-
291
- //
292
- // Activation tensor (_1, (w, h, d, n)) for wgrad kernel
293
- //
294
- // It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel
295
- //
296
-
297
- // Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
298
- // Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
299
- template <class IntT>
300
- CUTLASS_HOST_DEVICE
301
- cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
302
- make_cute_packed_stride(
303
- cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
304
- cute::array<IntT, 3> stride_nwc,
305
- conv::Operator ConvOp) {
306
- static_assert(std::is_integral_v<IntT>,
307
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
308
-
309
- assert(stride_nwc[2] == 1);
310
- auto s_copy = s;
311
- if (ConvOp == cutlass::conv::Operator::kWgrad) {
312
- cute::get<1,0>(s_copy) = stride_nwc[1];
313
- cute::get<1,1>(s_copy) = stride_nwc[0];
314
- }
315
- else if (ConvOp == cutlass::conv::Operator::kDgrad) {
316
- // stride_nwc in dgrad is ksc.
317
- cute::get<1,0>(s_copy) = stride_nwc[0];
318
- cute::get<1,1>(s_copy) = stride_nwc[1];
319
- }
320
- return s_copy;
321
- }
322
-
323
- // Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
324
- // Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
325
- template <class IntT>
326
- CUTLASS_HOST_DEVICE
327
- cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
328
- make_cute_packed_stride(
329
- cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
330
- cute::array<IntT, 4> stride_nhwc,
331
- conv::Operator ConvOp) {
332
- static_assert(std::is_integral_v<IntT>,
333
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
334
-
335
- assert(stride_nhwc[3] == 1);
336
- auto s_copy = s;
337
- if (ConvOp == cutlass::conv::Operator::kWgrad) {
338
- cute::for_each(cute::make_seq<3>{}, [&](auto i) {
339
- cute::get<1,i>(s_copy) = stride_nhwc[2-i];
340
- });
341
- }
342
- else if (ConvOp == cutlass::conv::Operator::kDgrad) {
343
- // stride_nhwc in dgrad is krsc.
344
- cute::get<1,0>(s_copy) = stride_nhwc[0];
345
- cute::for_each(cute::make_seq<2>{}, [&](auto i) {
346
- cute::get<1,2-i>(s_copy) = stride_nhwc[i+1];
347
- });
348
- }
349
- return s_copy;
350
- }
351
-
352
- // Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
353
- // Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
354
- template <class IntT>
355
- CUTLASS_HOST_DEVICE
356
- cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
357
- make_cute_packed_stride(
358
- cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
359
- cute::array<IntT, 5> stride_ndhwc,
360
- conv::Operator ConvOp) {
361
- static_assert(std::is_integral_v<IntT>,
362
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
363
-
364
- assert(stride_ndhwc[4] == 1);
365
- auto s_copy = s;
366
- if (ConvOp == cutlass::conv::Operator::kWgrad) {
367
- cute::for_each(cute::make_seq<4>{}, [&](auto i) {
368
- cute::get<1,i>(s_copy) = stride_ndhwc[3-i];
369
- });
370
- }
371
- else if (ConvOp == cutlass::conv::Operator::kDgrad) {
372
- // stride_ndhwc in dgrad is ktrsc.
373
- cute::get<1,0>(s_copy) = stride_ndhwc[0];
374
- cute::for_each(cute::make_seq<3>{}, [&](auto i) {
375
- cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1];
376
- });
377
- }
378
- return s_copy;
379
- }
380
-
381
- //
382
- // NZPQ tensor (_1, nzpq) for wgrad kernel
383
- //
384
-
385
- // cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
386
- template <class IntT>
387
- CUTLASS_HOST_DEVICE
388
- cute::Stride<cute::Int<1>, IntT>
389
- make_cute_packed_stride(
390
- cute::Stride<cute::Int<1>, IntT> s,
391
- cute::array<IntT, 3> stride_nqk,
392
- conv::Operator ConvOp) {
393
- static_assert(std::is_integral_v<IntT>,
394
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
395
-
396
- assert(stride_nqk[2] == 1);
397
- auto s_copy = s;
398
- cute::get<1>(s_copy) = stride_nqk[1];
399
- return s_copy;
400
- }
401
-
402
- // cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
403
- template <class IntT>
404
- CUTLASS_HOST_DEVICE
405
- cute::Stride<cute::Int<1>, IntT>
406
- make_cute_packed_stride(
407
- cute::Stride<cute::Int<1>, IntT> s,
408
- cute::array<IntT, 4> stride_npqk,
409
- conv::Operator ConvOp) {
410
- static_assert(std::is_integral_v<IntT>,
411
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
412
-
413
- assert(stride_npqk[3] == 1);
414
- auto s_copy = s;
415
- cute::get<1>(s_copy) = stride_npqk[2];
416
- return s_copy;
417
- }
418
-
419
- // cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
420
- template <class IntT>
421
- CUTLASS_HOST_DEVICE
422
- cute::Stride<cute::Int<1>, IntT>
423
- make_cute_packed_stride(
424
- cute::Stride<cute::Int<1>, IntT> s,
425
- cute::array<IntT, 5> stride_nzpqk,
426
- conv::Operator ConvOp) {
427
- static_assert(std::is_integral_v<IntT>,
428
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
429
-
430
- assert(stride_nzpqk[4] == 1);
431
- auto s_copy = s;
432
- cute::get<1>(s_copy) = stride_nzpqk[3];
433
- return s_copy;
434
- }
435
-
436
-
437
-
438
- //
439
- // Wgrad output tensor (k, (_1, s, r, t), _0)
440
- //
441
-
442
- // Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
443
- template <class IntT>
444
- CUTLASS_HOST_DEVICE
445
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
446
- make_cute_packed_stride(
447
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
448
- [[maybe_unused]] cute::array<int32_t, 3> shape_output,
449
- cute::array<IntT, 3> stride_ksc,
450
- conv::Operator ConvOp) {
451
- static_assert(std::is_integral_v<IntT>,
452
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
453
-
454
- assert(stride_ksc[2] == 1);
455
- auto s_copy = s;
456
- cute::get<0,0>(s_copy) = stride_ksc[0];
457
- cute::get<1,1>(s_copy) = stride_ksc[1];
458
- return s_copy;
459
- }
460
-
461
- // Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
462
- template <class IntT>
463
- CUTLASS_HOST_DEVICE
464
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
465
- make_cute_packed_stride(
466
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
467
- [[maybe_unused]] cute::array<int32_t, 4> shape_output,
468
- cute::array<IntT, 4> stride_krsc,
469
- conv::Operator ConvOp) {
470
- static_assert(std::is_integral_v<IntT>,
471
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
472
-
473
- assert(stride_krsc[3] == 1);
474
- auto s_copy = s;
475
- cute::get<0,0>(s_copy) = stride_krsc[0];
476
- cute::for_each(cute::make_seq<2>{}, [&](auto i) {
477
- cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
478
- });
479
- return s_copy;
480
- }
481
-
482
- // Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
483
- template <class IntT>
484
- CUTLASS_HOST_DEVICE
485
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
486
- make_cute_packed_stride(
487
- cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
488
- [[maybe_unused]] cute::array<int32_t, 5> shape_output,
489
- cute::array<IntT, 5> stride_ktrsc,
490
- conv::Operator ConvOp) {
491
- static_assert(std::is_integral_v<IntT>,
492
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
493
-
494
- assert(stride_ktrsc[4] == 1);
495
- auto s_copy = s;
496
- cute::get<0,0>(s_copy) = stride_ktrsc[0];
497
- cute::for_each(cute::make_seq<3>{}, [&](auto i) {
498
- cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
499
- });
500
- return s_copy;
501
- }
502
-
503
-
504
- //
505
- // Wgrad output tensor ((_1, s, r, t), k, _0)
506
- //
507
-
508
- // Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0)
509
- template <class IntT>
510
- CUTLASS_HOST_DEVICE
511
- cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>>
512
- make_cute_packed_stride(
513
- cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>> s,
514
- [[maybe_unused]] cute::array<int32_t, 3> shape_output,
515
- cute::array<IntT, 3> stride_ksc,
516
- conv::Operator ConvOp) {
517
- static_assert(std::is_integral_v<IntT>,
518
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
519
-
520
- assert(stride_ksc[2] == 1);
521
- auto s_copy = s;
522
- cute::get<1,0>(s_copy) = stride_ksc[0];
523
- cute::get<0,1>(s_copy) = stride_ksc[1];
524
- return s_copy;
525
- }
526
-
527
- // Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0)
528
- template <class IntT>
529
- CUTLASS_HOST_DEVICE
530
- cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>>
531
- make_cute_packed_stride(
532
- cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>> s,
533
- [[maybe_unused]] cute::array<int32_t, 4> shape_output,
534
- cute::array<IntT, 4> stride_krsc,
535
- conv::Operator ConvOp) {
536
- static_assert(std::is_integral_v<IntT>,
537
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
538
-
539
- assert(stride_krsc[3] == 1);
540
- auto s_copy = s;
541
- cute::get<1,0>(s_copy) = stride_krsc[0];
542
- cute::for_each(cute::make_seq<2>{}, [&](auto i) {
543
- cute::get<0,2-i>(s_copy) = stride_krsc[i+1];
544
- });
545
- return s_copy;
546
- }
547
-
548
- // Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0)
549
- template <class IntT>
550
- CUTLASS_HOST_DEVICE
551
- cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>>
552
- make_cute_packed_stride(
553
- cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>> s,
554
- [[maybe_unused]] cute::array<int32_t, 5> shape_output,
555
- cute::array<IntT, 5> stride_ktrsc,
556
- conv::Operator ConvOp) {
557
- static_assert(std::is_integral_v<IntT>,
558
- "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
559
-
560
- assert(stride_ktrsc[4] == 1);
561
- auto s_copy = s;
562
- cute::get<1,0>(s_copy) = stride_ktrsc[0];
563
- cute::for_each(cute::make_seq<3>{}, [&](auto i) {
564
- cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1];
565
- });
566
- return s_copy;
567
- }
568
- /////////////////////////////////////////////////////////////////////////////////////////////////
569
-
570
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp DELETED
@@ -1,341 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <array>
35
- #include <cassert>
36
- #include <cmath>
37
- #include <iostream>
38
- #include <type_traits>
39
-
40
- #include <cute/util/type_traits.hpp>
41
- #include <cute/tensor.hpp>
42
-
43
- #include <cute/numeric/numeric_types.hpp>
44
- #include <cute/numeric/complex.hpp>
45
-
46
- #include <cutlass/layout/layout.h>
47
-
48
- // The computed infinity norm does not include
49
- // any NaN column absolute-value sums.
50
- struct matrix_inf_norm_result {
51
- // Accumulate errors in double, as this is generally
52
- // the highest precision that the examples use.
53
- double inf_norm = 0.0;
54
- bool found_nan = false;
55
- };
56
-
57
- // In theory, cute::Tensor<ViewEngine<T*>, T> could be treated as a view type,
58
- // and thus passed by value (as std::span or std::string_view would be).
59
- // However, generic cute::Tensor are more like containers
60
- // and thus are best passed by reference or const reference.
61
- template <typename EngineType, typename LayoutType>
62
- matrix_inf_norm_result
63
- matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
64
- {
65
- using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
66
- using element_type = typename EngineType::value_type;
67
-
68
- error_type inf_norm = 0.0;
69
- bool found_nan = false;
70
-
71
- // Computing the infinity norm requires that we be able
72
- // to treat the input as a matrix, with rows and columns.
73
- const int64_t num_rows = cute::size<0>(host_matrix);
74
- const int64_t num_cols = cute::size<1>(host_matrix);
75
-
76
- auto abs_fn = [] (element_type A_ij) {
77
- if constexpr (not std::is_unsigned_v<element_type>) {
78
- using std::abs;
79
- return abs(A_ij);
80
- }
81
- else {
82
- return A_ij;
83
- }
84
- };
85
-
86
- for (int64_t i = 0; i < num_rows; ++i) {
87
- error_type row_abs_sum = 0.0;
88
- for(int64_t j = 0; j < num_cols; ++j) {
89
- row_abs_sum += abs_fn(host_matrix(i, j));
90
- }
91
- if (std::isnan(row_abs_sum)) {
92
- found_nan = true;
93
- }
94
- else {
95
- inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
96
- }
97
- }
98
-
99
- return {inf_norm, found_nan};
100
- }
101
-
102
- // Infinity norm of (X - Y).
103
- template <typename EngineType, typename LayoutType>
104
- matrix_inf_norm_result
105
- matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
106
- cute::Tensor<EngineType, LayoutType> const& Y)
107
- {
108
- using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
109
- using element_type = typename EngineType::value_type;
110
-
111
- auto abs_fn = [] (element_type A_ij) {
112
- if constexpr (not std::is_unsigned_v<element_type>) {
113
- using std::abs;
114
- return abs(A_ij);
115
- }
116
- else {
117
- return A_ij;
118
- }
119
- };
120
-
121
- assert(cute::size<0>(X) == cute::size<0>(Y));
122
- assert(cute::size<1>(X) == cute::size<1>(Y));
123
-
124
- // Computing the infinity norm requires that we be able
125
- // to treat the input as a matrix, with rows and columns.
126
- const int64_t num_rows = cute::size<0>(X);
127
- const int64_t num_cols = cute::size<1>(X);
128
-
129
- error_type inf_norm = 0.0;
130
- bool found_nan = false;
131
-
132
- for (int64_t i = 0; i < num_rows; ++i) {
133
- error_type row_abs_sum = 0.0;
134
- for (int64_t j = 0; j < num_cols; ++j) {
135
- row_abs_sum += error_type(abs_fn(element_type(X(i,j)) -
136
- element_type(Y(i,j))));
137
- }
138
- if (std::isnan(row_abs_sum)) {
139
- found_nan = true;
140
- }
141
- else {
142
- inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
143
- }
144
- }
145
-
146
- return {inf_norm, found_nan};
147
- }
148
-
149
- template <typename EngineType_A, typename LayoutType_A,
150
- typename EngineType_B, typename LayoutType_B,
151
- typename EngineType_C, typename LayoutType_C,
152
- typename EngineType_C_ref, typename LayoutType_C_ref>
153
- auto
154
- print_matrix_multiply_mollified_relative_error(
155
- char const A_value_type_name[],
156
- cute::Tensor<EngineType_A, LayoutType_A> const& A,
157
- char const B_value_type_name[],
158
- cute::Tensor<EngineType_B, LayoutType_B> const& B,
159
- char const C_value_type_name[],
160
- cute::Tensor<EngineType_C, LayoutType_C> const& C,
161
- cute::Tensor<EngineType_C_ref, LayoutType_C_ref> const& C_ref)
162
- {
163
- const auto [A_norm, A_has_nan] = matrix_inf_norm(A);
164
- const auto [B_norm, B_has_nan] = matrix_inf_norm(B);
165
- const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref);
166
- const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref);
167
-
168
- const auto A_norm_times_B_norm = A_norm * B_norm;
169
- const auto relative_error = A_norm_times_B_norm == 0.0 ?
170
- diff_norm : (diff_norm / A_norm_times_B_norm);
171
-
172
- // For expected error bounds, please refer to the LAPACK Users' Guide,
173
- // in particular https://netlib.org/lapack/lug/node108.html .
174
- // Printing the infinity norm of C is a way to check
175
- // that both the function being tested (C)
176
- // and the reference implementation (C_ref)
177
- // don't just do nothing (or fill with zeros).
178
- using std::cout;
179
- using cute::shape;
180
- cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n'
181
- << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
182
- << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
183
- << std::scientific
184
- << "Infinity norm of A: " << A_norm << '\n'
185
- << "Infinity norm of B: " << B_norm << '\n'
186
- << "Infinity norm of C: " << C_norm << '\n'
187
- << "Infinity norm of (C - C_ref): " << diff_norm << '\n';
188
-
189
- if(A_norm_times_B_norm == 0.0) {
190
- cout << "Mollified relative error: " << relative_error << '\n';
191
- } else {
192
- cout << "Relative error: " << relative_error << '\n';
193
- }
194
-
195
- if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) {
196
- cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
197
- << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
198
- << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
199
- << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
200
- }
201
- return relative_error;
202
- }
203
-
204
- template <typename EngineType, typename LayoutType>
205
- auto
206
- print_matrix_multiply_mollified_relative_error(
207
- const char value_type_name[],
208
- const cute::Tensor<EngineType, LayoutType>& A,
209
- const cute::Tensor<EngineType, LayoutType>& B,
210
- const cute::Tensor<EngineType, LayoutType>& C_computed,
211
- const cute::Tensor<EngineType, LayoutType>& C_expected)
212
- {
213
- return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
214
- value_type_name, C_computed, C_expected);
215
- }
216
-
217
- // Take a CUTLASS HostTensor (or the like) as input,
218
- // and return a const CuTe Tensor.
219
- // This is useful for use with the above error printing functions.
220
- // This implicitly "transposes" if the layout is RowMajor.
221
- // Note that the HostTensor must be captured by nonconst reference
222
- // in order for X.host_ref().data() to compile.
223
- // (CUTLASS is a bit more container-y than CuTe.)
224
- template<class CutlassHostTensorType>
225
- auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X)
226
- {
227
- // The tensors were created with post-transposed extents.
228
- const auto extents = X.extent();
229
- const auto shape = cute::Shape<int, int>{extents[0], extents[1]};
230
- // Both RowMajor and ColumnMajor only store one stride.
231
- const int LDX = X.stride(0);
232
- const auto strides = [&]() {
233
- using input_layout_type = typename std::decay_t<decltype(X)>::Layout;
234
- if constexpr (std::is_same_v<input_layout_type, cutlass::layout::ColumnMajor>) {
235
- return cute::Stride<int, int>{1, LDX};
236
- }
237
- else {
238
- static_assert(std::is_same_v<input_layout_type, cutlass::layout::RowMajor>);
239
- return cute::Stride<int, int>{LDX, 1};
240
- }
241
- }();
242
- const auto layout = cute::make_layout(shape, strides);
243
- auto X_data = X.host_ref().data();
244
- auto X_data_const = const_cast<std::add_const_t< decltype(X_data)> >(X_data);
245
- return cute::make_tensor(X_data_const, layout);
246
- };
247
-
248
-
249
- // Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE.
250
- // This makes the return value suitable as the return value of main().
251
- template <typename T1, typename T2>
252
- int
253
- print_relative_error(
254
- std::size_t n,
255
- T1 const& data,
256
- T2 const& reference,
257
- bool print_verbose = false,
258
- bool print_error = true,
259
- double error_margin = 0.00001) {
260
- using std::abs; using std::sqrt;
261
-
262
- // Use either double or complex<double> for error computation
263
- using value_type = cute::remove_cvref_t<decltype(reference[0])>;
264
- using error_type = std::conditional_t<cute::is_complex<value_type>::value,
265
- cute::complex<double>,
266
- double>;
267
-
268
- if (print_verbose) {
269
- std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl;
270
- }
271
-
272
- double eps = 1e-200;
273
-
274
- double tot_error_sq = 0;
275
- double tot_norm_sq = 0;
276
- double tot_ind_rel_err = 0;
277
- double max_ind_rel_err = 0;
278
- double max_diff = 0;
279
- for (std::size_t i = 0; i < n; ++i) {
280
- error_type val = data[i];
281
- error_type ref = reference[i];
282
-
283
- double aref = abs(ref);
284
- double diff = abs(ref - val);
285
- double rel_error = diff / (aref + eps);
286
-
287
- // Individual relative error
288
- tot_ind_rel_err += rel_error;
289
-
290
- // Maximum relative error
291
- max_ind_rel_err = std::max(max_ind_rel_err, rel_error);
292
-
293
- // Maximum delta in value error
294
- max_diff = std::max(max_diff, diff);
295
-
296
- // Total relative error
297
- tot_error_sq += diff * diff;
298
- tot_norm_sq += aref * aref;
299
-
300
- if (print_verbose) {
301
- std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl;
302
- }
303
- }
304
-
305
- double ave_rel_err = tot_ind_rel_err / double(n);
306
- if (print_error) {
307
- printf("Average relative error: %.3e\n", ave_rel_err);
308
- }
309
-
310
- if (print_error) {
311
- printf("Maximum relative error: %.3e\n", max_ind_rel_err);
312
- }
313
-
314
- if (print_error) {
315
- printf("Maximum difference : %.3e\n", max_diff);
316
- }
317
-
318
- double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps));
319
- if (print_error) {
320
- printf("Vector relative error: %.3e\n", tot_rel_err);
321
- }
322
-
323
- printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq));
324
-
325
- return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE;
326
- }
327
-
328
- // Overload for cute::Tensor<>
329
- template <class Engine, class Layout>
330
- int
331
- print_relative_error(
332
- cute::Tensor<Engine, Layout> data,
333
- cute::Tensor<Engine, Layout> reference,
334
- bool print_verbose = false,
335
- bool print_error = true,
336
- double error_margin = 0.00001) {
337
- assert(size(data) == size(reference));
338
- return print_relative_error(static_cast<std::size_t>(size(data)),
339
- data, reference,
340
- print_verbose, print_error, error_margin);
341
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h DELETED
@@ -1,135 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for GEMM in host-side code.
33
- */
34
- #pragma once
35
-
36
- #include "cutlass/cutlass.h"
37
- #include "cutlass/array.h"
38
-
39
- namespace cutlass {
40
- namespace reference {
41
- namespace detail {
42
-
43
- ////////////////////////////////////////////////////////////////////////////////////////////////////
44
-
45
- /// Template function to compute an inner product.
46
- #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
47
- // host-only type
48
- template <typename Atype, typename Btype, typename Ctype>
49
- CUTLASS_HOST_DEVICE
50
- Ctype inner_product(Atype a, Btype b, Ctype c) {
51
- return Ctype(a) * Ctype(b) + c;
52
- }
53
-
54
- /// Specialization for matrix multiplication with binary operands
55
- template <>
56
- CUTLASS_HOST_DEVICE
57
- int inner_product<Array<bin1_t, 32>, Array<bin1_t, 32>, int>(
58
- Array<bin1_t, 32> a,
59
- Array<bin1_t, 32> b,
60
- int c) {
61
-
62
- int accum = 0;
63
- for (int bit = 0; bit < 32; bit++) {
64
- accum += a[bit] ^ b[bit];
65
- }
66
- return accum + c;
67
- }
68
-
69
- /*
70
- /// Specialization for matrix multiplication with signed 4-bit integer operands
71
- template <>
72
- CUTLASS_HOST_DEVICE
73
- int inner_product<Array<int4b_t, 8>, Array<int4b_t, 8>, int>(
74
- Array<int4b_t, 8> a,
75
- Array<int4b_t, 8> b,
76
- int c) {
77
-
78
- int accum = 0;
79
- for (int k = 0; k < 8; k++) {
80
- accum += a[k] * b[k];
81
- }
82
- return accum + c;
83
- }
84
-
85
- /// Specialization for matrix multiplication with unsigned 4-bit integer operands
86
- template <>
87
- CUTLASS_HOST_DEVICE
88
- int inner_product<Array<uint4b_t, 8>, Array<uint4b_t, 8>, int>(
89
- Array<uint4b_t, 8> a,
90
- Array<uint4b_t, 8> b,
91
- int c) {
92
-
93
- int accum = 0;
94
- for (int k = 0; k < 8; k++) {
95
- accum += a[k] * b[k];
96
- }
97
- return accum + c;
98
- }
99
- */
100
-
101
- ////////////////////////////////////////////////////////////////////////////////////////////////////
102
-
103
- template <typename SrcType, typename DstType>
104
- struct Cast {
105
- // Default behavior: convert to the destination type
106
- #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
107
- // host-only type
108
- CUTLASS_HOST_DEVICE
109
- static DstType apply(SrcType src) { return static_cast<DstType>(src); };
110
- };
111
-
112
- template <>
113
- struct Cast<float, int8_t> {
114
- CUTLASS_HOST_DEVICE
115
- static int8_t apply(float src) {
116
- // Clamp to the range of signed 8-bit integers.
117
- return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
118
- };
119
- };
120
-
121
- template <>
122
- struct Cast<float, uint8_t> {
123
- CUTLASS_HOST_DEVICE
124
- static uint8_t apply(float src) {
125
- // Clamp to the range of signed 8-bit integers.
126
- return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
127
- };
128
- };
129
-
130
- ////////////////////////////////////////////////////////////////////////////////////////////////////
131
-
132
- } // namespace detail
133
- } // namespace reference
134
- } // namespace cutlass
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h DELETED
@@ -1,94 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for GEMM in host-side code.
33
- */
34
- #pragma once
35
-
36
- #include "cutlass/cutlass.h"
37
- #include "cutlass/coord.h"
38
-
39
- /////////////////////////////////////////////////////////////////////////////////////////////////
40
-
41
- namespace cutlass {
42
- namespace reference {
43
- namespace detail {
44
-
45
- /////////////////////////////////////////////////////////////////////////////////////////////////
46
-
47
- template <int Rank, int Index>
48
- struct LinearToCoordinateHelper {
49
-
50
- CUTLASS_HOST_DEVICE
51
- void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
52
-
53
- int64_t prod = 1;
54
-
55
- CUTLASS_PRAGMA_UNROLL
56
- for (int i = Rank - Index; i < Rank; ++i) {
57
- prod *= int64_t(extent[i]);
58
- }
59
-
60
- coord[Rank - Index - 1] = int(idx / prod);
61
-
62
- int64_t residual = idx % prod;
63
- LinearToCoordinateHelper<Rank, Index - 1>()(coord, residual, extent);
64
- }
65
- };
66
-
67
- template <int Rank>
68
- struct LinearToCoordinateHelper<Rank, 0> {
69
-
70
- CUTLASS_HOST_DEVICE
71
- void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &) const {
72
- coord[Rank - 1] = int(idx);
73
- }
74
- };
75
-
76
- /////////////////////////////////////////////////////////////////////////////////////////////////
77
-
78
- template <int Rank>
79
- struct LinearToCoordinate {
80
-
81
- CUTLASS_HOST_DEVICE
82
- void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
83
- LinearToCoordinateHelper<Rank, Rank - 1>()(coord, idx, extent);
84
- }
85
- };
86
-
87
- /////////////////////////////////////////////////////////////////////////////////////////////////
88
-
89
- } // namespace detail
90
- } // namespace reference
91
- } // namespace cutlass
92
-
93
- /////////////////////////////////////////////////////////////////////////////////////////////////
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h DELETED
@@ -1,1549 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Reference implementation for convolution in device-side code.
34
- */
35
-
36
- #pragma once
37
-
38
- #include "cutlass/coord.h"
39
- #include "cutlass/functional.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/matrix_shape.h"
42
- #include "cutlass/numeric_conversion.h"
43
- #include "cutlass/numeric_types.h"
44
- #include "cutlass/tensor_ref.h"
45
- #include "cutlass/conv/convolution.h"
46
- #include "cutlass/conv/conv2d_problem_size.h"
47
- #include "cutlass/conv/conv3d_problem_size.h"
48
-
49
- namespace cutlass {
50
- namespace reference {
51
- namespace device {
52
-
53
- /////////////////////////////////////////////////////////////////////////////////////////////////
54
-
55
- namespace kernel {
56
-
57
- ////////////////////////////////////////////////////////////////////////////////////////////////////
58
- /// Conv2d device reference kernel
59
- ////////////////////////////////////////////////////////////////////////////////////////////////////
60
-
61
- // Conv2d Fprop kernel - y = fprop(x, w)
62
- template <
63
- typename ElementA,
64
- typename LayoutA,
65
- typename ElementB,
66
- typename LayoutB,
67
- typename ElementC,
68
- typename LayoutC,
69
- typename ElementCompute,
70
- typename ElementAccumulator = ElementCompute,
71
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
72
- typename InnerProductOp = multiply_add<ElementAccumulator>,
73
- int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
74
- int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
75
- int kCtaShapeM = 16, // shape of a threadblock in units of threads
76
- int kCtaShapeN = 8 // shape of a threadblock in units of threads
77
- >
78
- __global__ void Conv2dFprop(
79
- conv::Conv2dProblemSize problem_size,
80
- TensorRef<ElementA, LayoutA> tensor_x,
81
- TensorRef<ElementB, LayoutB> tensor_w,
82
- TensorRef<ElementC, LayoutC> tensor_y_in,
83
- TensorRef<ElementC, LayoutC> tensor_y_out,
84
- ElementCompute alpha,
85
- ElementCompute beta
86
- ) {
87
-
88
- ConvertOp convert_op;
89
- InnerProductOp inner_product_op;
90
-
91
- ElementAccumulator element_A[kThreadM];
92
- ElementAccumulator element_B[kThreadN];
93
- ElementAccumulator accum[kThreadM][kThreadN];
94
-
95
- int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
96
- int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
97
-
98
- int thread_n[kThreadM];
99
- int thread_p[kThreadM];
100
- int thread_q[kThreadM];
101
-
102
- // Compute N, P, Q coordinates for each row of a thread's tile
103
- int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
104
-
105
- CUTLASS_PRAGMA_UNROLL
106
- for (int m = 0; m < kThreadM; ++m) {
107
-
108
- int64_t npq = npq_start + m;
109
-
110
- thread_n[m] = int(npq / PQ);
111
-
112
- int64_t residual = npq % PQ;
113
- thread_p[m] = int(residual / problem_size.Q);
114
- thread_q[m] = int(residual % problem_size.Q);
115
- }
116
-
117
- // Clear accumulators
118
- CUTLASS_PRAGMA_UNROLL
119
- for (int m = 0; m < kThreadM; ++m) {
120
- CUTLASS_PRAGMA_UNROLL
121
- for (int n = 0; n < kThreadN; ++n) {
122
- accum[m][n] = ElementAccumulator();
123
- }
124
- }
125
-
126
- int c_per_group = problem_size.C / problem_size.groups;
127
- int k_per_group = problem_size.K / problem_size.groups;
128
-
129
- // Compute convolution
130
- for (int R = 0; R < problem_size.R; ++R) {
131
- for (int S = 0; S < problem_size.S; ++S) {
132
- for (int C = 0; C < problem_size.C; ++C) {
133
-
134
- // Get group id of currnet channel
135
- int c_group_idx = C / c_per_group;
136
-
137
- // Load from activations tensor
138
- int filter_r = R;
139
- int filter_s = S;
140
-
141
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
142
- filter_r = problem_size.R - 1 - R;
143
- filter_s = problem_size.S - 1 - S;
144
- }
145
-
146
- CUTLASS_PRAGMA_UNROLL
147
- for (int m = 0; m < kThreadM; ++m) {
148
- int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
149
- int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
150
-
151
- if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
152
- element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C}));
153
- }
154
- else {
155
- element_A[m] = ElementAccumulator();
156
- }
157
- }
158
-
159
- // Load from filters tensor
160
- CUTLASS_PRAGMA_UNROLL
161
- for (int n = 0; n < kThreadN; ++n) {
162
- int thread_k = k_start + n;
163
- int k_group_idx = thread_k / k_per_group;
164
-
165
- if (thread_k < problem_size.K && k_group_idx == c_group_idx) {
166
- element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group}));
167
- }
168
- else {
169
- element_B[n] = ElementAccumulator();
170
- }
171
- }
172
-
173
- // Accumulate matrix product
174
- CUTLASS_PRAGMA_UNROLL
175
- for (int m = 0; m < kThreadM; ++m) {
176
- CUTLASS_PRAGMA_UNROLL
177
- for (int n = 0; n < kThreadN; ++n) {
178
- accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
179
- }
180
- }
181
- }
182
- }
183
- }
184
-
185
- // Write out the results
186
- CUTLASS_PRAGMA_UNROLL
187
- for (int m = 0; m < kThreadM; ++m) {
188
- if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
189
- CUTLASS_PRAGMA_UNROLL
190
- for (int n = 0; n < kThreadN; ++n) {
191
- int thread_k = k_start + n;
192
- if (thread_k < problem_size.K) {
193
-
194
- ElementCompute c_ref = ElementCompute();
195
- if (beta != ElementCompute()) {
196
- c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
197
- }
198
-
199
- tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
200
- alpha * ElementCompute(accum[m][n]) + beta * c_ref);
201
- }
202
- }
203
- }
204
- }
205
- }
206
-
207
- // Conv3d Fprop kernel - y = fprop(x, w)
208
- template <
209
- typename ElementA,
210
- typename LayoutA,
211
- typename ElementB,
212
- typename LayoutB,
213
- typename ElementC,
214
- typename LayoutC,
215
- typename ElementCompute,
216
- typename ElementAccumulator = ElementCompute,
217
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
218
- typename InnerProductOp = multiply_add<ElementAccumulator>,
219
- int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
220
- int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
221
- int kCtaShapeM = 16, // shape of a threadblock in units of threads
222
- int kCtaShapeN = 8 // shape of a threadblock in units of threads
223
- >
224
- __global__ void Conv3dFprop(
225
- conv::Conv3dProblemSize problem_size,
226
- TensorRef<ElementA, LayoutA> tensor_x,
227
- TensorRef<ElementB, LayoutB> tensor_w,
228
- TensorRef<ElementC, LayoutC> tensor_y_in,
229
- TensorRef<ElementC, LayoutC> tensor_y_out,
230
- ElementCompute alpha,
231
- ElementCompute beta
232
- ) {
233
-
234
- ConvertOp convert_op;
235
- InnerProductOp inner_product_op;
236
-
237
- ElementAccumulator element_A[kThreadM];
238
- ElementAccumulator element_B[kThreadN];
239
- ElementAccumulator accum[kThreadM][kThreadN];
240
-
241
- int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
242
- int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
243
-
244
- int thread_n[kThreadM];
245
- int thread_z[kThreadM];
246
- int thread_p[kThreadM];
247
- int thread_q[kThreadM];
248
-
249
- // Compute N, Z, P, Q coordinates for each row of a thread's tile
250
- int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
251
- int64_t ZPQ = PQ * problem_size.Z;
252
-
253
- CUTLASS_PRAGMA_UNROLL
254
- for (int m = 0; m < kThreadM; ++m) {
255
-
256
- int64_t nzpq = nzpq_start + m;
257
-
258
- thread_n[m] = int(nzpq / ZPQ);
259
-
260
- int64_t residual = nzpq % ZPQ;
261
- thread_z[m] = int(residual / PQ);
262
-
263
- residual = residual % PQ;
264
- thread_p[m] = int(residual / problem_size.Q);
265
- thread_q[m] = int(residual % problem_size.Q);
266
- }
267
-
268
- // Clear accumulators
269
- CUTLASS_PRAGMA_UNROLL
270
- for (int m = 0; m < kThreadM; ++m) {
271
- CUTLASS_PRAGMA_UNROLL
272
- for (int n = 0; n < kThreadN; ++n) {
273
- accum[m][n] = ElementAccumulator();
274
- }
275
- }
276
-
277
- // Compute convolution
278
- for (int T = 0; T < problem_size.T; ++T) {
279
- for (int R = 0; R < problem_size.R; ++R) {
280
- for (int S = 0; S < problem_size.S; ++S) {
281
- for (int C = 0; C < problem_size.C; ++C) {
282
-
283
- // Load from activations tensor
284
- int filter_t = T;
285
- int filter_r = R;
286
- int filter_s = S;
287
-
288
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
289
- filter_t = problem_size.T - 1 - T;
290
- filter_r = problem_size.R - 1 - R;
291
- filter_s = problem_size.S - 1 - S;
292
- }
293
-
294
- CUTLASS_PRAGMA_UNROLL
295
- for (int m = 0; m < kThreadM; ++m) {
296
- int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
297
- int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
298
- int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
299
-
300
- if (thread_n[m] < problem_size.N &&
301
- d >= 0 && d < problem_size.D &&
302
- h >= 0 && h < problem_size.H &&
303
- w >= 0 && w < problem_size.W) {
304
-
305
- element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C}));
306
- }
307
- else {
308
- element_A[m] = ElementAccumulator();
309
- }
310
- }
311
-
312
- // Load from filters tensor
313
- CUTLASS_PRAGMA_UNROLL
314
- for (int n = 0; n < kThreadN; ++n) {
315
- int thread_k = k_start + n;
316
-
317
- if (thread_k < problem_size.K) {
318
- element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C}));
319
- }
320
- else {
321
- element_B[n] = ElementAccumulator();
322
- }
323
- }
324
-
325
- // Accumulate matrix product
326
- CUTLASS_PRAGMA_UNROLL
327
- for (int m = 0; m < kThreadM; ++m) {
328
- CUTLASS_PRAGMA_UNROLL
329
- for (int n = 0; n < kThreadN; ++n) {
330
- accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
331
- }
332
- }
333
-
334
- } // for (C)
335
- } // for (S)
336
- } // for (R)
337
- } // for (T)
338
-
339
- // Write out the results
340
- CUTLASS_PRAGMA_UNROLL
341
- for (int m = 0; m < kThreadM; ++m) {
342
-
343
- if (thread_n[m] < problem_size.N &&
344
- thread_z[m] < problem_size.Z &&
345
- thread_p[m] < problem_size.P &&
346
- thread_q[m] < problem_size.Q) {
347
-
348
- CUTLASS_PRAGMA_UNROLL
349
- for (int n = 0; n < kThreadN; ++n) {
350
- int thread_k = k_start + n;
351
- if (thread_k < problem_size.K) {
352
-
353
- ElementCompute c_ref = ElementCompute();
354
- if (beta != ElementCompute()) {
355
- c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}));
356
- }
357
-
358
- tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
359
- alpha * ElementCompute(accum[m][n]) + beta * c_ref);
360
- }
361
- } // for (n)
362
-
363
- }
364
- } // for (m)
365
- }
366
-
367
- ///////////////////////////////////////////////////////////////////////////////////////////////////
368
-
369
- // Conv2d dgrad kernel - dx = dgrad(dy, w)
370
- template <
371
- typename ElementA,
372
- typename LayoutA,
373
- typename ElementB,
374
- typename LayoutB,
375
- typename ElementC,
376
- typename LayoutC,
377
- typename ElementCompute,
378
- typename ElementAccumulator = ElementCompute,
379
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
380
- typename InnerProductOp = multiply_add<ElementAccumulator>,
381
- int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
382
- int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
383
- int kCtaShapeM = 16, // shape of a threadblock in units of threads
384
- int kCtaShapeN = 8 // shape of a threadblock in units of threads
385
- >
386
- __global__ void Conv2dDgrad(
387
- conv::Conv2dProblemSize problem_size,
388
- TensorRef<ElementA, LayoutA> tensor_dy,
389
- TensorRef<ElementB, LayoutB> tensor_w,
390
- TensorRef<ElementC, LayoutC> tensor_dx_in,
391
- TensorRef<ElementC, LayoutC> tensor_dx_out,
392
- ElementCompute alpha,
393
- ElementCompute beta
394
- ) {
395
-
396
- ConvertOp convert_op;
397
- InnerProductOp inner_product_op;
398
-
399
- ElementAccumulator element_A[kThreadM];
400
- ElementAccumulator element_B[kThreadN];
401
- ElementAccumulator accum[kThreadM][kThreadN];
402
-
403
- int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
404
- int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
405
-
406
- int thread_n[kThreadM];
407
- int thread_h[kThreadM];
408
- int thread_w[kThreadM];
409
-
410
- // Compute N, H, W coordinates for each row of a thread's tile
411
- int64_t HW = int64_t(problem_size.H) * problem_size.W;
412
-
413
- CUTLASS_PRAGMA_UNROLL
414
- for (int m = 0; m < kThreadM; ++m) {
415
-
416
- int64_t nhw = nhw_start + m;
417
-
418
- thread_n[m] = int(nhw / HW);
419
-
420
- int64_t residual = nhw % HW;
421
- thread_h[m] = int(residual / problem_size.W);
422
- thread_w[m] = int(residual % problem_size.W);
423
- }
424
-
425
- // Clear accumulators
426
- CUTLASS_PRAGMA_UNROLL
427
- for (int m = 0; m < kThreadM; ++m) {
428
- CUTLASS_PRAGMA_UNROLL
429
- for (int n = 0; n < kThreadN; ++n) {
430
- accum[m][n] = ElementAccumulator();
431
- }
432
- }
433
-
434
- // Compute convolution
435
- for (int R = 0; R < problem_size.R; ++R) {
436
- for (int S = 0; S < problem_size.S; ++S) {
437
- for (int K = 0; K < problem_size.K; ++K) {
438
-
439
- // Load from activations tensor
440
- int filter_r = R;
441
- int filter_s = S;
442
-
443
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
444
- filter_r = problem_size.R - 1 - R;
445
- filter_s = problem_size.S - 1 - S;
446
- }
447
-
448
- CUTLASS_PRAGMA_UNROLL
449
- for (int m = 0; m < kThreadM; ++m) {
450
-
451
- int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
452
- int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
453
-
454
- element_A[m] = ElementAccumulator();
455
-
456
- if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) {
457
-
458
- p = p / problem_size.stride_h;
459
- q = q / problem_size.stride_w;
460
-
461
- if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) {
462
- element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K}));
463
- }
464
- }
465
- }
466
-
467
- // Load from filters tensor
468
- CUTLASS_PRAGMA_UNROLL
469
- for (int n = 0; n < kThreadN; ++n) {
470
- int thread_c = c_start + n;
471
-
472
- if (thread_c < problem_size.C) {
473
- element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c}));
474
- }
475
- else {
476
- element_B[n] = ElementAccumulator();
477
- }
478
- }
479
-
480
- // Accumulate matrix product
481
- CUTLASS_PRAGMA_UNROLL
482
- for (int m = 0; m < kThreadM; ++m) {
483
- CUTLASS_PRAGMA_UNROLL
484
- for (int n = 0; n < kThreadN; ++n) {
485
- accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
486
- }
487
- }
488
- }
489
- }
490
- }
491
-
492
- // Write out the results
493
- CUTLASS_PRAGMA_UNROLL
494
- for (int m = 0; m < kThreadM; ++m) {
495
-
496
- if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) {
497
-
498
- CUTLASS_PRAGMA_UNROLL
499
- for (int n = 0; n < kThreadN; ++n) {
500
- int thread_c = c_start + n;
501
- if (thread_c < problem_size.C) {
502
-
503
- ElementCompute c_ref = ElementCompute();
504
- if (beta != ElementCompute()) {
505
- c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c}));
506
- }
507
-
508
- tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
509
- alpha * ElementCompute(accum[m][n]) + beta * c_ref);
510
- }
511
- }
512
- }
513
- }
514
- }
515
-
516
- // Conv3d dgrad kernel - dx = dgrad(dy, w)
517
- template <
518
- typename ElementA,
519
- typename LayoutA,
520
- typename ElementB,
521
- typename LayoutB,
522
- typename ElementC,
523
- typename LayoutC,
524
- typename ElementCompute,
525
- typename ElementAccumulator = ElementCompute,
526
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
527
- typename InnerProductOp = multiply_add<ElementAccumulator>,
528
- int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
529
- int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
530
- int kCtaShapeM = 16, // shape of a threadblock in units of threads
531
- int kCtaShapeN = 8 // shape of a threadblock in units of threads
532
- >
533
- __global__ void Conv3dDgrad(
534
- conv::Conv3dProblemSize problem_size,
535
- TensorRef<ElementA, LayoutA> tensor_dy,
536
- TensorRef<ElementB, LayoutB> tensor_w,
537
- TensorRef<ElementC, LayoutC> tensor_dx_in,
538
- TensorRef<ElementC, LayoutC> tensor_dx_out,
539
- ElementCompute alpha,
540
- ElementCompute beta
541
- ) {
542
-
543
- ConvertOp convert_op;
544
- InnerProductOp inner_product_op;
545
-
546
- ElementAccumulator element_A[kThreadM];
547
- ElementAccumulator element_B[kThreadN];
548
- ElementAccumulator accum[kThreadM][kThreadN];
549
-
550
- int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
551
- int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
552
-
553
- int thread_n[kThreadM];
554
- int thread_d[kThreadM];
555
- int thread_h[kThreadM];
556
- int thread_w[kThreadM];
557
-
558
- // Compute N, H, W coordinates for each row of a thread's tile
559
- int64_t HW = int64_t(problem_size.H) * problem_size.W;
560
- int64_t DHW = HW * problem_size.D;
561
-
562
- CUTLASS_PRAGMA_UNROLL
563
- for (int m = 0; m < kThreadM; ++m) {
564
-
565
- int64_t ndhw = ndhw_start + m;
566
-
567
- thread_n[m] = int(ndhw / DHW);
568
-
569
- int64_t residual = ndhw % DHW;
570
- thread_d[m] = int(residual / HW);
571
-
572
- residual = residual % HW;
573
- thread_h[m] = int(residual / problem_size.W);
574
- thread_w[m] = int(residual % problem_size.W);
575
- }
576
-
577
- // Clear accumulators
578
- CUTLASS_PRAGMA_UNROLL
579
- for (int m = 0; m < kThreadM; ++m) {
580
- CUTLASS_PRAGMA_UNROLL
581
- for (int n = 0; n < kThreadN; ++n) {
582
- accum[m][n] = ElementAccumulator();
583
- }
584
- }
585
-
586
- // Compute convolution
587
- for (int T = 0; T < problem_size.T; ++T) {
588
- for (int R = 0; R < problem_size.R; ++R) {
589
- for (int S = 0; S < problem_size.S; ++S) {
590
- for (int K = 0; K < problem_size.K; ++K) {
591
-
592
- // Load from activations tensor
593
- int filter_t = T;
594
- int filter_r = R;
595
- int filter_s = S;
596
-
597
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
598
- filter_t = problem_size.T - 1 - T;
599
- filter_r = problem_size.R - 1 - R;
600
- filter_s = problem_size.S - 1 - S;
601
- }
602
-
603
- CUTLASS_PRAGMA_UNROLL
604
- for (int m = 0; m < kThreadM; ++m) {
605
-
606
- int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d;
607
- int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
608
- int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
609
-
610
- element_A[m] = ElementAccumulator();
611
-
612
- if (z >= 0 && !(z % problem_size.stride_d) &&
613
- p >= 0 && !(p % problem_size.stride_h) &&
614
- q >= 0 && !(q % problem_size.stride_w)) {
615
-
616
- z = z / problem_size.stride_d;
617
- p = p / problem_size.stride_h;
618
- q = q / problem_size.stride_w;
619
-
620
- if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
621
- element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K}));
622
- }
623
- }
624
- }
625
-
626
- // Load from filters tensor
627
- CUTLASS_PRAGMA_UNROLL
628
- for (int n = 0; n < kThreadN; ++n) {
629
- int thread_c = c_start + n;
630
-
631
- if (thread_c < problem_size.C) {
632
- element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c}));
633
- }
634
- else {
635
- element_B[n] = ElementAccumulator();
636
- }
637
- }
638
-
639
- // Accumulate matrix product
640
- CUTLASS_PRAGMA_UNROLL
641
- for (int m = 0; m < kThreadM; ++m) {
642
- CUTLASS_PRAGMA_UNROLL
643
- for (int n = 0; n < kThreadN; ++n) {
644
- accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
645
- }
646
- }
647
-
648
- } // for (C)
649
- } // for (S)
650
- } // for (R)
651
- } // for (T)
652
-
653
- // Write out the results
654
- CUTLASS_PRAGMA_UNROLL
655
- for (int m = 0; m < kThreadM; ++m) {
656
-
657
- if (thread_n[m] < problem_size.N &&
658
- thread_d[m] < problem_size.D &&
659
- thread_h[m] < problem_size.H &&
660
- thread_w[m] < problem_size.W) {
661
-
662
- CUTLASS_PRAGMA_UNROLL
663
- for (int n = 0; n < kThreadN; ++n) {
664
- int thread_c = c_start + n;
665
- if (thread_c < problem_size.C) {
666
-
667
- ElementCompute c_ref = ElementCompute();
668
- if (beta != ElementCompute()) {
669
- c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}));
670
- }
671
-
672
- tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
673
- alpha * ElementCompute(accum[m][n]) + beta * c_ref);
674
- }
675
- }
676
- }
677
- }
678
- }
679
-
680
- ///////////////////////////////////////////////////////////////////////////////////////////////////
681
-
682
- // Conv2d wgrad kernel - dw = wgrad(dy, x)
683
- template <
684
- typename ElementA,
685
- typename LayoutA,
686
- typename ElementB,
687
- typename LayoutB,
688
- typename ElementC,
689
- typename LayoutC,
690
- typename ElementCompute,
691
- typename ElementAccumulator = ElementCompute,
692
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
693
- typename InnerProductOp = multiply_add<ElementAccumulator>,
694
- int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
695
- int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
696
- int kCtaShapeM = 8, // shape of a threadblock in units of threads
697
- int kCtaShapeN = 16 // shape of a threadblock in units of threads
698
- >
699
- __global__ void Conv2dWgrad(
700
- conv::Conv2dProblemSize problem_size,
701
- TensorRef<ElementA, LayoutA> tensor_dy,
702
- TensorRef<ElementB, LayoutB> tensor_x,
703
- TensorRef<ElementC, LayoutC> tensor_dw_in,
704
- TensorRef<ElementC, LayoutC> tensor_dw_out,
705
- ElementCompute alpha,
706
- ElementCompute beta
707
- ) {
708
-
709
- ConvertOp convert_op;
710
- InnerProductOp inner_product_op;
711
-
712
- ElementAccumulator element_A[kThreadM];
713
- ElementAccumulator element_B[kThreadN];
714
- ElementAccumulator accum[kThreadM][kThreadN];
715
-
716
- int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
717
- int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
718
-
719
- int thread_r[kThreadN];
720
- int thread_s[kThreadN];
721
- int thread_c[kThreadN];
722
-
723
- // Compute R, S, C coordinates for each row of a thread's tile
724
- int64_t SC = int64_t(problem_size.S) * problem_size.C;
725
-
726
- CUTLASS_PRAGMA_UNROLL
727
- for (int n = 0; n < kThreadN; ++n) {
728
-
729
- int64_t rsc = rsc_start + n;
730
- int64_t residual = rsc % SC;
731
-
732
- thread_r[n] = int(rsc / SC);
733
- thread_s[n] = int(residual / problem_size.C);
734
- thread_c[n] = int(residual % problem_size.C);
735
- }
736
-
737
- // Clear accumulators
738
- CUTLASS_PRAGMA_UNROLL
739
- for (int m = 0; m < kThreadM; ++m) {
740
- CUTLASS_PRAGMA_UNROLL
741
- for (int n = 0; n < kThreadN; ++n) {
742
- accum[m][n] = ElementAccumulator();
743
- }
744
- }
745
-
746
- // Compute convolution
747
- for (int N = 0; N < problem_size.N; ++N) {
748
- for (int P = 0; P < problem_size.P; ++P) {
749
- for (int Q = 0; Q < problem_size.Q; ++Q) {
750
-
751
- CUTLASS_PRAGMA_UNROLL
752
- for (int m = 0; m < kThreadM; ++m) {
753
- int thread_k = k_start + m;
754
-
755
- element_A[m] = ElementAccumulator();
756
-
757
- if (thread_k < problem_size.K) {
758
- element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k}));
759
- }
760
- }
761
-
762
- // Load from filters tensor
763
- CUTLASS_PRAGMA_UNROLL
764
- for (int n = 0; n < kThreadN; ++n) {
765
-
766
- // Load from activations tensor
767
- int filter_r = thread_r[n];
768
- int filter_s = thread_s[n];
769
-
770
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
771
- filter_r = problem_size.R - 1 - filter_r;
772
- filter_s = problem_size.S - 1 - filter_s;
773
- }
774
-
775
- int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
776
- int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
777
-
778
- element_B[n] = ElementAccumulator();
779
-
780
- if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) {
781
- element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]}));
782
- }
783
- }
784
-
785
- // Accumulate matrix product
786
- CUTLASS_PRAGMA_UNROLL
787
- for (int m = 0; m < kThreadM; ++m) {
788
- CUTLASS_PRAGMA_UNROLL
789
- for (int n = 0; n < kThreadN; ++n) {
790
- accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
791
- }
792
- }
793
- }
794
- }
795
- }
796
-
797
- // Write out the results
798
- CUTLASS_PRAGMA_UNROLL
799
- for (int m = 0; m < kThreadM; ++m) {
800
- int thread_k = k_start + m;
801
-
802
- if (thread_k < problem_size.K) {
803
-
804
- CUTLASS_PRAGMA_UNROLL
805
- for (int n = 0; n < kThreadN; ++n) {
806
-
807
- if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) {
808
-
809
- ElementCompute c_ref = ElementCompute();
810
-
811
- if (beta != ElementCompute()) {
812
- c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}));
813
- }
814
-
815
- tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
816
- alpha * ElementCompute(accum[m][n]) + beta * c_ref);
817
- }
818
- }
819
- }
820
- }
821
- }
822
-
823
- // Conv3d wgrad kernel - dw = wgrad(dy, x)
824
- template <
825
- typename ElementA,
826
- typename LayoutA,
827
- typename ElementB,
828
- typename LayoutB,
829
- typename ElementC,
830
- typename LayoutC,
831
- typename ElementCompute,
832
- typename ElementAccumulator = ElementCompute,
833
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
834
- typename InnerProductOp = multiply_add<ElementAccumulator>,
835
- int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
836
- int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
837
- int kCtaShapeM = 8, // shape of a threadblock in units of threads
838
- int kCtaShapeN = 16 // shape of a threadblock in units of threads
839
- >
840
- __global__ void Conv3dWgrad(
841
- conv::Conv3dProblemSize problem_size,
842
- TensorRef<ElementA, LayoutA> tensor_dy,
843
- TensorRef<ElementB, LayoutB> tensor_x,
844
- TensorRef<ElementC, LayoutC> tensor_dw_in,
845
- TensorRef<ElementC, LayoutC> tensor_dw_out,
846
- ElementCompute alpha,
847
- ElementCompute beta
848
- ) {
849
-
850
- ConvertOp convert_op;
851
- InnerProductOp inner_product_op;
852
-
853
- ElementAccumulator element_A[kThreadM];
854
- ElementAccumulator element_B[kThreadN];
855
- ElementAccumulator accum[kThreadM][kThreadN];
856
-
857
- int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
858
- int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
859
-
860
- int thread_t[kThreadN];
861
- int thread_r[kThreadN];
862
- int thread_s[kThreadN];
863
- int thread_c[kThreadN];
864
-
865
- // Compute R, S, C coordinates for each row of a thread's tile
866
- int64_t SC = int64_t(problem_size.S) * problem_size.C;
867
- int64_t RSC = SC * problem_size.R;
868
-
869
- CUTLASS_PRAGMA_UNROLL
870
- for (int n = 0; n < kThreadN; ++n) {
871
-
872
- int64_t trsc = trsc_start + n;
873
-
874
- thread_t[n] = int(trsc / RSC);
875
-
876
- int64_t residual = trsc % RSC;
877
- thread_r[n] = int(residual / SC);
878
-
879
- residual = residual % SC;
880
- thread_s[n] = int(residual / problem_size.C);
881
- thread_c[n] = int(residual % problem_size.C);
882
- }
883
-
884
- // Clear accumulators
885
- CUTLASS_PRAGMA_UNROLL
886
- for (int m = 0; m < kThreadM; ++m) {
887
- CUTLASS_PRAGMA_UNROLL
888
- for (int n = 0; n < kThreadN; ++n) {
889
- accum[m][n] = ElementAccumulator();
890
- }
891
- }
892
-
893
- // Compute convolution
894
- for (int N = 0; N < problem_size.N; ++N) {
895
- for (int Z = 0; Z < problem_size.Z; ++Z) {
896
- for (int P = 0; P < problem_size.P; ++P) {
897
- for (int Q = 0; Q < problem_size.Q; ++Q) {
898
-
899
- CUTLASS_PRAGMA_UNROLL
900
- for (int m = 0; m < kThreadM; ++m) {
901
- int thread_k = k_start + m;
902
-
903
- element_A[m] = ElementAccumulator();
904
-
905
- if (thread_k < problem_size.K) {
906
- element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k}));
907
- }
908
- }
909
-
910
- // Load from filters tensor
911
- CUTLASS_PRAGMA_UNROLL
912
- for (int n = 0; n < kThreadN; ++n) {
913
-
914
- // Load from activations tensor
915
- int filter_t = thread_t[n];
916
- int filter_r = thread_r[n];
917
- int filter_s = thread_s[n];
918
-
919
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
920
- filter_t = problem_size.T - 1 - filter_t;
921
- filter_r = problem_size.R - 1 - filter_r;
922
- filter_s = problem_size.S - 1 - filter_s;
923
- }
924
-
925
- int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
926
- int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
927
- int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
928
-
929
- element_B[n] = ElementAccumulator();
930
-
931
- if (d >= 0 && d < problem_size.D &&
932
- h >= 0 && h < problem_size.H &&
933
- w >= 0 && w < problem_size.W &&
934
- thread_c[n] < problem_size.C) {
935
-
936
- element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]}));
937
- }
938
- }
939
-
940
- // Accumulate matrix product
941
- CUTLASS_PRAGMA_UNROLL
942
- for (int m = 0; m < kThreadM; ++m) {
943
- CUTLASS_PRAGMA_UNROLL
944
- for (int n = 0; n < kThreadN; ++n) {
945
- accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
946
- }
947
- }
948
-
949
- } // for (Q)
950
- } // for (P)
951
- } // for (Z)
952
- } // for (N)
953
-
954
- // Write out the results
955
- CUTLASS_PRAGMA_UNROLL
956
- for (int m = 0; m < kThreadM; ++m) {
957
- int thread_k = k_start + m;
958
-
959
- if (thread_k < problem_size.K) {
960
-
961
- CUTLASS_PRAGMA_UNROLL
962
- for (int n = 0; n < kThreadN; ++n) {
963
-
964
- if (thread_t[n] < problem_size.T &&
965
- thread_r[n] < problem_size.R &&
966
- thread_s[n] < problem_size.S &&
967
- thread_c[n] < problem_size.C) {
968
-
969
- ElementCompute c_ref = ElementCompute();
970
-
971
- if (beta != ElementCompute()) {
972
- c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}));
973
- }
974
-
975
- tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
976
- alpha * ElementCompute(accum[m][n]) + beta * c_ref);
977
- }
978
- }
979
- }
980
- }
981
- }
982
-
983
- /////////////////////////////////////////////////////////////////////////////////////////////////
984
-
985
- } // namespace kernel
986
-
987
- /////////////////////////////////////////////////////////////////////////////////////////////////
988
-
989
- /// Conv2d Fprop dispatcher - y = fprop(x, w)
990
- template <
991
- typename ElementA,
992
- typename LayoutA,
993
- typename ElementB,
994
- typename LayoutB,
995
- typename ElementC,
996
- typename LayoutC,
997
- typename ElementCompute,
998
- typename ElementAccumulator = ElementCompute,
999
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1000
- typename InnerProductOp = multiply_add<ElementAccumulator>
1001
- >
1002
- Status Conv2dFprop(
1003
- conv::Conv2dProblemSize problem_size,
1004
- TensorRef<ElementA, LayoutA> tensor_x,
1005
- TensorRef<ElementB, LayoutB> tensor_w,
1006
- TensorRef<ElementC, LayoutC> tensor_y_in,
1007
- TensorRef<ElementC, LayoutC> tensor_y_out,
1008
- ElementCompute alpha,
1009
- ElementCompute beta,
1010
- cudaStream_t stream = nullptr) {
1011
-
1012
- //
1013
- // Blocking factors improve performance of reference implementation
1014
- //
1015
-
1016
- int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
1017
- int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1018
- int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1019
- int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1020
-
1021
- int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
1022
- int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1023
-
1024
- dim3 block(kCtaShapeM, kCtaShapeN);
1025
- dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1026
-
1027
- kernel::Conv2dFprop<
1028
- ElementA,
1029
- LayoutA,
1030
- ElementB,
1031
- LayoutB,
1032
- ElementC,
1033
- LayoutC,
1034
- ElementCompute,
1035
- ElementAccumulator,
1036
- ConvertOp,
1037
- InnerProductOp,
1038
- kThreadM,
1039
- kThreadN,
1040
- kCtaShapeM,
1041
- kCtaShapeN
1042
- ><<< grid, block, 0, stream >>>(
1043
- problem_size,
1044
- tensor_x,
1045
- tensor_w,
1046
- tensor_y_in,
1047
- tensor_y_out,
1048
- alpha,
1049
- beta
1050
- );
1051
-
1052
- cudaError_t result = cudaPeekAtLastError();
1053
- if (result != cudaSuccess) {
1054
- return Status::kErrorInternal;
1055
- }
1056
-
1057
- return Status::kSuccess;
1058
- }
1059
-
1060
- /// Conv3d Fprop dispatcher - y = fprop(x, w)
1061
- template <
1062
- typename ElementA,
1063
- typename LayoutA,
1064
- typename ElementB,
1065
- typename LayoutB,
1066
- typename ElementC,
1067
- typename LayoutC,
1068
- typename ElementCompute,
1069
- typename ElementAccumulator = ElementCompute,
1070
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1071
- typename InnerProductOp = multiply_add<ElementAccumulator>
1072
- >
1073
- Status Conv3dFprop(
1074
- conv::Conv3dProblemSize problem_size,
1075
- TensorRef<ElementA, LayoutA> tensor_x,
1076
- TensorRef<ElementB, LayoutB> tensor_w,
1077
- TensorRef<ElementC, LayoutC> tensor_y_in,
1078
- TensorRef<ElementC, LayoutC> tensor_y_out,
1079
- ElementCompute alpha,
1080
- ElementCompute beta,
1081
- cudaStream_t stream = nullptr) {
1082
-
1083
- //
1084
- // Blocking factors improve performance of reference implementation
1085
- //
1086
-
1087
- int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
1088
- int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1089
- int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1090
- int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1091
-
1092
- int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q;
1093
- int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1094
-
1095
- dim3 block(kCtaShapeM, kCtaShapeN);
1096
- dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1097
-
1098
- kernel::Conv3dFprop<
1099
- ElementA,
1100
- LayoutA,
1101
- ElementB,
1102
- LayoutB,
1103
- ElementC,
1104
- LayoutC,
1105
- ElementCompute,
1106
- ElementAccumulator,
1107
- ConvertOp,
1108
- InnerProductOp,
1109
- kThreadM,
1110
- kThreadN,
1111
- kCtaShapeM,
1112
- kCtaShapeN
1113
- ><<< grid, block, 0, stream >>>(
1114
- problem_size,
1115
- tensor_x,
1116
- tensor_w,
1117
- tensor_y_in,
1118
- tensor_y_out,
1119
- alpha,
1120
- beta
1121
- );
1122
-
1123
- cudaError_t result = cudaPeekAtLastError();
1124
- if (result != cudaSuccess) {
1125
- return Status::kErrorInternal;
1126
- }
1127
-
1128
- return Status::kSuccess;
1129
- }
1130
-
1131
- /// Conv2d Dgrad dispatcher - dx = dgrad(dy, w)
1132
- template <
1133
- typename ElementA,
1134
- typename LayoutA,
1135
- typename ElementB,
1136
- typename LayoutB,
1137
- typename ElementC,
1138
- typename LayoutC,
1139
- typename ElementCompute,
1140
- typename ElementAccumulator = ElementCompute,
1141
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1142
- typename InnerProductOp = multiply_add<ElementAccumulator>
1143
- >
1144
- Status Conv2dDgrad(
1145
- conv::Conv2dProblemSize problem_size,
1146
- TensorRef<ElementA, LayoutA> tensor_dy,
1147
- TensorRef<ElementB, LayoutB> tensor_w,
1148
- TensorRef<ElementC, LayoutC> tensor_dx_in,
1149
- TensorRef<ElementC, LayoutC> tensor_dx_out,
1150
- ElementCompute alpha,
1151
- ElementCompute beta,
1152
- cudaStream_t stream = nullptr) {
1153
-
1154
- //
1155
- // Blocking factors improve performance of reference implementation
1156
- //
1157
-
1158
- int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1159
- int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1160
- int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1161
- int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1162
-
1163
- int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W;
1164
- int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1165
-
1166
- dim3 block(kCtaShapeM, kCtaShapeN);
1167
- dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1168
-
1169
- kernel::Conv2dDgrad<
1170
- ElementA,
1171
- LayoutA,
1172
- ElementB,
1173
- LayoutB,
1174
- ElementC,
1175
- LayoutC,
1176
- ElementCompute,
1177
- ElementAccumulator,
1178
- ConvertOp,
1179
- InnerProductOp,
1180
- kThreadM,
1181
- kThreadN,
1182
- kCtaShapeM,
1183
- kCtaShapeN
1184
- ><<< grid, block, 0, stream >>>(
1185
- problem_size,
1186
- tensor_dy,
1187
- tensor_w,
1188
- tensor_dx_in,
1189
- tensor_dx_out,
1190
- alpha,
1191
- beta
1192
- );
1193
-
1194
- cudaError_t result = cudaPeekAtLastError();
1195
- if (result != cudaSuccess) {
1196
- return Status::kErrorInternal;
1197
- }
1198
-
1199
- return Status::kSuccess;
1200
- }
1201
-
1202
- /// Conv3d Dgrad dispatcher - dx = dgrad(dy, w)
1203
- template <
1204
- typename ElementA,
1205
- typename LayoutA,
1206
- typename ElementB,
1207
- typename LayoutB,
1208
- typename ElementC,
1209
- typename LayoutC,
1210
- typename ElementCompute,
1211
- typename ElementAccumulator = ElementCompute,
1212
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1213
- typename InnerProductOp = multiply_add<ElementAccumulator>
1214
- >
1215
- Status Conv3dDgrad(
1216
- conv::Conv3dProblemSize problem_size,
1217
- TensorRef<ElementA, LayoutA> tensor_dy,
1218
- TensorRef<ElementB, LayoutB> tensor_w,
1219
- TensorRef<ElementC, LayoutC> tensor_dx_in,
1220
- TensorRef<ElementC, LayoutC> tensor_dx_out,
1221
- ElementCompute alpha,
1222
- ElementCompute beta,
1223
- cudaStream_t stream = nullptr) {
1224
-
1225
- //
1226
- // Blocking factors improve performance of reference implementation
1227
- //
1228
-
1229
- int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1230
- int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1231
- int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1232
- int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1233
-
1234
- int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W;
1235
- int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1236
-
1237
- dim3 block(kCtaShapeM, kCtaShapeN);
1238
- dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1239
-
1240
- kernel::Conv3dDgrad<
1241
- ElementA,
1242
- LayoutA,
1243
- ElementB,
1244
- LayoutB,
1245
- ElementC,
1246
- LayoutC,
1247
- ElementCompute,
1248
- ElementAccumulator,
1249
- ConvertOp,
1250
- InnerProductOp,
1251
- kThreadM,
1252
- kThreadN,
1253
- kCtaShapeM,
1254
- kCtaShapeN
1255
- ><<< grid, block, 0, stream >>>(
1256
- problem_size,
1257
- tensor_dy,
1258
- tensor_w,
1259
- tensor_dx_in,
1260
- tensor_dx_out,
1261
- alpha,
1262
- beta
1263
- );
1264
-
1265
- cudaError_t result = cudaPeekAtLastError();
1266
- if (result != cudaSuccess) {
1267
- return Status::kErrorInternal;
1268
- }
1269
-
1270
- return Status::kSuccess;
1271
- }
1272
-
1273
- /// Conv2d Wgrad dispatcher - dw = wgrad(dy, x)
1274
- template <
1275
- typename ElementA,
1276
- typename LayoutA,
1277
- typename ElementB,
1278
- typename LayoutB,
1279
- typename ElementC,
1280
- typename LayoutC,
1281
- typename ElementCompute,
1282
- typename ElementAccumulator = ElementCompute,
1283
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1284
- typename InnerProductOp = multiply_add<ElementAccumulator>
1285
- >
1286
- Status Conv2dWgrad(
1287
- conv::Conv2dProblemSize problem_size,
1288
- TensorRef<ElementA, LayoutA> tensor_dy,
1289
- TensorRef<ElementB, LayoutB> tensor_x,
1290
- TensorRef<ElementC, LayoutC> tensor_dw_in,
1291
- TensorRef<ElementC, LayoutC> tensor_dw_out,
1292
- ElementCompute alpha,
1293
- ElementCompute beta,
1294
- cudaStream_t stream = nullptr) {
1295
-
1296
- //
1297
- // Blocking factors improve performance of reference implementation
1298
- //
1299
-
1300
- int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1301
- int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1302
- int const kCtaShapeM = 8; // shape of a threadblock in units of threads
1303
- int const kCtaShapeN = 16; // shape of a threadblock in units of threads
1304
-
1305
- int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C;
1306
- int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
1307
-
1308
- dim3 block(kCtaShapeM, kCtaShapeN);
1309
- dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
1310
-
1311
- kernel::Conv2dWgrad<
1312
- ElementA,
1313
- LayoutA,
1314
- ElementB,
1315
- LayoutB,
1316
- ElementC,
1317
- LayoutC,
1318
- ElementCompute,
1319
- ElementAccumulator,
1320
- ConvertOp,
1321
- InnerProductOp,
1322
- kThreadM,
1323
- kThreadN,
1324
- kCtaShapeM,
1325
- kCtaShapeN
1326
- ><<< grid, block, 0, stream >>>(
1327
- problem_size,
1328
- tensor_dy,
1329
- tensor_x,
1330
- tensor_dw_in,
1331
- tensor_dw_out,
1332
- alpha,
1333
- beta
1334
- );
1335
-
1336
- cudaError_t result = cudaPeekAtLastError();
1337
- if (result != cudaSuccess) {
1338
- return Status::kErrorInternal;
1339
- }
1340
-
1341
- return Status::kSuccess;
1342
- }
1343
-
1344
- /// Conv3d Wgrad dispatcher - dw = wgrad(dy, x)
1345
- template <
1346
- typename ElementA,
1347
- typename LayoutA,
1348
- typename ElementB,
1349
- typename LayoutB,
1350
- typename ElementC,
1351
- typename LayoutC,
1352
- typename ElementCompute,
1353
- typename ElementAccumulator = ElementCompute,
1354
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1355
- typename InnerProductOp = multiply_add<ElementAccumulator>
1356
- >
1357
- Status Conv3dWgrad(
1358
- conv::Conv3dProblemSize problem_size,
1359
- TensorRef<ElementA, LayoutA> tensor_dy,
1360
- TensorRef<ElementB, LayoutB> tensor_x,
1361
- TensorRef<ElementC, LayoutC> tensor_dw_in,
1362
- TensorRef<ElementC, LayoutC> tensor_dw_out,
1363
- ElementCompute alpha,
1364
- ElementCompute beta,
1365
- cudaStream_t stream = nullptr) {
1366
-
1367
- //
1368
- // Blocking factors improve performance of reference implementation
1369
- //
1370
-
1371
- int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1372
- int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1373
- int const kCtaShapeM = 8; // shape of a threadblock in units of threads
1374
- int const kCtaShapeN = 16; // shape of a threadblock in units of threads
1375
-
1376
- int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C;
1377
- int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
1378
-
1379
- dim3 block(kCtaShapeM, kCtaShapeN);
1380
- dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
1381
-
1382
- kernel::Conv3dWgrad<
1383
- ElementA,
1384
- LayoutA,
1385
- ElementB,
1386
- LayoutB,
1387
- ElementC,
1388
- LayoutC,
1389
- ElementCompute,
1390
- ElementAccumulator,
1391
- ConvertOp,
1392
- InnerProductOp,
1393
- kThreadM,
1394
- kThreadN,
1395
- kCtaShapeM,
1396
- kCtaShapeN
1397
- ><<< grid, block, 0, stream >>>(
1398
- problem_size,
1399
- tensor_dy,
1400
- tensor_x,
1401
- tensor_dw_in,
1402
- tensor_dw_out,
1403
- alpha,
1404
- beta
1405
- );
1406
-
1407
- cudaError_t result = cudaPeekAtLastError();
1408
- if (result != cudaSuccess) {
1409
- return Status::kErrorInternal;
1410
- }
1411
-
1412
- return Status::kSuccess;
1413
- }
1414
-
1415
- /////////////////////////////////////////////////////////////////////////////////////////////////
1416
-
1417
- /// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
1418
- template <
1419
- typename ElementA,
1420
- typename LayoutA,
1421
- typename ElementB,
1422
- typename LayoutB,
1423
- typename ElementC,
1424
- typename LayoutC,
1425
- typename ElementCompute,
1426
- typename ElementAccumulator = ElementCompute,
1427
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1428
- typename InnerProductOp = multiply_add<ElementAccumulator>
1429
- >
1430
- Status Conv2d(
1431
- conv::Operator convolutional_operator,
1432
- conv::Conv2dProblemSize problem_size,
1433
- TensorRef<ElementA, LayoutA> tensor_A,
1434
- TensorRef<ElementB, LayoutB> tensor_B,
1435
- TensorRef<ElementC, LayoutC> tensor_C,
1436
- TensorRef<ElementC, LayoutC> tensor_D,
1437
- ElementCompute alpha,
1438
- ElementCompute beta,
1439
- cudaStream_t stream = nullptr) {
1440
-
1441
- switch (convolutional_operator) {
1442
- case conv::Operator::kFprop:
1443
- return Conv2dFprop<
1444
- ElementA, LayoutA,
1445
- ElementB, LayoutB,
1446
- ElementC, LayoutC,
1447
- ElementCompute,
1448
- ElementAccumulator,
1449
- ConvertOp, InnerProductOp
1450
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1451
- break;
1452
-
1453
- case conv::Operator::kDgrad:
1454
- return Conv2dDgrad<
1455
- ElementA, LayoutA,
1456
- ElementB, LayoutB,
1457
- ElementC, LayoutC,
1458
- ElementCompute,
1459
- ElementAccumulator,
1460
- ConvertOp, InnerProductOp
1461
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1462
- break;
1463
-
1464
- case conv::Operator::kWgrad:
1465
- return Conv2dWgrad<
1466
- ElementA, LayoutA,
1467
- ElementB, LayoutB,
1468
- ElementC, LayoutC,
1469
- ElementCompute,
1470
- ElementAccumulator,
1471
- ConvertOp, InnerProductOp
1472
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1473
- break;
1474
-
1475
- default: break;
1476
- }
1477
-
1478
- return Status::kErrorNotSupported;
1479
- }
1480
-
1481
- /// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad.
1482
- template <
1483
- typename ElementA,
1484
- typename LayoutA,
1485
- typename ElementB,
1486
- typename LayoutB,
1487
- typename ElementC,
1488
- typename LayoutC,
1489
- typename ElementCompute,
1490
- typename ElementAccumulator = ElementCompute,
1491
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1492
- typename InnerProductOp = multiply_add<ElementAccumulator>
1493
- >
1494
- Status Conv3d(
1495
- conv::Operator convolutional_operator,
1496
- conv::Conv3dProblemSize problem_size,
1497
- TensorRef<ElementA, LayoutA> tensor_A,
1498
- TensorRef<ElementB, LayoutB> tensor_B,
1499
- TensorRef<ElementC, LayoutC> tensor_C,
1500
- TensorRef<ElementC, LayoutC> tensor_D,
1501
- ElementCompute alpha,
1502
- ElementCompute beta,
1503
- cudaStream_t stream = nullptr) {
1504
-
1505
- switch (convolutional_operator) {
1506
- case conv::Operator::kFprop:
1507
- return Conv3dFprop<
1508
- ElementA, LayoutA,
1509
- ElementB, LayoutB,
1510
- ElementC, LayoutC,
1511
- ElementCompute,
1512
- ElementAccumulator,
1513
- ConvertOp, InnerProductOp
1514
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1515
-
1516
- case conv::Operator::kDgrad:
1517
- return Conv3dDgrad<
1518
- ElementA, LayoutA,
1519
- ElementB, LayoutB,
1520
- ElementC, LayoutC,
1521
- ElementCompute,
1522
- ElementAccumulator,
1523
- ConvertOp, InnerProductOp
1524
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1525
-
1526
- case conv::Operator::kWgrad:
1527
- return Conv3dWgrad<
1528
- ElementA, LayoutA,
1529
- ElementB, LayoutB,
1530
- ElementC, LayoutC,
1531
- ElementCompute,
1532
- ElementAccumulator,
1533
- ConvertOp, InnerProductOp
1534
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1535
-
1536
- default: break;
1537
- }
1538
-
1539
- return Status::kErrorNotSupported;
1540
- }
1541
-
1542
- ////////////////////////////////////////////////////////////////////////////////////////////////////
1543
-
1544
- } // namespace device
1545
- } // namespace reference
1546
- } // namespace cutlass
1547
-
1548
- ////////////////////////////////////////////////////////////////////////////////////////////////////
1549
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h DELETED
@@ -1,385 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for GEMM in device-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
-
39
- #include "cutlass/numeric_types.h"
40
- #include "cutlass/functional.h"
41
- #include "cutlass/numeric_conversion.h"
42
-
43
- #include "cutlass/tensor_view.h"
44
- #include "cutlass/gemm/gemm.h"
45
-
46
- #include "cutlass/util/reference/device/kernel/gemm.h"
47
-
48
- namespace cutlass {
49
- namespace reference {
50
- namespace device {
51
-
52
- ////////////////////////////////////////////////////////////////////////////////////////////////////
53
-
54
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
- /// objects.
56
- ///
57
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
58
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
59
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
60
- /// arguments explicitly.
61
- template <
62
- typename ElementA,
63
- typename LayoutA,
64
- typename ElementB,
65
- typename LayoutB,
66
- typename ElementC,
67
- typename LayoutC,
68
- typename ScalarType,
69
- typename AccumulatorType,
70
- typename InnerProductOp = multiply_add<AccumulatorType>,
71
- typename ConvertOp = NumericConverter<ElementC, ScalarType>
72
- >
73
- void compute_gemm(
74
- gemm::GemmCoord problem_size,
75
- ScalarType alpha,
76
- TensorRef<ElementA, LayoutA> tensor_a,
77
- TensorRef<ElementB, LayoutB> tensor_b,
78
- ScalarType beta,
79
- TensorRef<ElementC, LayoutC> tensor_c,
80
- TensorRef<ElementC, LayoutC> tensor_d,
81
- AccumulatorType initial_accum) {
82
-
83
- static_assert(
84
- LayoutA::kRank == 2 &&
85
- LayoutB::kRank == 2 &&
86
- LayoutC::kRank == 2, "Tensors must be of rank 2");
87
-
88
- // Blocking structure potentially improves performance of reference implementation
89
- // with a minor increase in complexity.
90
- //
91
- // Note, this reference implementation is NOT expected to approach peak performance.
92
- using OutputTile = MatrixShape<4, 4>;
93
-
94
- dim3 block(16, 8);
95
-
96
- dim3 grid(
97
- (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
98
- (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
99
- );
100
-
101
- // Launch a GEMM kernel
102
- kernel::Gemm<
103
- TensorRef<ElementA, LayoutA>,
104
- TensorRef<ElementB, LayoutB>,
105
- TensorRef<ElementC, LayoutC>,
106
- ScalarType,
107
- AccumulatorType,
108
- OutputTile,
109
- InnerProductOp,
110
- ConvertOp
111
- ><<< grid, block >>>(
112
- problem_size,
113
- alpha,
114
- tensor_a,
115
- tensor_b,
116
- beta,
117
- tensor_c,
118
- tensor_d,
119
- initial_accum
120
- );
121
- }
122
- ////////////////////////////////////////////////////////////////////////////////////////////////////
123
-
124
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
125
- /// objects.
126
- ///
127
- /// This assumes the accumulator type is the same type as the scalars.
128
- template <
129
- typename ElementA,
130
- typename LayoutA,
131
- typename ElementB,
132
- typename LayoutB,
133
- typename ElementC,
134
- typename LayoutC,
135
- typename ScalarType,
136
- typename AccumulatorType,
137
- typename InnerProductOp = multiply_add<AccumulatorType>,
138
- typename ConvertOp = NumericConverter<ElementC, ScalarType>
139
- >
140
- void compute_gemm(
141
- gemm::GemmCoord problem_size,
142
- ScalarType alpha,
143
- TensorRef<ElementA, LayoutA> tensor_a,
144
- TensorRef<ElementB, LayoutB> tensor_b,
145
- ScalarType beta,
146
- TensorRef<ElementC, LayoutC> tensor_c,
147
- AccumulatorType initial_accum) {
148
-
149
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
150
- ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
151
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
152
- initial_accum);
153
- }
154
-
155
- template <
156
- typename ElementA,
157
- typename LayoutA,
158
- typename ElementB,
159
- typename LayoutB,
160
- typename ElementC,
161
- typename LayoutC,
162
- typename ScalarType,
163
- typename AccumulatorType,
164
- typename InnerProductOp = cutlass::arch::OpMultiplyAdd
165
- >
166
- struct Gemm;
167
-
168
- ////////////////////////////////////////////////////////////////////////////////////////////////////
169
-
170
- /// Partial specialization for multiply-add
171
- template <typename ElementA, typename LayoutA, typename ElementB,
172
- typename LayoutB, typename ElementC, typename LayoutC,
173
- typename ScalarType, typename AccumulatorType>
174
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
175
- ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
176
-
177
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
178
- TensorRef<ElementA, LayoutA> tensor_a,
179
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
180
- TensorRef<ElementC, LayoutC> tensor_c,
181
- AccumulatorType initial_accum = AccumulatorType(0)) {
182
-
183
- static_assert(
184
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
185
- "Tensors must be of rank 2");
186
-
187
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
188
- ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
189
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
190
- }
191
-
192
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
193
- TensorRef<ElementA, LayoutA> tensor_a,
194
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
195
- TensorRef<ElementC, LayoutC> tensor_c,
196
- TensorRef<ElementC, LayoutC> tensor_d,
197
- AccumulatorType initial_accum = AccumulatorType(0)) {
198
- static_assert(
199
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
200
- "Tensors must be of rank 2");
201
-
202
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
203
- ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
204
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
205
- }
206
- };
207
-
208
- ////////////////////////////////////////////////////////////////////////////////////////////////////
209
-
210
- /// Partial specialization for multiply-add-saturate
211
- template <typename ElementA, typename LayoutA, typename ElementB,
212
- typename LayoutB, typename ElementC, typename LayoutC,
213
- typename ScalarType, typename AccumulatorType>
214
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
215
- AccumulatorType, arch::OpMultiplyAddSaturate> {
216
-
217
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
218
- TensorRef<ElementA, LayoutA> tensor_a,
219
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
220
- TensorRef<ElementC, LayoutC> tensor_c,
221
- AccumulatorType initial_accum = AccumulatorType(0)) {
222
- static_assert(
223
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
224
- "Tensors must be of rank 2");
225
-
226
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
227
- ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
228
- NumericConverterClamp<ElementC, ScalarType>>(
229
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
230
- }
231
-
232
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
233
- TensorRef<ElementA, LayoutA> tensor_a,
234
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
235
- TensorRef<ElementC, LayoutC> tensor_c,
236
- TensorRef<ElementC, LayoutC> tensor_d,
237
- AccumulatorType initial_accum = AccumulatorType(0)) {
238
- static_assert(
239
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
240
- "Tensors must be of rank 2");
241
-
242
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
243
- ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
244
- NumericConverterClamp<ElementC, ScalarType>>(
245
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
246
- }
247
- };
248
-
249
- ////////////////////////////////////////////////////////////////////////////////////////////////////
250
-
251
- /// Partial specialization for XOR-popc
252
- template <typename ElementA, typename LayoutA, typename ElementB,
253
- typename LayoutB, typename ElementC, typename LayoutC,
254
- typename ScalarType, typename AccumulatorType>
255
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
256
- AccumulatorType, arch::OpXorPopc> {
257
-
258
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
259
- TensorRef<ElementA, LayoutA> tensor_a,
260
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
261
- TensorRef<ElementC, LayoutC> tensor_c,
262
- AccumulatorType initial_accum = AccumulatorType(0)) {
263
- static_assert(
264
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
265
- "Tensors must be of rank 2");
266
-
267
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
268
- ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
269
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
270
- }
271
-
272
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
273
- TensorRef<ElementA, LayoutA> tensor_a,
274
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
275
- TensorRef<ElementC, LayoutC> tensor_c,
276
- TensorRef<ElementC, LayoutC> tensor_d,
277
- AccumulatorType initial_accum = AccumulatorType(0)) {
278
- static_assert(
279
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
280
- "Tensors must be of rank 2");
281
-
282
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
283
- ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
284
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
285
- }
286
- };
287
-
288
-
289
- ////////////////////////////////////////////////////////////////////////////////////////////////////
290
- //
291
- // Batched GEMM
292
- //
293
- ////////////////////////////////////////////////////////////////////////////////////////////////////
294
-
295
- /// Computes a batch of GEMMs over a set of matrices of common dimension.
296
- //
297
- // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
298
- //
299
- template <
300
- typename TensorRefCollectionA,
301
- typename TensorRefCollectionB,
302
- typename TensorRefCollectionC,
303
- typename ScalarType,
304
- typename AccumulatorType,
305
- typename InnerProductOp,
306
- typename ConvertOp
307
- >
308
- void BatchedGemm(
309
- gemm::GemmCoord problem_size,
310
- int batch_count,
311
- ScalarType alpha,
312
- TensorRefCollectionA const& tensor_a,
313
- TensorRefCollectionB const& tensor_b,
314
- ScalarType beta,
315
- TensorRefCollectionC &tensor_c,
316
- AccumulatorType initial_accum) {
317
-
318
- static_assert(
319
- TensorRefCollectionA::kRank == 2 &&
320
- TensorRefCollectionB::kRank == 2 &&
321
- TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2");
322
-
323
- // Blocking structure potentially improves performance of reference implementation
324
- // with a minor increase in complexity.
325
- //
326
- // Note, this reference implementation is NOT expected to approach peak performance.
327
- using OutputTile = MatrixShape<4, 4>;
328
-
329
- dim3 block(16, 8);
330
- dim3 grid(
331
- (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
332
- (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
333
- batch_count
334
- );
335
-
336
- // Launch a GEMM kernel
337
- kernel::BatchedGemm<
338
- TensorRefCollectionA,
339
- TensorRefCollectionB,
340
- TensorRefCollectionC,
341
- ScalarType,
342
- AccumulatorType,
343
- OutputTile,
344
- InnerProductOp,
345
- ConvertOp
346
- ><<< grid, block >>>(
347
- problem_size,
348
- alpha,
349
- tensor_a,
350
- tensor_b,
351
- beta,
352
- tensor_c,
353
- initial_accum
354
- );
355
- }
356
-
357
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
358
- /// objects.
359
- //
360
- // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
361
- //
362
- template <
363
- typename TensorRefCollectionA,
364
- typename TensorRefCollectionB,
365
- typename TensorRefCollectionC,
366
- typename ScalarType,
367
- typename AccumulatorType
368
- >
369
- void BatchedGemm(
370
- gemm::GemmCoord problem_size,
371
- int batch_count,
372
- ScalarType alpha,
373
- TensorRefCollectionA const& tensor_a,
374
- TensorRefCollectionB const& tensor_b,
375
- ScalarType beta,
376
- TensorRefCollectionC &tensor_c) {
377
-
378
- BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
379
- }
380
-
381
- ////////////////////////////////////////////////////////////////////////////////////////////////////
382
-
383
- } // namespace device
384
- } // namespace reference
385
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h DELETED
@@ -1,350 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for complex-valued GEMM in device-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/complex.h"
39
- #include "cutlass/numeric_types.h"
40
- #include "cutlass/functional.h"
41
- #include "cutlass/numeric_conversion.h"
42
-
43
- #include "cutlass/tensor_view.h"
44
- #include "cutlass/gemm/gemm.h"
45
-
46
- namespace cutlass {
47
- namespace reference {
48
- namespace device {
49
-
50
- ////////////////////////////////////////////////////////////////////////////////////////////////////
51
-
52
- namespace kernel {
53
-
54
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
- /// objects.
56
- ///
57
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
58
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
59
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
60
- /// arguments explicitly.
61
- template <
62
- typename ElementA,
63
- typename LayoutA,
64
- typename ElementB,
65
- typename LayoutB,
66
- typename ElementC,
67
- typename LayoutC,
68
- typename ScalarType,
69
- typename ComputeType,
70
- typename ElementD = ElementC,
71
- typename ConvertOp = NumericConverter<ElementD, ScalarType>,
72
- typename InnerProductOp = multiply_add<ComputeType>,
73
- int kMblock = 4,
74
- int kNblock = 4
75
- >
76
- __global__ void GemmComplex(
77
- gemm::GemmCoord problem_size,
78
- ScalarType alpha,
79
- TensorRef<ElementA, LayoutA> tensor_a,
80
- ComplexTransform transform_a,
81
- TensorRef<ElementB, LayoutB> tensor_b,
82
- ComplexTransform transform_b,
83
- ScalarType beta,
84
- TensorRef<ElementC, LayoutC> tensor_c,
85
- TensorRef<ElementD, LayoutC> tensor_d,
86
- ComputeType initial_accum,
87
- int batch_count = 1,
88
- int64_t batch_stride_A = 0,
89
- int64_t batch_stride_B = 0,
90
- int64_t batch_stride_C = 0,
91
- int64_t batch_stride_D = 0) {
92
-
93
- static_assert(
94
- LayoutA::kRank == 2 &&
95
- LayoutB::kRank == 2 &&
96
- LayoutC::kRank == 2, "Tensors must be of rank 2");
97
-
98
- int const M = problem_size.m();
99
- int const N = problem_size.n();
100
- int const K = problem_size.k();
101
-
102
- ConvertOp convert_op;
103
- InnerProductOp inner_product_op;
104
-
105
- int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
106
- int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
107
- int batch_idx = blockIdx.z;
108
-
109
- tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
110
- tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
111
- tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
112
- tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
113
-
114
- for (; batch_idx < batch_count; batch_idx += gridDim.z) {
115
-
116
- // Compute matrix product using blocks
117
- ComputeType accum[kMblock][kNblock];
118
-
119
- CUTLASS_PRAGMA_UNROLL
120
- for (int j = 0; j < kNblock; j++) {
121
- CUTLASS_PRAGMA_UNROLL
122
- for (int i = 0; i < kMblock; i++) {
123
- accum[i][j] = initial_accum;
124
- }
125
- }
126
-
127
- for (int k_block = 0; k_block < K; ++k_block) {
128
- CUTLASS_PRAGMA_UNROLL
129
- for (int j = 0; j < kNblock; j++) {
130
- CUTLASS_PRAGMA_UNROLL
131
- for (int i = 0; i < kMblock; i++) {
132
- int row = row_block + i;
133
- int col = col_block + j;
134
-
135
- if (row < M && col < N) {
136
- ElementA a = tensor_a.at(MatrixCoord(row, k_block));
137
- ElementB b = tensor_b.at(MatrixCoord(k_block, col));
138
-
139
- ComputeType a_ik = ComputeType(a);
140
- ComputeType b_kj = ComputeType(b);
141
-
142
- if (transform_a == ComplexTransform::kConjugate) {
143
- a_ik = conj(a_ik);
144
- }
145
-
146
- if (transform_b == ComplexTransform::kConjugate) {
147
- b_kj = conj(b_kj);
148
- }
149
-
150
- accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
151
- }
152
- }
153
- }
154
- }
155
-
156
- CUTLASS_PRAGMA_UNROLL
157
- for (int j = 0; j < kNblock; j++) {
158
- CUTLASS_PRAGMA_UNROLL
159
- for (int i = 0; i < kMblock; i++) {
160
- int row = row_block + i;
161
- int col = col_block + j;
162
-
163
- MatrixCoord coord = MatrixCoord(row, col);
164
-
165
- if (row < M && col < N) {
166
-
167
- tensor_d.at(coord) = convert_op(
168
- alpha * ScalarType(accum[i][j]) +
169
- beta * ScalarType(tensor_c.at(coord)));
170
- }
171
- }
172
- }
173
-
174
- tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
175
- tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
176
- tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
177
- tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
178
-
179
- } // for (batch_idx)
180
- }
181
-
182
- } // namespace kernel
183
-
184
- ////////////////////////////////////////////////////////////////////////////////////////////////////
185
-
186
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
187
- /// objects.
188
- ///
189
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
190
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
191
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
192
- /// arguments explicitly.
193
- template <
194
- typename ElementA,
195
- typename LayoutA,
196
- typename ElementB,
197
- typename LayoutB,
198
- typename ElementC,
199
- typename LayoutC,
200
- typename ScalarType,
201
- typename ComputeType,
202
- typename ElementD = ElementC,
203
- typename ConvertOp = NumericConverter<ElementD, ScalarType>,
204
- typename InnerProductOp = multiply_add<ComputeType>
205
- >
206
- void GemmComplex(
207
- gemm::GemmCoord problem_size,
208
- ScalarType alpha,
209
- TensorRef<ElementA, LayoutA> tensor_a,
210
- ComplexTransform transform_a,
211
- TensorRef<ElementB, LayoutB> tensor_b,
212
- ComplexTransform transform_b,
213
- ScalarType beta,
214
- TensorRef<ElementC, LayoutC> tensor_c,
215
- TensorRef<ElementD, LayoutC> tensor_d,
216
- ComputeType initial_accum,
217
- int batch_count = 1,
218
- int64_t batch_stride_A = 0,
219
- int64_t batch_stride_B = 0,
220
- int64_t batch_stride_C = 0,
221
- int64_t batch_stride_D = 0) {
222
-
223
- static_assert(
224
- LayoutA::kRank == 2 &&
225
- LayoutB::kRank == 2 &&
226
- LayoutC::kRank == 2, "Tensors must be of rank 2");
227
-
228
- int const kMblock = 4;
229
- int const kNblock = 4;
230
-
231
- dim3 block(16, 8);
232
- dim3 grid(
233
- (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
234
- (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
235
- batch_count % std::numeric_limits<uint16_t>::max()
236
- );
237
-
238
- if (grid.y <= std::numeric_limits<uint16_t>::max()) {
239
- kernel::GemmComplex<
240
- ElementA,
241
- LayoutA,
242
- ElementB,
243
- LayoutB,
244
- ElementC,
245
- LayoutC,
246
- ScalarType,
247
- ComputeType,
248
- ElementD,
249
- ConvertOp,
250
- InnerProductOp,
251
- kMblock,
252
- kNblock
253
- ><<< grid, block >>>(
254
- problem_size,
255
- alpha,
256
- tensor_a,
257
- transform_a,
258
- tensor_b,
259
- transform_b,
260
- beta,
261
- tensor_c,
262
- tensor_d,
263
- initial_accum,
264
- batch_count,
265
- batch_stride_A,
266
- batch_stride_B,
267
- batch_stride_C,
268
- batch_stride_D
269
- );
270
- } else {
271
- // Using bigger thread tile size
272
- int const kBigMblock = 4;
273
- int const kBigNblock = 16;
274
-
275
- dim3 Bigblock(16, 8);
276
- dim3 Biggrid(
277
- (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock),
278
- (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock),
279
- batch_count % std::numeric_limits<uint16_t>::max()
280
- );
281
-
282
- kernel::GemmComplex<
283
- ElementA,
284
- LayoutA,
285
- ElementB,
286
- LayoutB,
287
- ElementC,
288
- LayoutC,
289
- ScalarType,
290
- ComputeType,
291
- ElementD,
292
- ConvertOp,
293
- InnerProductOp,
294
- kBigMblock,
295
- kBigNblock
296
- ><<< Biggrid, Bigblock >>>(
297
- problem_size,
298
- alpha,
299
- tensor_a,
300
- transform_a,
301
- tensor_b,
302
- transform_b,
303
- beta,
304
- tensor_c,
305
- tensor_d,
306
- initial_accum,
307
- batch_count,
308
- batch_stride_A,
309
- batch_stride_B,
310
- batch_stride_C,
311
- batch_stride_D
312
- );
313
- }
314
- }
315
-
316
- ////////////////////////////////////////////////////////////////////////////////////////////////////
317
-
318
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
319
- /// objects.
320
- ///
321
- /// This assumes the accumulator type is the same type as the scalars.
322
- template <
323
- typename ElementA,
324
- typename LayoutA,
325
- typename ElementB,
326
- typename LayoutB,
327
- typename ElementC,
328
- typename LayoutC,
329
- typename ScalarType,
330
- typename ElementD = ElementC
331
- >
332
- void GemmComplex(
333
- gemm::GemmCoord problem_size,
334
- ScalarType alpha,
335
- TensorRef<ElementA, LayoutA> tensor_a,
336
- ComplexTransform transform_a,
337
- TensorRef<ElementB, LayoutB> tensor_b,
338
- ComplexTransform transform_b,
339
- ScalarType beta,
340
- TensorRef<ElementC, LayoutC> tensor_c,
341
- TensorRef<ElementD, LayoutC> tensor_d) {
342
-
343
- GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
344
- }
345
-
346
- ////////////////////////////////////////////////////////////////////////////////////////////////////
347
-
348
- } // namespace device
349
- } // namespace reference
350
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h DELETED
@@ -1,311 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for complex-valued GEMM in device code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/complex.h"
39
- #include "cutlass/matrix_coord.h"
40
- #include "cutlass/numeric_types.h"
41
- #include "cutlass/functional.h"
42
- #include "cutlass/numeric_conversion.h"
43
- #include "cutlass/tensor_ref_planar_complex.h"
44
-
45
- #include "cutlass/tensor_view.h"
46
- #include "cutlass/gemm/gemm.h"
47
-
48
- namespace cutlass {
49
- namespace reference {
50
- namespace device {
51
-
52
- ////////////////////////////////////////////////////////////////////////////////////////////////////
53
-
54
- namespace kernel {
55
-
56
- ////////////////////////////////////////////////////////////////////////////////////////////////////
57
-
58
- static int const kGemmPlanarComplexBlockSize = 4;
59
-
60
- template <
61
- typename ElementA,
62
- typename LayoutA,
63
- typename ElementB,
64
- typename LayoutB,
65
- typename ElementC,
66
- typename LayoutC,
67
- typename ScalarType,
68
- typename ComputeType,
69
- typename ConvertOp = NumericConverter<ElementC, ScalarType>,
70
- typename InnerProductOp = multiply_add<complex<ComputeType>>
71
- >
72
- __global__ void GemmPlanarComplex(
73
- gemm::GemmCoord problem_size,
74
- complex<ScalarType> alpha,
75
- TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
76
- ComplexTransform transform_a,
77
- TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
78
- ComplexTransform transform_b,
79
- complex<ScalarType> beta,
80
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
81
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
82
- complex<ComputeType> initial_accum) {
83
-
84
- int const kMblock = kGemmPlanarComplexBlockSize;
85
- int const kNblock = kGemmPlanarComplexBlockSize;
86
-
87
- using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
88
- using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
89
- using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
90
-
91
- // Note: batch is ignored.
92
- int const M = problem_size.m();
93
- int const N = problem_size.n();
94
- int const K = problem_size.k();
95
-
96
- ConvertOp convert_op;
97
- InnerProductOp inner_product_op;
98
-
99
- complex<ComputeType> accum[kMblock][kNblock];
100
-
101
- int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
102
- int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
103
-
104
- CUTLASS_PRAGMA_UNROLL
105
- for (int j = 0; j < kNblock; j++) {
106
- CUTLASS_PRAGMA_UNROLL
107
- for (int i = 0; i < kMblock; i++) {
108
- accum[i][j] = initial_accum;
109
- }
110
- }
111
-
112
- CUTLASS_PRAGMA_NO_UNROLL
113
- for (int k_block = 0; k_block < K; ++k_block) {
114
-
115
- CUTLASS_PRAGMA_UNROLL
116
- for (int j = 0; j < kNblock; j++) {
117
-
118
- CUTLASS_PRAGMA_UNROLL
119
- for (int i = 0; i < kMblock; i++) {
120
-
121
- int row = row_block + i;
122
- int col = col_block + j;
123
-
124
- if (row < M && col < N) {
125
-
126
- ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
127
- ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
128
-
129
- complex<ComputeType> a = complex<ComputeType>{
130
- ComputeType(a_ik.real()),
131
- ComputeType(a_ik.imag())
132
- };
133
-
134
- complex<ComputeType> b = complex<ComputeType>{
135
- ComputeType(b_kj.real()),
136
- ComputeType(b_kj.imag())
137
- };
138
-
139
- if (transform_a == ComplexTransform::kConjugate) {
140
- a = conj(a);
141
- }
142
-
143
- if (transform_b == ComplexTransform::kConjugate) {
144
- b = conj(b);
145
- }
146
-
147
- accum[i][j] = inner_product_op(a, b, accum[i][j]);
148
- }
149
- }
150
- }
151
- }
152
-
153
- CUTLASS_PRAGMA_UNROLL
154
- for (int j = 0; j < kNblock; j++) {
155
- CUTLASS_PRAGMA_UNROLL
156
- for (int i = 0; i < kMblock; i++) {
157
-
158
- int row = row_block + i;
159
- int col = col_block + j;
160
-
161
- MatrixCoord coord = MatrixCoord(row, col);
162
-
163
- if (row < M && col < N) {
164
-
165
- complex<ScalarType> acc{
166
- ScalarType(accum[i][j].real()),
167
- ScalarType(accum[i][j].imag())
168
- };
169
-
170
- ComplexC c_ij = ComplexC();
171
-
172
- if (beta.real() != ScalarType() || beta.imag() != ScalarType()) {
173
- c_ij = tensor_c.at(coord);
174
- }
175
-
176
- complex<ScalarType> src{
177
- ScalarType(c_ij.real()),
178
- ScalarType(c_ij.imag())
179
- };
180
-
181
- complex<ScalarType> result = alpha * acc + beta * src;
182
-
183
- ComplexC d_ij;
184
-
185
- d_ij.real() = convert_op(result.real());
186
- d_ij.imag() = convert_op(result.imag());
187
-
188
- tensor_d.at(coord) = d_ij;
189
- }
190
- }
191
- }
192
- }
193
-
194
- ////////////////////////////////////////////////////////////////////////////////////////////////////
195
-
196
- } // namespace kernel
197
-
198
- ////////////////////////////////////////////////////////////////////////////////////////////////////
199
-
200
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
201
- /// objects.
202
- ///
203
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
204
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
205
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
206
- /// arguments explicitly.
207
- template <
208
- typename ElementA,
209
- typename LayoutA,
210
- typename ElementB,
211
- typename LayoutB,
212
- typename ElementC,
213
- typename LayoutC,
214
- typename ScalarType,
215
- typename ComputeType,
216
- typename ConvertOp = NumericConverter<ElementC, ScalarType>,
217
- typename InnerProductOp = multiply_add<complex<ComputeType>>
218
- >
219
- void GemmPlanarComplex(
220
- gemm::GemmCoord problem_size,
221
- complex<ScalarType> alpha,
222
- TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
223
- ComplexTransform transform_a,
224
- TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
225
- ComplexTransform transform_b,
226
- complex<ScalarType> beta,
227
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
228
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
229
- complex<ComputeType> initial_accum) {
230
-
231
- static_assert(
232
- LayoutA::kRank == 2 &&
233
- LayoutB::kRank == 2 &&
234
- LayoutC::kRank == 2, "Tensors must be of rank 2");
235
-
236
- int const kMblock = kernel::kGemmPlanarComplexBlockSize;
237
- int const kNblock = kernel::kGemmPlanarComplexBlockSize;
238
-
239
- dim3 block(16, 8);
240
-
241
- dim3 grid(
242
- (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
243
- (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
244
- 1);
245
-
246
- kernel::GemmPlanarComplex<
247
- ElementA, LayoutA,
248
- ElementB, LayoutB,
249
- ElementC, LayoutC,
250
- ScalarType,
251
- ComputeType,
252
- ConvertOp,
253
- InnerProductOp
254
- ><<< grid, block >>>(
255
- problem_size,
256
- alpha,
257
- tensor_a,
258
- transform_a,
259
- tensor_b,
260
- transform_b,
261
- beta,
262
- tensor_c,
263
- tensor_d,
264
- initial_accum
265
- );
266
- }
267
-
268
- ////////////////////////////////////////////////////////////////////////////////////////////////////
269
-
270
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
271
- /// objects.
272
- ///
273
- /// This assumes the accumulator type is the same type as the scalars.
274
- template <
275
- typename ElementA,
276
- typename LayoutA,
277
- typename ElementB,
278
- typename LayoutB,
279
- typename ElementC,
280
- typename LayoutC,
281
- typename ScalarType
282
- >
283
- void GemmPlanarComplex(
284
- gemm::GemmCoord problem_size,
285
- complex<ScalarType> alpha,
286
- TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
287
- ComplexTransform transform_a,
288
- TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
289
- ComplexTransform transform_b,
290
- complex<ScalarType> beta,
291
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
292
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
293
-
294
- GemmPlanarComplex(
295
- problem_size,
296
- alpha,
297
- tensor_a, transform_a,
298
- tensor_b, transform_b,
299
- beta,
300
- tensor_c,
301
- tensor_d,
302
- complex<ScalarType>());
303
- }
304
-
305
- ////////////////////////////////////////////////////////////////////////////////////////////////////
306
-
307
- } // namespace device
308
- } // namespace reference
309
- } // namespace cutlass
310
-
311
- ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp DELETED
@@ -1,146 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief GETT device reference code
33
- */
34
- #pragma once
35
-
36
- #include <cute/tensor.hpp>
37
-
38
- namespace cutlass::reference::device {
39
-
40
- template <
41
- class ATensor,
42
- class BTensor,
43
- class CTensor,
44
- class DTensor,
45
- class ElementAccumulator,
46
- class ElementEpilogue>
47
- __global__ static
48
- void
49
- gett_kernel(
50
- DTensor D,
51
- ATensor const A,
52
- BTensor const B,
53
- CTensor const C,
54
- ElementEpilogue alpha, ElementEpilogue beta,
55
- ElementAccumulator acc_init)
56
- {
57
- using namespace cute;
58
-
59
- static_assert(DTensor::rank == 3, "(M,N,L)");
60
- static_assert(ATensor::rank == 3, "(M,K,L)");
61
- static_assert(BTensor::rank == 3, "(N,K,L)");
62
- static_assert(CTensor::rank == 3, "(M,N,L)");
63
-
64
- assert(size<0>(A) == size<0>(D)); // M
65
- assert(size<0>(C) == size<0>(D)); // M
66
- assert(size<0>(B) == size<1>(D)); // N
67
- assert(size<1>(C) == size<1>(D)); // N
68
- assert(size<1>(A) == size<1>(B)); // K
69
- assert(size<2>(A) == size<2>(D)); // L
70
- assert(size<2>(B) == size<2>(D)); // L
71
- assert(size<2>(C) == size<2>(D)); // L
72
-
73
- NumericConverter<ElementAccumulator, typename ATensor::value_type> a_converter;
74
- NumericConverter<ElementAccumulator, typename BTensor::value_type> b_converter;
75
- NumericConverter<ElementEpilogue, ElementAccumulator> acc_converter;
76
- NumericConverter<ElementEpilogue, typename CTensor::value_type> source_converter;
77
- NumericConverter<typename DTensor::value_type, ElementEpilogue> output_converter;
78
-
79
- // Thread id to each element of D
80
- for (int tid = threadIdx.x + blockDim.x * blockIdx.x;
81
- tid < size(D);
82
- tid += blockDim.x * gridDim.x) {
83
- // (m,n,l) coordinate
84
- auto mnl_coord = idx2crd(tid, product_each(shape(D)));
85
- auto m = get<0>(mnl_coord);
86
- auto n = get<1>(mnl_coord);
87
- auto l = get<2>(mnl_coord);
88
-
89
- auto A_ml = A(m,_,l);
90
- auto B_nl = B(n,_,l);
91
-
92
- ElementAccumulator accum = ElementAccumulator(0);
93
- for (int k = 0; k < size<1>(A); ++k) {
94
- ElementAccumulator a = a_converter(A_ml(k));
95
- ElementAccumulator b = b_converter(B_nl(k));
96
- accum += a * b;
97
- }
98
-
99
- ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l)));
100
- D(m,n,l) = output_converter(scaled_output);
101
- }
102
- }
103
-
104
- // Most general version
105
- template <
106
- class ProblemShapeMNKL,
107
- class ElementA,
108
- class StrideA,
109
- class ElementB,
110
- class StrideB,
111
- class ElementAccumulator,
112
- class ElementC,
113
- class StrideC,
114
- class ElementD,
115
- class StrideD,
116
- class ElementEpilogue>
117
- void
118
- gett(
119
- ProblemShapeMNKL problem_shape_mnkl,
120
- ElementA const* ptr_A, StrideA stride_a_mkl,
121
- ElementB const* ptr_B, StrideB stride_b_nkl,
122
- ElementAccumulator _,
123
- ElementC const* ptr_C, StrideC stride_c_mnl,
124
- ElementD * ptr_D, StrideD stride_d_mnl,
125
- ElementEpilogue alpha, ElementEpilogue beta,
126
- cudaStream_t stream = 0) {
127
- using namespace cute;
128
-
129
- static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
130
- auto M = get<0>(problem_shape_mnkl);
131
- auto N = get<1>(problem_shape_mnkl);
132
- auto K = get<2>(problem_shape_mnkl);
133
- auto L = get<3>(problem_shape_mnkl);
134
-
135
- // Represent the full tensors
136
- auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L)
137
- auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L)
138
- auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L)
139
- auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L)
140
-
141
- dim3 dimBlock(256);
142
- dim3 dimGrid(240);
143
- gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0));
144
- }
145
-
146
- } // namespace cutlass::reference::device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h DELETED
@@ -1,162 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for GEMM in host-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/tensor_view.h"
39
- #include "cutlass/gemm/gemm.h"
40
-
41
- #include "cutlass/util/reference/device/thread/gemm.h"
42
-
43
- namespace cutlass {
44
- namespace reference {
45
- namespace device {
46
- namespace kernel {
47
-
48
- ////////////////////////////////////////////////////////////////////////////////////////////////////
49
-
50
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
51
- /// objects.
52
- template <
53
- typename TensorRefA,
54
- typename TensorRefB,
55
- typename TensorRefC,
56
- typename ScalarType,
57
- typename AccumulatorType,
58
- typename OutputTile,
59
- typename InnerProductOp,
60
- typename ConvertOp
61
- >
62
- __global__ void Gemm(
63
- gemm::GemmCoord problem_size,
64
- ScalarType alpha,
65
- TensorRefA tensor_a,
66
- TensorRefB tensor_b,
67
- ScalarType beta,
68
- TensorRefC tensor_c,
69
- TensorRefC tensor_d,
70
- AccumulatorType initial_accum) {
71
-
72
- // Map each thread to a unique tile of the output matrix
73
- MatrixCoord output_coord(
74
- MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
75
- MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
76
- );
77
-
78
- // Compute the general matrix product
79
- thread::Gemm<
80
- TensorRefA,
81
- TensorRefB,
82
- TensorRefC,
83
- ScalarType,
84
- AccumulatorType,
85
- OutputTile,
86
- InnerProductOp,
87
- ConvertOp
88
- > gemm(initial_accum);
89
-
90
- gemm.multiply_add(
91
- problem_size,
92
- tensor_a,
93
- tensor_b,
94
- output_coord);
95
-
96
- gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord);
97
- }
98
-
99
- ////////////////////////////////////////////////////////////////////////////////////////////////////
100
-
101
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
102
- /// objects.
103
- template <
104
- typename TensorRefCollectionA,
105
- typename TensorRefCollectionB,
106
- typename TensorRefCollectionC,
107
- typename ScalarType,
108
- typename AccumulatorType,
109
- typename OutputTile,
110
- typename InnerProductOp,
111
- typename ConvertOp
112
- >
113
- __global__ void BatchedGemm(
114
- gemm::GemmCoord problem_size,
115
- ScalarType alpha,
116
- TensorRefCollectionA tensor_collection_a,
117
- TensorRefCollectionB tensor_collection_b,
118
- ScalarType beta,
119
- TensorRefCollectionC tensor_collection_c,
120
- AccumulatorType initial_accum) {
121
-
122
- // Obtain batch ID
123
- int batch_id = blockIdx.z;
124
-
125
- // Dereference based on batch_id
126
- typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
127
- typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
128
- typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
129
-
130
- // Map each thread to a unique tile of the output matrix
131
- MatrixCoord output_coord(
132
- (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn,
133
- (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow
134
- );
135
-
136
- // Compute the general matrix product
137
- thread::Gemm<
138
- typename TensorRefCollectionA::TensorRef,
139
- typename TensorRefCollectionB::TensorRef,
140
- typename TensorRefCollectionC::TensorRef,
141
- ScalarType,
142
- AccumulatorType,
143
- OutputTile,
144
- InnerProductOp,
145
- ConvertOp
146
- > gemm(initial_accum);
147
-
148
- gemm.multiply_add(
149
- problem_size,
150
- tensor_a,
151
- tensor_b,
152
- output_coord);
153
-
154
- gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
155
- }
156
-
157
- ////////////////////////////////////////////////////////////////////////////////////////////////////
158
-
159
- } // namespace kernel
160
- } // namespace device
161
- } // namespace reference
162
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h DELETED
@@ -1,168 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <curand_kernel.h>
35
-
36
- #include "cutlass/cutlass.h"
37
-
38
- namespace cutlass {
39
- namespace reference {
40
- namespace device {
41
- namespace kernel {
42
-
43
- ////////////////////////////////////////////////////////////////////////////////////////////////////
44
-
45
- /// Kernel to initialize tensor to uniform random distribution
46
- template <typename T>
47
- __global__ void TensorInitializeUniform(
48
- Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
49
- __shared__ curandState_t rng_state[1024];
50
-
51
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
52
-
53
- curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
54
-
55
- int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
56
- int s_idx = blockIdx.y * blockDim.x;
57
-
58
- tensor += s_idx * ldm + c_idx;
59
-
60
- for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
61
- if (s_idx < dim_strided && c_idx < dim_contiguous) {
62
- double range = dist.uniform.max - dist.uniform.min;
63
-
64
- double rnd = curand_uniform(&rng_state[threadIdx.x]);
65
-
66
- rnd = dist.uniform.min + range * rnd;
67
-
68
- // Random values are cast to integer after scaling by a power of two to facilitate error
69
- // testing
70
- if (dist.int_scale >= 0) {
71
- rnd = double(int(rnd * double(1 << dist.int_scale)));
72
- *tensor = T(rnd / double(1 << dist.int_scale));
73
- } else {
74
- *tensor = T(rnd);
75
- }
76
-
77
- tensor += ldm;
78
- }
79
- }
80
- }
81
-
82
- ///////////////////////////////////////////////////////////////////////////////////////////////////
83
-
84
- /// Kernel to initialize tensor to uniform distribution
85
- template <typename T>
86
- __global__ void TensorInitializeGaussian(
87
- Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
88
- __shared__ curandState_t rng_state[1024];
89
-
90
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
91
-
92
- curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
93
-
94
- int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
95
- int s_idx = blockIdx.y * blockDim.x;
96
-
97
- tensor += s_idx * ldm + c_idx;
98
-
99
- for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
100
- if (s_idx < dim_strided && c_idx < dim_contiguous) {
101
- // Random values are cast to integer after scaling by a power of two to facilitate error
102
- // testing
103
-
104
- double rnd = curand_normal(&rng_state[threadIdx.x]);
105
-
106
- rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd;
107
-
108
- if (dist.int_scale >= 0) {
109
- rnd = double(int(rnd * double(1 << dist.int_scale)));
110
- *tensor = T(rnd / double(1 << dist.int_scale));
111
- } else {
112
- *tensor = T(rnd);
113
- }
114
- }
115
- }
116
- }
117
-
118
- /// Kernel to initialize tensor to an identity matrix
119
- template <typename T>
120
- __global__ void TensorInitializeLinear(
121
- Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
122
- __shared__ curandState_t rng_state[1024];
123
-
124
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
125
-
126
- curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
127
-
128
- int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
129
- int s_idx = blockIdx.y * blockDim.x;
130
-
131
- tensor += s_idx * ldm + c_idx;
132
-
133
- for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
134
- if (s_idx < dim_strided && c_idx < dim_contiguous) {
135
- *tensor =
136
- dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx;
137
- }
138
- }
139
- }
140
-
141
- /// Kernel to initialize tensor to an identity matrix
142
- template <typename T>
143
- __global__ void TensorInitializeIdentity(
144
- Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
145
- __shared__ curandState_t rng_state[1024];
146
-
147
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
148
-
149
- curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
150
-
151
- int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
152
- int s_idx = blockIdx.y * blockDim.x;
153
-
154
- tensor += s_idx * ldm + c_idx;
155
-
156
- for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
157
- if (s_idx < dim_strided && c_idx < dim_contiguous) {
158
- *tensor = (c_idx == s_idx ? T(1) : T(0));
159
- }
160
- }
161
- }
162
-
163
- ////////////////////////////////////////////////////////////////////////////////////////////////////
164
-
165
- } // namespace kernel
166
- } // namespace device
167
- } // namespace reference
168
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h DELETED
@@ -1,159 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/coord.h"
36
- #include "cutlass/subbyte_reference.h"
37
- #include "cutlass/fast_math.h"
38
-
39
- namespace cutlass {
40
- namespace reference {
41
- namespace device {
42
- namespace kernel {
43
-
44
- ///////////////////////////////////////////////////////////////////////////////////////////////////
45
-
46
- /// Defines several helpers
47
- namespace detail {
48
-
49
- /// Helper to perform for-each operation
50
- template <typename Func, int Rank, int RankRemaining>
51
- struct TensorForEachHelper {
52
-
53
- /// Constructor for general rank
54
- __inline__ __device__
55
- TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
56
-
57
- int64_t product = 1;
58
-
59
- CUTLASS_PRAGMA_UNROLL
60
- for (int i = Rank - RankRemaining; i < Rank; ++i) {
61
- product *= size[i];
62
- }
63
-
64
- coord[Rank - 1 - RankRemaining] = index / product;
65
- int64_t remaining = index % product;
66
-
67
- TensorForEachHelper<Func, Rank, RankRemaining-1>(func, size, coord, remaining);
68
- }
69
- };
70
-
71
- /// Helper to perform for-each operation
72
- template <typename Func, int Rank>
73
- struct TensorForEachHelper<Func, Rank, 0> {
74
-
75
- /// Constructor for fastest changing rank
76
- __inline__ __device__
77
- TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
78
-
79
- coord[Rank - 1] = index;
80
-
81
- if (coord < size) {
82
- func(coord);
83
- }
84
- }
85
- };
86
-
87
- } // namespace detail
88
-
89
- ///////////////////////////////////////////////////////////////////////////////////////////////////
90
-
91
- /// Kernel calls a functor for each element in a tensor's index space
92
- template <typename Func, int Rank, typename Params>
93
- __global__ void TensorForEach(Coord<Rank> size, Params params = Params()) {
94
-
95
- Func func(params);
96
-
97
- int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
98
- int64_t max_index = 1;
99
-
100
- CUTLASS_PRAGMA_UNROLL
101
- for (int i = 0; i < Rank; ++i) {
102
- max_index *= size[i];
103
- }
104
-
105
- CUTLASS_PRAGMA_NO_UNROLL
106
- while (index < max_index) {
107
- Coord<Rank> coord;
108
-
109
- detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord, index);
110
- index += blockDim.x * gridDim.x;
111
- }
112
- }
113
-
114
- ///////////////////////////////////////////////////////////////////////////////////////////////////
115
-
116
- /// Kernel calls a functor for each element along a tensor's diagonal
117
- template <typename Func, int Rank, typename Params>
118
- __global__ void TensorDiagonalForEach(Coord<Rank> size, Params params, int start, int end) {
119
-
120
- Func func(params);
121
-
122
- int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start;
123
-
124
- if (index < end) {
125
- Coord<Rank> coord;
126
-
127
- CUTLASS_PRAGMA_UNROLL
128
- for (int i = 0; i < Rank; ++i) {
129
- coord[i] = index;
130
- }
131
-
132
- func(coord);
133
- }
134
- }
135
-
136
- ///////////////////////////////////////////////////////////////////////////////////////////////////
137
-
138
- template <typename Element, typename Func>
139
- __global__ void BlockForEach(
140
- Element *ptr,
141
- size_t capacity,
142
- typename Func::Params params) {
143
-
144
- Func func(params);
145
-
146
- size_t index = threadIdx.x + blockIdx.x * blockDim.x;
147
-
148
- for (; index < capacity; index += blockDim.x * gridDim.x) {
149
- ReferenceFactory<Element>::get(ptr, index) = func();
150
- }
151
- }
152
-
153
- ///////////////////////////////////////////////////////////////////////////////////////////////////
154
-
155
- } // namespace kernel
156
- } // namespace device
157
- } // namespace reference
158
- } // namespace cutlass
159
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h DELETED
@@ -1,355 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for complex-valued GEMM in device-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/blas3.h"
38
- #include "cutlass/complex.h"
39
- #include "cutlass/numeric_conversion.h"
40
- #include "cutlass/tensor_view.h"
41
- #include "cutlass/gemm/gemm.h"
42
-
43
- namespace cutlass {
44
- namespace reference {
45
- namespace device {
46
-
47
- ////////////////////////////////////////////////////////////////////////////////////////////////////
48
-
49
- namespace kernel {
50
-
51
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
52
- /// objects.
53
- ///
54
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
55
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
56
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
57
- /// arguments explicitly.
58
- template <
59
- typename ElementA,
60
- typename LayoutA,
61
- typename ElementB,
62
- typename LayoutB,
63
- typename ElementC,
64
- typename LayoutC,
65
- typename ScalarType,
66
- typename ComputeType,
67
- typename ConvertOp = NumericConverter<ElementC, ScalarType>,
68
- typename InnerProductOp = multiply_add<ComputeType>,
69
- int kMblock = 4,
70
- int kNblock = 4
71
- >
72
- __global__ void Rank2KComplex(
73
- gemm::GemmCoord problem_size,
74
- ScalarType alpha,
75
- TensorRef<ElementA, LayoutA> tensor_a,
76
- ComplexTransform transform_a,
77
- TensorRef<ElementB, LayoutB> tensor_b,
78
- ComplexTransform transform_b,
79
- ScalarType beta,
80
- TensorRef<ElementC, LayoutC> tensor_c,
81
- TensorRef<ElementC, LayoutC> tensor_d,
82
- ComputeType initial_accum,
83
- FillMode fill_mode_c,
84
- BlasMode blas_mode,
85
- int batch_count = 1,
86
- int64_t batch_stride_A = 0,
87
- int64_t batch_stride_B = 0,
88
- int64_t batch_stride_C = 0,
89
- int64_t batch_stride_D = 0) {
90
-
91
- static_assert(
92
- LayoutA::kRank == 2 &&
93
- LayoutB::kRank == 2 &&
94
- LayoutC::kRank == 2, "Tensors must be of rank 2");
95
-
96
- int const M = problem_size.m();
97
- int const N = problem_size.n();
98
- int const K = problem_size.k();
99
-
100
- assert(M=N);
101
-
102
- ConvertOp convert_op;
103
- InnerProductOp inner_product_op;
104
-
105
- int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
106
- int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
107
- int batch_idx = blockIdx.z;
108
-
109
- tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
110
- tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
111
- tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
112
- tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
113
-
114
- for (; batch_idx < batch_count; batch_idx += gridDim.z) {
115
-
116
- // Compute matrix product using blocks
117
- ComputeType accum[kMblock][kNblock];
118
-
119
- CUTLASS_PRAGMA_UNROLL
120
- for (int j = 0; j < kNblock; j++) {
121
- CUTLASS_PRAGMA_UNROLL
122
- for (int i = 0; i < kMblock; i++) {
123
- accum[i][j] = initial_accum;
124
- }
125
- }
126
-
127
- for (int k_block = 0; k_block < K; ++k_block) {
128
- CUTLASS_PRAGMA_UNROLL
129
- for (int j = 0; j < kNblock; j++) {
130
- CUTLASS_PRAGMA_UNROLL
131
- for (int i = 0; i < kMblock; i++) {
132
- int row = row_block + i;
133
- int col = col_block + j;
134
-
135
- if (row < M && col < N &&
136
- ( (fill_mode_c == FillMode::kLower && row >= col) ||
137
- (fill_mode_c == FillMode::kUpper && row <= col) )
138
- ) {
139
-
140
- // A x B^T (Symmetric) or A x B^H (Hermitian)
141
- // complex conjugation on operandB (b_t) is function of blas3 computation
142
- ElementA a = tensor_a.at(MatrixCoord(row, k_block));
143
- ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
144
- conj(tensor_b.at(MatrixCoord(col, k_block))) :
145
- tensor_b.at(MatrixCoord(col, k_block));
146
-
147
- ComputeType a_ik = ComputeType(a);
148
- ComputeType b_jk = ComputeType(b_t);
149
-
150
- // complex conjugation is a function of operand layouts
151
- if (transform_a == ComplexTransform::kConjugate) {
152
- a_ik = conj(a_ik);
153
- }
154
- // complex conjugation is a function of operand layouts
155
- if (transform_b == ComplexTransform::kConjugate) {
156
- b_jk = conj(b_jk);
157
- }
158
-
159
- accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
160
-
161
- // B x A^T (Symmetric) or B x A^H (Hermitian)
162
- // complex conjugation on operandB (a_t) is function of blas3 computation
163
- ElementB b = tensor_b.at(MatrixCoord(row, k_block));
164
- ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
165
- conj(tensor_a.at(MatrixCoord(col, k_block))):
166
- tensor_a.at(MatrixCoord(col, k_block));
167
-
168
- ComputeType b_ik = ComputeType(b);
169
- ComputeType a_jk = ComputeType(a_t);
170
-
171
- // complex conjugation here is a function of operand layouts
172
- if (transform_b == ComplexTransform::kConjugate) {
173
- b_ik = conj(b_ik);
174
- }
175
- // complex conjugation here is a function of operand layouts
176
- if (transform_a == ComplexTransform::kConjugate) {
177
- a_jk = conj(a_jk);
178
- }
179
-
180
- accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
181
- }
182
- }
183
- }
184
- }
185
-
186
- CUTLASS_PRAGMA_UNROLL
187
- for (int j = 0; j < kNblock; j++) {
188
- CUTLASS_PRAGMA_UNROLL
189
- for (int i = 0; i < kMblock; i++) {
190
- int row = row_block + i;
191
- int col = col_block + j;
192
-
193
- MatrixCoord coord = MatrixCoord(row, col);
194
-
195
- if (row < M && col < N &&
196
- ((fill_mode_c == FillMode::kLower && row >= col) ||
197
- (fill_mode_c == FillMode::kUpper && row <= col))
198
- ) {
199
-
200
- ScalarType c = tensor_c.at(coord);
201
- // The imaginary parts of the diagonal elements of
202
- // a complex data type are assumed and set to zero
203
- if (blas_mode == BlasMode::kHermitian) {
204
- c = (row == col) ? real(c) : c;
205
- }
206
-
207
- tensor_d.at(coord) = convert_op(
208
- alpha * ScalarType(accum[i][j]) +
209
- beta * c);
210
- }
211
- }
212
- }
213
-
214
- tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
215
- tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
216
- tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
217
- tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
218
-
219
- } // for (batch_idx)
220
- }
221
-
222
- } // namespace kernel
223
-
224
- ////////////////////////////////////////////////////////////////////////////////////////////////////
225
-
226
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
227
- /// objects.
228
- ///
229
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
230
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
231
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
232
- /// arguments explicitly.
233
- template <
234
- typename ElementA,
235
- typename LayoutA,
236
- typename ElementB,
237
- typename LayoutB,
238
- typename ElementC,
239
- typename LayoutC,
240
- typename ScalarType,
241
- typename ComputeType,
242
- typename ConvertOp = NumericConverter<ElementC, ScalarType>,
243
- typename InnerProductOp = multiply_add<ComputeType>
244
- >
245
- void Rank2KComplex(
246
- gemm::GemmCoord problem_size,
247
- ScalarType alpha,
248
- TensorRef<ElementA, LayoutA> tensor_a,
249
- ComplexTransform transform_a,
250
- TensorRef<ElementB, LayoutB> tensor_b,
251
- ComplexTransform transform_b,
252
- ScalarType beta,
253
- TensorRef<ElementC, LayoutC> tensor_c,
254
- TensorRef<ElementC, LayoutC> tensor_d,
255
- ComputeType initial_accum,
256
- FillMode fill_mode_c,
257
- BlasMode blas_mode,
258
- int batch_count = 1,
259
- int64_t batch_stride_A = 0,
260
- int64_t batch_stride_B = 0,
261
- int64_t batch_stride_C = 0,
262
- int64_t batch_stride_D = 0) {
263
-
264
- static_assert(
265
- LayoutA::kRank == 2 &&
266
- LayoutB::kRank == 2 &&
267
- LayoutC::kRank == 2, "Tensors must be of rank 2");
268
-
269
- int const kMblock = 4;
270
- int const kNblock = 4;
271
-
272
- dim3 block(16, 8);
273
- dim3 grid(
274
- (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
275
- (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
276
- batch_count % std::numeric_limits<uint16_t>::max()
277
- );
278
-
279
- kernel::Rank2KComplex<
280
- ElementA,
281
- LayoutA,
282
- ElementB,
283
- LayoutB,
284
- ElementC,
285
- LayoutC,
286
- ScalarType,
287
- ComputeType,
288
- ConvertOp,
289
- InnerProductOp,
290
- kMblock,
291
- kNblock
292
- ><<< grid, block >>>(
293
- problem_size,
294
- alpha,
295
- tensor_a,
296
- transform_a,
297
- tensor_b,
298
- transform_b,
299
- beta,
300
- tensor_c,
301
- tensor_d,
302
- initial_accum,
303
- fill_mode_c,
304
- blas_mode,
305
- batch_count,
306
- batch_stride_A,
307
- batch_stride_B,
308
- batch_stride_C,
309
- batch_stride_D
310
- );
311
- }
312
-
313
- ////////////////////////////////////////////////////////////////////////////////////////////////////
314
-
315
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
316
- /// objects.
317
- ///
318
- /// This assumes the accumulator type is the same type as the scalars.
319
- template <
320
- typename ElementA,
321
- typename LayoutA,
322
- typename ElementB,
323
- typename LayoutB,
324
- typename ElementC,
325
- typename LayoutC,
326
- typename ScalarType
327
- >
328
- void Rank2KComplex(
329
- gemm::GemmCoord problem_size,
330
- ScalarType alpha,
331
- TensorRef<ElementA, LayoutA> tensor_a,
332
- ComplexTransform transform_a,
333
- TensorRef<ElementB, LayoutB> tensor_b,
334
- ComplexTransform transform_b,
335
- ScalarType beta,
336
- TensorRef<ElementC, LayoutC> tensor_c,
337
- TensorRef<ElementC, LayoutC> tensor_d,
338
- FillMode fill_mode_c,
339
- BlasMode blas_mode) {
340
-
341
- Rank2KComplex(
342
- problem_size, alpha,
343
- tensor_a, transform_a,
344
- tensor_b, transform_b,
345
- beta, tensor_c, tensor_d,
346
- ScalarType(0),
347
- fill_mode_c,
348
- blas_mode);
349
- }
350
-
351
- ////////////////////////////////////////////////////////////////////////////////////////////////////
352
-
353
- } // namespace device
354
- } // namespace reference
355
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h DELETED
@@ -1,250 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /* \file
32
- \brief Defines host-side elementwise operations on TensorView.
33
- */
34
-
35
- #pragma once
36
- // Standard Library includes
37
- #include <utility>
38
-
39
- // Cutlass includes
40
- #include "cutlass/cutlass.h"
41
- #include "cutlass/relatively_equal.h"
42
-
43
- #include "cutlass/util/distribution.h"
44
-
45
- #include "tensor_foreach.h"
46
-
47
- namespace cutlass {
48
- namespace reference {
49
- namespace device {
50
-
51
- ///////////////////////////////////////////////////////////////////////////////////////////////////
52
-
53
- namespace kernel {
54
-
55
- template <typename Element>
56
- __global__ void BlockCompareEqual(
57
- int *equal,
58
- Element const *ptr_A,
59
- Element const *ptr_B,
60
- size_t capacity) {
61
-
62
- size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
63
-
64
- for (; idx < capacity; idx += gridDim.x * blockDim.x) {
65
-
66
- Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
67
- Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
68
-
69
- if (a != b) {
70
- *equal = 0;
71
-
72
- return;
73
- }
74
- }
75
- }
76
-
77
- template <typename Element>
78
- __global__ void BlockCompareRelativelyEqual(
79
- int *equal,
80
- Element const *ptr_A,
81
- Element const *ptr_B,
82
- size_t capacity,
83
- Element epsilon,
84
- Element nonzero_floor) {
85
-
86
- size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
87
-
88
- for (; idx < capacity; idx += gridDim.x * blockDim.x) {
89
-
90
- Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
91
- Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
92
-
93
- if (!relatively_equal(a, b, epsilon, nonzero_floor)) {
94
- *equal = 0;
95
- return;
96
- }
97
- }
98
- }
99
-
100
- } // namespace kernel
101
-
102
-
103
- ///////////////////////////////////////////////////////////////////////////////////////////////////
104
-
105
- /// Performs a bit-level equality check between two blocks
106
- template <typename Element>
107
- bool BlockCompareEqual(
108
- Element const *ptr_A,
109
- Element const *ptr_B,
110
- size_t capacity,
111
- int grid_size = 0,
112
- int block_size = 0,
113
- cudaStream_t stream = nullptr) {
114
-
115
- int equal_flag = 1;
116
- int *device_equal_flag = nullptr;
117
-
118
- if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
119
- throw std::runtime_error("Failed to allocate device flag.");
120
- }
121
-
122
- if (cudaMemcpy(
123
- device_equal_flag,
124
- &equal_flag,
125
- sizeof(int),
126
- cudaMemcpyHostToDevice) != cudaSuccess) {
127
-
128
- throw std::runtime_error("Failed to copy equality flag to device.");
129
- }
130
-
131
- if (!grid_size || !block_size) {
132
-
133
- // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
134
- cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
135
- &grid_size,
136
- &block_size,
137
- reinterpret_cast<void const *>(kernel::BlockCompareEqual<Element>));
138
-
139
- if (result != cudaSuccess) {
140
- throw std::runtime_error("Failed to query occupancy.");
141
- }
142
- // Limit block size. This has the effect of increasing the number of items processed by a
143
- // single thread and reduces the impact of initialization overhead.
144
- block_size = (block_size < 128 ? block_size : 128);
145
- }
146
-
147
- dim3 grid(grid_size, 1, 1);
148
- dim3 block(block_size, 1, 1);
149
-
150
- kernel::BlockCompareEqual<Element><<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity);
151
-
152
- cudaStreamSynchronize(stream);
153
-
154
- if (cudaMemcpy(
155
- &equal_flag,
156
- device_equal_flag,
157
- sizeof(int),
158
- cudaMemcpyDeviceToHost) != cudaSuccess) {
159
-
160
- cudaFree(device_equal_flag);
161
-
162
- throw std::runtime_error("Failed to copy equality flag from device.");
163
- }
164
-
165
- cudaFree(device_equal_flag);
166
-
167
- return equal_flag;
168
- }
169
-
170
- ///////////////////////////////////////////////////////////////////////////////////////////////////
171
-
172
- /// Performs a bit-level equality check between two blocks
173
- template <typename Element>
174
- bool BlockCompareRelativelyEqual(
175
- Element const *ptr_A,
176
- Element const *ptr_B,
177
- size_t capacity,
178
- Element epsilon,
179
- Element nonzero_floor,
180
- int grid_size = 0,
181
- int block_size = 0,
182
- cudaStream_t stream = nullptr) {
183
-
184
- int equal_flag = 1;
185
- int *device_equal_flag = nullptr;
186
-
187
- if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
188
- throw std::runtime_error("Failed to allocate device flag.");
189
- }
190
-
191
- if (cudaMemcpy(
192
- device_equal_flag,
193
- &equal_flag,
194
- sizeof(int),
195
- cudaMemcpyHostToDevice) != cudaSuccess) {
196
-
197
- throw std::runtime_error("Failed to copy equality flag to device.");
198
- }
199
-
200
- if (!grid_size || !block_size) {
201
-
202
- // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
203
- cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
204
- &grid_size,
205
- &block_size,
206
- reinterpret_cast<void const *>(kernel::BlockCompareRelativelyEqual<Element>));
207
-
208
- if (result != cudaSuccess) {
209
- throw std::runtime_error("Failed to query occupancy.");
210
- }
211
- // Limit block size. This has the effect of increasing the number of items processed by a
212
- // single thread and reduces the impact of initialization overhead.
213
- block_size = (block_size < 128 ? block_size : 128);
214
- }
215
-
216
- dim3 grid(grid_size, 1, 1);
217
- dim3 block(block_size, 1, 1);
218
-
219
- kernel::BlockCompareRelativelyEqual<Element><<< grid, block, 0, stream >>>(
220
- device_equal_flag,
221
- ptr_A,
222
- ptr_B,
223
- capacity,
224
- epsilon,
225
- nonzero_floor
226
- );
227
-
228
- cudaStreamSynchronize(stream);
229
-
230
- if (cudaMemcpy(
231
- &equal_flag,
232
- device_equal_flag,
233
- sizeof(int),
234
- cudaMemcpyDeviceToHost) != cudaSuccess) {
235
-
236
- cudaFree(device_equal_flag);
237
-
238
- throw std::runtime_error("Failed to copy equality flag from device.");
239
- }
240
-
241
- cudaFree(device_equal_flag);
242
-
243
- return equal_flag;
244
- }
245
-
246
- ///////////////////////////////////////////////////////////////////////////////////////////////////
247
-
248
- } // device
249
- } // reference
250
- } // cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h DELETED
@@ -1,2075 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /* \file
32
- \brief Defines device-side elementwise operations on TensorView. Note, the operations defined
33
- in this header are not specialized for any particular data layout and are therefore not
34
- intended to offer the best possible performance. Rather, they are intended to be generic
35
- reference implementations to support the CUTLASS unit tests.
36
- */
37
-
38
- #pragma once
39
-
40
- #if !defined(__CUDACC_RTC__)
41
-
42
- // Standard Library includes
43
- #include <utility>
44
- #include <cstdlib>
45
- #include <cmath>
46
- #include <type_traits>
47
- #include <cstdint>
48
-
49
- #endif
50
-
51
- // CUDA includes
52
- #include <curand_kernel.h>
53
-
54
- // Cutlass includes
55
- #include "cutlass/cutlass.h"
56
- #include "cutlass/array.h"
57
- #include "cutlass/complex.h"
58
- #include "cutlass/tensor_view.h"
59
- #include "cutlass/blas3.h"
60
- #include "cutlass/numeric_types.h"
61
-
62
- #include "cutlass/layout/vector.h"
63
-
64
- #include "cutlass/util/reference/device/tensor_foreach.h"
65
- #include "cutlass/util/distribution.h"
66
-
67
- ///////////////////////////////////////////////////////////////////////////////////////////////////
68
-
69
- namespace cutlass {
70
- namespace reference {
71
- namespace device {
72
-
73
- ///////////////////////////////////////////////////////////////////////////////////////////////////
74
- ///////////////////////////////////////////////////////////////////////////////////////////////////
75
-
76
- namespace detail {
77
-
78
- template <typename FloatType>
79
- CUTLASS_DEVICE
80
- FloatType random_normal_float(curandState_t *state) {
81
- return curand_normal(state);
82
- }
83
-
84
- template <>
85
- CUTLASS_DEVICE
86
- double random_normal_float<double>(curandState_t *state) {
87
- return curand_normal_double(state);
88
- }
89
-
90
- template <typename FloatType>
91
- CUTLASS_DEVICE
92
- FloatType random_uniform_float(curandState_t *state) {
93
- return curand_uniform(state);
94
- }
95
-
96
- template <>
97
- CUTLASS_DEVICE
98
- double random_uniform_float<double>(curandState_t *state) {
99
- return curand_uniform_double(state);
100
- }
101
-
102
- template <typename Element>
103
- struct RandomGaussianFunc {
104
-
105
- using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type;
106
- using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type;
107
-
108
- /// Parameters structure
109
- struct Params {
110
-
111
- //
112
- // Data members
113
- //
114
-
115
- uint64_t seed;
116
- FloatType mean;
117
- FloatType stddev;
118
- int int_scale;
119
- FloatType float_scale_up;
120
- FloatType float_scale_down;
121
- int exclude_zero; ///< If non-negative, excludes zeros
122
-
123
- //
124
- // Methods
125
- //
126
-
127
- /// Construction of Gaussian RNG functor.
128
- Params(
129
- uint64_t seed_ = 0,
130
- Element mean_ = 0,
131
- Element stddev_ = 1,
132
- int int_scale_ = -1,
133
- int exclude_zero_ = -1
134
- ):
135
- seed(seed_),
136
- mean(static_cast<FloatType>(mean_)),
137
- stddev(static_cast<FloatType>(stddev_)),
138
- int_scale(int_scale_),
139
- exclude_zero(exclude_zero_) {
140
-
141
- float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
142
- float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
143
- }
144
- };
145
-
146
- //
147
- // Data members
148
- //
149
-
150
- /// Parameters object
151
- Params params;
152
-
153
- /// RNG state object
154
- curandState_t rng_state;
155
-
156
- //
157
- // Methods
158
- //
159
-
160
- /// Device-side initialization of RNG
161
- CUTLASS_DEVICE
162
- RandomGaussianFunc(Params const &params): params(params) {
163
-
164
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
165
-
166
- curand_init(params.seed, gtid, 0, &rng_state);
167
- }
168
-
169
- /// Compute random value and update RNG state
170
- CUTLASS_DEVICE
171
- Element operator()() {
172
-
173
- FloatType rnd = random_normal_float<FloatType>(&rng_state);
174
- rnd = params.mean + params.stddev * rnd;
175
-
176
- Element result;
177
- if (params.int_scale >= 0) {
178
- rnd = FloatType(std::llround(rnd * params.float_scale_up));
179
- result = Element(rnd * params.float_scale_down);
180
- }
181
- else {
182
- result = Element(rnd);
183
- }
184
-
185
- if (params.exclude_zero >=0 && result == Element(0.0)) {
186
- if (rnd > FloatType(0)) {
187
- rnd += FloatType(1);
188
- } else {
189
- rnd -= FloatType(1);
190
- }
191
- result = Element(rnd);
192
- }
193
-
194
- return result;
195
- }
196
- };
197
-
198
-
199
- template <typename Real>
200
- struct RandomGaussianFunc<complex<Real>> {
201
-
202
- using Element = complex<Real>;
203
- using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type;
204
- using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type;
205
-
206
- /// Parameters structure
207
- struct Params {
208
-
209
- //
210
- // Data members
211
- //
212
-
213
- uint64_t seed;
214
- FloatType mean;
215
- FloatType stddev;
216
- int int_scale;
217
- FloatType float_scale_up;
218
- FloatType float_scale_down;
219
- int exclude_zero; ///< If non-negative, excludes zeros
220
-
221
- //
222
- // Methods
223
- //
224
-
225
- /// Construction of Gaussian RNG functor.
226
- Params(
227
- uint64_t seed_ = 0,
228
- Real mean_ = 0,
229
- Real stddev_ = 1,
230
- int int_scale_ = -1,
231
- int exclude_zero_ = -1
232
- ):
233
- seed(seed_),
234
- mean(static_cast<FloatType>(mean_)),
235
- stddev(static_cast<FloatType>(stddev_)),
236
- int_scale(int_scale_),
237
- exclude_zero(exclude_zero_) {
238
-
239
- float_scale_up = FloatType(IntType(1) << int_scale);
240
- float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
241
- }
242
- };
243
-
244
- //
245
- // Data members
246
- //
247
-
248
- /// Parameters object
249
- Params params;
250
-
251
- /// RNG state object
252
- curandState_t rng_state;
253
-
254
- //
255
- // Methods
256
- //
257
-
258
- /// Device-side initialization of RNG
259
- CUTLASS_DEVICE
260
- RandomGaussianFunc(Params const &params): params(params) {
261
-
262
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
263
-
264
- curand_init(params.seed, gtid, 0, &rng_state);
265
- }
266
-
267
- /// Compute random value and update RNG state
268
- CUTLASS_DEVICE
269
- Element operator()() {
270
-
271
- FloatType rnd_r = random_normal_float<FloatType>(&rng_state);
272
- FloatType rnd_i = random_normal_float<FloatType>(&rng_state);
273
- rnd_r = params.mean + params.stddev * rnd_r;
274
- rnd_i = params.mean + params.stddev * rnd_i;
275
-
276
- Element result;
277
- if (params.int_scale >= 0) {
278
- rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
279
- rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
280
-
281
- result = {
282
- Real(rnd_r * params.float_scale_down),
283
- Real(rnd_i * params.float_scale_down)
284
- };
285
- }
286
- else {
287
- result = Element(Real(rnd_r), Real(rnd_i));
288
- }
289
-
290
- if (params.exclude_zero >= 0 &&
291
- result.real() == Real(0.0) &&
292
- result.imag() == Real(0.0)) {
293
-
294
- if (rnd_r > FloatType(0)) {
295
- rnd_r += FloatType(1);
296
- } else {
297
- rnd_r -= FloatType(1);
298
- }
299
- result = Element(Real(rnd_r), Real(rnd_i));
300
- }
301
-
302
- return result;
303
- }
304
- };
305
-
306
- /// Computes a random Gaussian distribution
307
- template <
308
- typename Element, ///< Element type
309
- typename Layout> ///< Layout function
310
- struct TensorFillRandomGaussianFunc {
311
-
312
- /// View type
313
- using TensorView = TensorView<Element, Layout>;
314
-
315
- /// Scalar type
316
- typedef typename TensorView::Element T;
317
-
318
- /// Coordinate in tensor's index space
319
- typedef typename TensorView::TensorCoord TensorCoord;
320
-
321
- using RandomFunc = RandomGaussianFunc<Element>;
322
-
323
- /// Parameters structure
324
- struct Params {
325
-
326
- //
327
- // Data members
328
- //
329
-
330
- TensorView view;
331
- typename RandomFunc::Params random;
332
-
333
- //
334
- // Methods
335
- //
336
-
337
- /// Construction of Gaussian RNG functor.
338
- Params(
339
- TensorView view_ = TensorView(),
340
- typename RandomFunc::Params random_ = typename RandomFunc::Params()
341
- ):
342
- view(view_), random(random_) {
343
-
344
- }
345
- };
346
-
347
- //
348
- // Data members
349
- //
350
-
351
- Params params;
352
- RandomFunc random;
353
-
354
- //
355
- // Methods
356
- //
357
-
358
- /// Device-side initialization of RNG
359
- CUTLASS_DEVICE
360
- TensorFillRandomGaussianFunc(Params const &params): params(params), random(params.random) {
361
-
362
- }
363
-
364
- /// Compute random value and update RNG state
365
- CUTLASS_DEVICE
366
- void operator()(TensorCoord const &coord) {
367
-
368
- params.view.at(coord) = random();
369
- }
370
- };
371
-
372
- } // namespace detail
373
-
374
- ///////////////////////////////////////////////////////////////////////////////////////////////////
375
-
376
- /// Fills a tensor with random values with a Gaussian distribution.
377
- template <
378
- typename Element, ///< Element type
379
- typename Layout> ///< Layout function
380
- void TensorFillRandomGaussian(
381
- TensorView<Element, Layout> view, ///< destination tensor
382
- uint64_t seed, ///< seed for RNG
383
- typename RealType<Element>::Type mean = Element(0), ///< Gaussian distribution's mean
384
- typename RealType<Element>::Type stddev = Element(1), ///< Gaussian distribution's standard deviation
385
- int bits = -1, ///< If non-negative, specifies number of fractional bits that
386
- /// are not truncated to zero. Permits reducing precision of
387
- /// data.
388
- int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
389
- cudaStream_t stream = nullptr) {
390
-
391
- using RandomFunc = detail::RandomGaussianFunc<Element>;
392
- using Func = detail::TensorFillRandomGaussianFunc<Element, Layout>;
393
- using Params = typename Func::Params;
394
-
395
- TensorForEach<Func, Layout::kRank, Params>(
396
- view.extent(),
397
- Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)),
398
- /*grid_size*/0, /*block_size*/0,
399
- stream
400
- );
401
- }
402
-
403
- ///////////////////////////////////////////////////////////////////////////////////////////////////
404
-
405
- /// Fills a tensor with random values with a Gaussian distribution.
406
- template <typename Element> ///< Element type
407
- void BlockFillRandomGaussian(
408
- Element *ptr,
409
- size_t capacity,
410
- uint64_t seed, ///< seed for RNG
411
- typename RealType<Element>::Type mean, ///< Gaussian distribution's mean
412
- typename RealType<Element>::Type stddev, ///< Gaussian distribution's standard deviation
413
- int bits = -1, ///< If non-negative, specifies number of fractional bits that
414
- /// are not truncated to zero. Permits reducing precision of
415
- /// data.
416
- cudaStream_t stream = nullptr) {
417
-
418
- using RandomFunc = detail::RandomGaussianFunc<Element>;
419
-
420
- typename RandomFunc::Params params(seed, mean, stddev, bits);
421
-
422
- BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
423
- }
424
-
425
- ///////////////////////////////////////////////////////////////////////////////////////////////////
426
- ///////////////////////////////////////////////////////////////////////////////////////////////////
427
-
428
- namespace detail {
429
-
430
- /// Computes a random uniform distribution
431
- template <typename Element> ///< Element type
432
- struct RandomUniformFunc {
433
-
434
- using FloatType = typename std::conditional<
435
- (sizeof(Element) > 4),
436
- double,
437
- float>::type;
438
-
439
- using IntType = typename std::conditional<
440
- (sizeof(Element) > 4),
441
- int64_t,
442
- int>::type;
443
-
444
- /// Parameters structure
445
- struct Params {
446
-
447
- //
448
- // Data members
449
- //
450
-
451
- uint64_t seed;
452
- FloatType range;
453
- FloatType max;
454
- int int_scale;
455
- double pnan;
456
- FloatType float_scale_up;
457
- FloatType float_scale_down;
458
- int exclude_zero; ///< If non-negative, excludes zeros
459
-
460
- /// Default ctor
461
- CUTLASS_HOST_DEVICE
462
- Params() { }
463
-
464
- //
465
- // Methods
466
- //
467
-
468
- /// Construction of Gaussian RNG functor.
469
- Params(
470
- uint64_t seed_ = 0,
471
- Element max_ = 1,
472
- Element min = 0,
473
- int int_scale_ = -1,
474
- double pnan_ = 0,
475
- int exclude_zero_ = -1
476
- ):
477
- seed(seed_),
478
- range(static_cast<FloatType>(max_) - static_cast<FloatType>(min)),
479
- max(static_cast<FloatType>(max_)),
480
- int_scale(int_scale_),
481
- pnan(pnan_),
482
- exclude_zero(exclude_zero_) {
483
-
484
- float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
485
- float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
486
-
487
- // Handle cases where min = 0 or max = 0 for excluding zeros
488
- if (exclude_zero >= 0) {
489
- range = (min == Element(0)) ? range - FloatType(1): range;
490
- max = (max_ == Element(0)) ? max - FloatType(1): max;
491
- }
492
- }
493
- };
494
-
495
- //
496
- // Data members
497
- //
498
-
499
- /// Parameters object
500
- Params params;
501
-
502
- /// RNG state object
503
- curandState_t rng_state;
504
-
505
- //
506
- // Methods
507
- //
508
-
509
- /// Device-side initialization of RNG
510
- CUTLASS_DEVICE
511
- RandomUniformFunc(Params const &params): params(params) {
512
-
513
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
514
-
515
- curand_init(params.seed, gtid, 0, &rng_state);
516
- }
517
-
518
- /// Compute random value and update RNG state
519
- CUTLASS_DEVICE
520
- Element operator()() {
521
-
522
- // Draw random float in [0.0, 1.0] to determine if element should be NaN.
523
- if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
524
- if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
525
- return Element(NAN);
526
- }
527
- }
528
-
529
- FloatType rnd = random_uniform_float<FloatType>(&rng_state);
530
- rnd = params.max - params.range * rnd;
531
-
532
- // Random values are cast to integer after scaling by a power of two to facilitate error
533
- // testing
534
- Element result;
535
-
536
- if (params.int_scale >= 0) {
537
- rnd = FloatType(std::llround(rnd * params.float_scale_up));
538
- result = Element(rnd * params.float_scale_down);
539
- }
540
- else {
541
- result = Element(rnd);
542
- }
543
-
544
- if (params.exclude_zero >=0 && result == Element(0.0)) {
545
- if (rnd > FloatType(0)) {
546
- rnd = std::min(params.max, rnd + FloatType(1));
547
- } else {
548
- rnd = std::max((params.max - params.range), rnd - FloatType(1));
549
- }
550
- result = Element(rnd);
551
- }
552
-
553
- return result;
554
- }
555
- };
556
-
557
- /// Computes a random Gaussian distribution
558
- template <typename Real>
559
- struct RandomUniformFunc<complex<Real>> {
560
-
561
- using Element = complex<Real>;
562
-
563
- using FloatType = typename std::conditional<
564
- (sizeof(Real) > 4),
565
- double,
566
- float>::type;
567
-
568
- using IntType = typename std::conditional<
569
- (sizeof(Real) > 4),
570
- int64_t,
571
- int>::type;
572
-
573
- /// Parameters structure
574
- struct Params {
575
-
576
- //
577
- // Data members
578
- //
579
-
580
- uint64_t seed;
581
- FloatType range;
582
- FloatType min;
583
- int int_scale;
584
- double pnan;
585
- FloatType float_scale_up;
586
- FloatType float_scale_down;
587
- int exclude_zero; ///< If non-negative, excludes zeros
588
-
589
- /// Default ctor
590
- CUTLASS_HOST_DEVICE
591
- Params() { }
592
-
593
- //
594
- // Methods
595
- //
596
-
597
- /// Construction of Gaussian RNG functor.
598
- Params(
599
- uint64_t seed_ = 0,
600
- FloatType max = 1,
601
- FloatType min_ = 0,
602
- int int_scale_ = -1,
603
- double pnan_ = 0,
604
- int exclude_zero_ = -1
605
- ):
606
- seed(seed_),
607
- range(static_cast<FloatType>(max - min_)),
608
- min(static_cast<FloatType>(min_)),
609
- int_scale(int_scale_),
610
- pnan(pnan_),
611
- exclude_zero(exclude_zero_) {
612
-
613
- float_scale_up = FloatType(IntType(1) << int_scale);
614
- float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
615
-
616
- // Handle cases where min = 0 or max = 0 for excluding zeros
617
- if (exclude_zero >= 0) {
618
- min = (min == FloatType(0)) ? min + FloatType(1): min;
619
- range = (max == FloatType(0)) ? range - FloatType(1): range;
620
- }
621
- }
622
- };
623
-
624
- //
625
- // Data members
626
- //
627
-
628
- /// Parameters object
629
- Params params;
630
-
631
- /// RNG state object
632
- curandState_t rng_state;
633
-
634
- //
635
- // Methods
636
- //
637
-
638
- /// Device-side initialization of RNG
639
- CUTLASS_DEVICE
640
- RandomUniformFunc(Params const &params): params(params) {
641
-
642
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
643
-
644
- curand_init(params.seed, gtid, 0, &rng_state);
645
- }
646
-
647
- /// Compute random value and update RNG state
648
- CUTLASS_DEVICE
649
- Element operator()() {
650
-
651
- // Draw random float in [0.0, 1.0] to determine if element should be NaN.
652
- if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
653
- if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
654
- return Element(Real(NAN), Real(NAN));
655
- }
656
- }
657
-
658
- FloatType rnd_r = random_uniform_float<FloatType>(&rng_state);
659
- FloatType rnd_i = random_uniform_float<FloatType>(&rng_state);
660
-
661
- rnd_r = params.min + params.range * rnd_r;
662
- rnd_i = params.min + params.range * rnd_i;
663
-
664
- // Random values are cast to integer after scaling by a power of two to facilitate error
665
- // testing
666
- Element result;
667
-
668
- if (params.int_scale >= 0) {
669
- rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
670
- rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
671
-
672
- result = {
673
- Real(rnd_r * params.float_scale_down),
674
- Real(rnd_i * params.float_scale_down)
675
- };
676
- }
677
- else {
678
- result = Element(Real(rnd_r), Real(rnd_i));
679
- }
680
-
681
- if (params.exclude_zero >= 0 &&
682
- result.real() == Real(0.0) &&
683
- result.imag() == Real(0.0)) {
684
-
685
- if (rnd_r > FloatType(0)) {
686
- rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1));
687
- } else {
688
- rnd_r = std::max((params.min), rnd_r - FloatType(1));
689
- }
690
- result = Element(Real(rnd_r), Real(rnd_i));
691
- }
692
-
693
- return result;
694
- }
695
- };
696
-
697
- /// Computes a random uniform distribution
698
- template <
699
- typename Element, ///< Element type
700
- typename Layout> ///< Layout function
701
- struct TensorFillRandomUniformFunc {
702
-
703
- /// View type
704
- using TensorView = TensorView<Element, Layout>;
705
-
706
- /// Scalar type
707
- typedef typename TensorView::Element T;
708
-
709
- /// Coordinate in tensor's index space
710
- typedef typename TensorView::TensorCoord TensorCoord;
711
-
712
- using RandomFunc = RandomUniformFunc<Element>;
713
-
714
- /// Parameters structure
715
- struct Params {
716
-
717
- //
718
- // Data members
719
- //
720
-
721
- TensorView view;
722
- typename RandomFunc::Params random;
723
-
724
- /// Default ctor
725
- CUTLASS_HOST_DEVICE
726
- Params() { }
727
-
728
- //
729
- // Methods
730
- //
731
-
732
- /// Construction of Gaussian RNG functor.
733
- Params(
734
- TensorView view_ = TensorView(),
735
- typename RandomFunc::Params random_ = RandomFunc::Params()
736
- ):
737
- view(view_), random(random_) {
738
-
739
- }
740
- };
741
-
742
- //
743
- // Data members
744
- //
745
-
746
- Params params;
747
- RandomFunc random;
748
-
749
- //
750
- // Methods
751
- //
752
-
753
- /// Device-side initialization of RNG
754
- CUTLASS_DEVICE
755
- TensorFillRandomUniformFunc(Params const &params): params(params), random(params.random) {
756
- }
757
-
758
- /// Compute random value and update RNG state
759
- CUTLASS_DEVICE
760
- void operator()(TensorCoord const &coord) {
761
-
762
- params.view.at(coord) = random();
763
- }
764
- };
765
-
766
- } // namespace detail
767
-
768
- ///////////////////////////////////////////////////////////////////////////////////////////////////
769
-
770
- /// Fills a tensor with random values with a uniform random distribution.
771
- template <
772
- typename Element, ///< Element type
773
- typename Layout> ///< Layout function
774
- void TensorFillRandomUniform(
775
- TensorView<Element, Layout> view, ///< destination tensor
776
- uint64_t seed, ///< seed for RNG
777
- typename RealType<Element>::Type max = Element(1), ///< upper bound of distribution
778
- typename RealType<Element>::Type min = Element(0), ///< lower bound for distribution
779
- int bits = -1, ///< If non-negative, specifies number of fractional bits that
780
- /// are not truncated to zero. Permits reducing precision of
781
- /// data.
782
- double pnan = 0, ///< Percentage of NaN elements.
783
- int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
784
- cudaStream_t stream = nullptr) {
785
-
786
- using RandomFunc = detail::RandomUniformFunc<Element>;
787
- using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
788
- using Params = typename Func::Params;
789
-
790
- typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero);
791
-
792
- TensorForEach<Func, Layout::kRank, Params>(
793
- view.extent(),
794
- Params(view, random),
795
- /*grid_size*/0, /*block_size*/0,
796
- stream
797
- );
798
- }
799
-
800
- ///////////////////////////////////////////////////////////////////////////////////////////////////
801
-
802
- /// Fills a tensor with random values with a uniform random distribution.
803
- template <typename Element>
804
- void BlockFillRandomUniform(
805
- Element *ptr,
806
- size_t capacity,
807
- uint64_t seed, ///< seed for RNG
808
- typename RealType<Element>::Type max, ///< upper bound of distribution
809
- typename RealType<Element>::Type min, ///< lower bound for distribution
810
- int bits = -1, ///< If non-negative, specifies number of fractional bits that
811
- /// are not truncated to zero. Permits reducing precision of
812
- /// data.
813
- double pnan = 0, ///< Percentage of NaN elements.
814
- cudaStream_t stream = nullptr) {
815
-
816
- using RandomFunc = detail::RandomUniformFunc<Element>;
817
-
818
- typename RandomFunc::Params params(seed, max, min, bits, pnan);
819
-
820
- BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
821
- }
822
-
823
- ///////////////////////////////////////////////////////////////////////////////////////////////////
824
- ///////////////////////////////////////////////////////////////////////////////////////////////////
825
-
826
- namespace detail {
827
-
828
- /// Computes a random sparse meta
829
- template <typename Element> ///< Element type
830
- struct RandomSparseMetaFunc {
831
-
832
- using FloatType = float;
833
-
834
- using IntType = int32_t;
835
-
836
- /// Parameters structure
837
- struct Params {
838
-
839
- //
840
- // Data members
841
- //
842
-
843
- uint64_t seed;
844
- FloatType range;
845
- int MetaSizeInBits;
846
-
847
- /// Default ctor
848
- CUTLASS_HOST_DEVICE
849
- Params() { }
850
-
851
- //
852
- // Methods
853
- //
854
-
855
- /// Construction of Gaussian RNG functor.
856
- Params(
857
- uint64_t seed_ = 0,
858
- int MetaSizeInBits_ = 2
859
- ):
860
- seed(seed_),
861
- MetaSizeInBits(MetaSizeInBits_) {
862
- if (MetaSizeInBits_ == 2) {
863
- range = 6;
864
- }
865
- else if (MetaSizeInBits_ == 4) {
866
- range = 2;
867
- }
868
- else {
869
- throw std::invalid_argument("Invalid MetaSizeInBits");
870
- }
871
- }
872
- };
873
-
874
- //
875
- // Data members
876
- //
877
-
878
- /// Parameters object
879
- Params params;
880
-
881
- /// RNG state object
882
- curandState_t rng_state;
883
-
884
- //
885
- // Methods
886
- //
887
-
888
- /// Device-side initialization of RNG
889
- CUTLASS_DEVICE
890
- RandomSparseMetaFunc(Params const &params): params(params) {
891
-
892
- uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
893
-
894
- curand_init(params.seed, gtid, 0, &rng_state);
895
- }
896
-
897
- /// Compute random value and update RNG state
898
- CUTLASS_DEVICE
899
- Element operator()() {
900
- Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe};
901
- Element TwoToOneMeta[2] = {0x4, 0xe};
902
-
903
- Element *MetaArray =
904
- (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta;
905
-
906
- Element result = 0x0;
907
-
908
- CUTLASS_PRAGMA_UNROLL
909
- for (int i = 0; i < cutlass::sizeof_bits<Element>::value / 4; ++i) {
910
- FloatType rnd = random_uniform_float<FloatType>(&rng_state);
911
- rnd = params.range * rnd;
912
- Element meta = MetaArray[(int)rnd];
913
-
914
- result = (Element)(result | ((Element)(meta << (i * 4))));
915
- }
916
-
917
- return result;
918
- }
919
- };
920
-
921
- /// Computes a random Gaussian distribution
922
- template <
923
- typename Element, ///< Element type
924
- typename Layout> ///< Layout function
925
- struct TensorFillRandomSparseMetaFunc {
926
-
927
- /// View type
928
- using TensorView = TensorView<Element, Layout>;
929
-
930
- /// Scalar type
931
- typedef typename TensorView::Element T;
932
-
933
- /// Coordinate in tensor's index space
934
- typedef typename TensorView::TensorCoord TensorCoord;
935
-
936
- using RandomFunc = RandomSparseMetaFunc<Element>;
937
-
938
- /// Parameters structure
939
- struct Params {
940
-
941
- //
942
- // Data members
943
- //
944
-
945
- TensorView view;
946
- typename RandomFunc::Params random;
947
-
948
- /// Default ctor
949
- CUTLASS_HOST_DEVICE
950
- Params() { }
951
-
952
- //
953
- // Methods
954
- //
955
-
956
- /// Construction of Gaussian RNG functor.
957
- Params(
958
- TensorView view_ = TensorView(),
959
- typename RandomFunc::Params random_ = RandomFunc::Params()
960
- ):
961
- view(view_), random(random_) {
962
-
963
- }
964
- };
965
-
966
- //
967
- // Data members
968
- //
969
-
970
- Params params;
971
- RandomFunc random;
972
-
973
- //
974
- // Methods
975
- //
976
-
977
- /// Device-side initialization of RNG
978
- CUTLASS_DEVICE
979
- TensorFillRandomSparseMetaFunc(Params const &params): params(params), random(params.random) {
980
- }
981
-
982
- /// Compute random value and update RNG state
983
- CUTLASS_DEVICE
984
- void operator()(TensorCoord const &coord) {
985
-
986
- params.view.at(coord) = random();
987
- }
988
- };
989
-
990
- } // namespace detail
991
-
992
- ///////////////////////////////////////////////////////////////////////////////////////////////////
993
-
994
- /// Fills a tensor with random values with a uniform random distribution.
995
- template <
996
- typename Element, ///< Element type
997
- typename Layout> ///< Layout function
998
- void TensorFillRandomSparseMeta(
999
- TensorView<Element, Layout> view, ///< destination tensor
1000
- uint64_t seed, ///< seed for RNG
1001
- int MetaSizeInBits = 2, ///< meta data size
1002
- cudaStream_t stream = nullptr) {
1003
-
1004
- using RandomFunc = detail::RandomSparseMetaFunc<Element>;
1005
- using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
1006
- using Params = typename Func::Params;
1007
-
1008
- typename RandomFunc::Params random(seed, MetaSizeInBits);
1009
-
1010
- TensorForEach<Func, Layout::kRank, Params>(
1011
- view.extent(),
1012
- Params(view, random),
1013
- /*grid_size*/0, /*block_size*/0,
1014
- stream
1015
- );
1016
- }
1017
-
1018
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1019
-
1020
- /// Fills a tensor with random values with a uniform random distribution.
1021
- template <typename Element>
1022
- void BlockFillRandomSparseMeta(
1023
- Element *ptr,
1024
- size_t capacity,
1025
- uint64_t seed, ///< seed for RNG
1026
- int MetaSizeInBits = 2, ///< meta data size
1027
- cudaStream_t stream = nullptr) {
1028
-
1029
- using RandomFunc = detail::RandomSparseMetaFunc<Element>;
1030
-
1031
- typename RandomFunc::Params params(seed, MetaSizeInBits);
1032
-
1033
- BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
1034
- }
1035
-
1036
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1037
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1038
-
1039
- namespace detail {
1040
-
1041
- /// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal.
1042
- template <
1043
- typename Element, ///< Element type
1044
- typename Layout> ///< Layout function
1045
- struct TensorFillDiagonalFunc {
1046
-
1047
- /// View type
1048
- using TensorView = TensorView<Element, Layout>;
1049
-
1050
- /// Scalar type
1051
- typedef typename TensorView::Element T;
1052
-
1053
- /// Coordinate in tensor's index space
1054
- typedef typename TensorView::TensorCoord TensorCoord;
1055
-
1056
- /// Parameters structure
1057
- struct Params {
1058
-
1059
- //
1060
- // Data members
1061
- //
1062
-
1063
- TensorView view;
1064
- Element diag;
1065
- Element other;
1066
-
1067
- /// Default ctor
1068
- CUTLASS_HOST_DEVICE
1069
- Params() { }
1070
-
1071
- //
1072
- // Methods
1073
- //
1074
-
1075
- Params(
1076
- TensorView view_ = TensorView(),
1077
- Element diag_ = Element(1),
1078
- Element other_ = Element(0)
1079
- ):
1080
- view(view_), diag(diag_), other(other_) {
1081
-
1082
- }
1083
- };
1084
-
1085
- //
1086
- // Data members
1087
- //
1088
-
1089
- /// Parameters object
1090
- Params params;
1091
-
1092
- //
1093
- // Methods
1094
- //
1095
-
1096
- /// Device-side initialization of RNG
1097
- CUTLASS_DEVICE
1098
- TensorFillDiagonalFunc(Params const &params): params(params) {
1099
-
1100
- }
1101
-
1102
- /// Updates the tensor
1103
- CUTLASS_DEVICE
1104
- void operator()(TensorCoord const &coord) {
1105
-
1106
- bool is_diag = true;
1107
-
1108
- CUTLASS_PRAGMA_UNROLL
1109
- for (int i = 1; i < Layout::kRank; ++i) {
1110
- if (coord[i] != coord[i - 1]) {
1111
- is_diag = false;
1112
- break;
1113
- }
1114
- }
1115
-
1116
- params.view.at(coord) = (is_diag ? params.diag : params.other);
1117
- }
1118
- };
1119
-
1120
- // Overwrites the elements of a tensor with a uniform value depending on fill mode
1121
- template <
1122
- typename Element, ///< Element type
1123
- typename Layout> ///< Layout function
1124
- struct TensorFillPartialFunc {
1125
-
1126
- /// View type
1127
- using TensorView = TensorView<Element, Layout>;
1128
-
1129
- /// Scalar type
1130
- typedef typename TensorView::Element T;
1131
-
1132
- /// Coordinate in tensor's index space
1133
- typedef typename TensorView::TensorCoord TensorCoord;
1134
-
1135
- /// Parameters structure
1136
- struct Params {
1137
-
1138
- //
1139
- // Data members
1140
- //
1141
-
1142
- TensorView view;
1143
- Element element;
1144
- FillMode fill_mode;
1145
-
1146
- /// Default ctor
1147
- CUTLASS_HOST_DEVICE
1148
- Params(): fill_mode(FillMode::kNone) { }
1149
-
1150
- //
1151
- // Methods
1152
- //
1153
-
1154
- /// Construction of Gaussian RNG functor.
1155
- Params(
1156
- TensorView view_,
1157
- Element element_,
1158
- FillMode fill_mode_
1159
- ):
1160
- view(view_), element(element_), fill_mode(fill_mode_) {
1161
-
1162
- }
1163
- };
1164
-
1165
- //
1166
- // Data members
1167
- //
1168
-
1169
- /// Parameters object
1170
- Params params;
1171
-
1172
- //
1173
- // Methods
1174
- //
1175
-
1176
- CUTLASS_DEVICE
1177
- TensorFillPartialFunc(Params const &params): params(params) {
1178
-
1179
- }
1180
-
1181
- /// Overwrites the element if it is within the covered region.
1182
- CUTLASS_DEVICE
1183
- void operator()(TensorCoord const &coord) {
1184
-
1185
- bool predicate = true;
1186
-
1187
- switch (params.fill_mode) {
1188
- case FillMode::kFull:
1189
- predicate = true;
1190
- break;
1191
-
1192
- case FillMode::kLower:
1193
- CUTLASS_PRAGMA_UNROLL
1194
- for (int i = 1; i < Layout::kRank; ++i) {
1195
- if (coord[i - 1] < coord[i]) {
1196
- predicate = false;
1197
- break;
1198
- }
1199
- }
1200
- break;
1201
-
1202
- case FillMode::kUpper:
1203
- CUTLASS_PRAGMA_UNROLL
1204
- for (int i = 1; i < Layout::kRank; ++i) {
1205
- if (coord[i - 1] > coord[i]) {
1206
- predicate = false;
1207
- break;
1208
- }
1209
- }
1210
- break;
1211
-
1212
- case FillMode::kDiagonal:
1213
- CUTLASS_PRAGMA_UNROLL
1214
- for (int i = 1; i < Layout::kRank; ++i) {
1215
- if (coord[i - 1] != coord[i]) {
1216
- predicate = false;
1217
- break;
1218
- }
1219
- }
1220
- break;
1221
-
1222
- case FillMode::kNone: // fall-through
1223
-
1224
- default:
1225
- predicate = false;
1226
- break;
1227
- }
1228
-
1229
- if (predicate) {
1230
- params.view.at(coord) = params.element;
1231
- }
1232
- }
1233
- };
1234
-
1235
-
1236
- template <
1237
- typename Element, ///< Element type
1238
- typename Layout> ///< Layout function
1239
- struct TensorClearPartialFunc {
1240
-
1241
- /// View type
1242
- using TensorView = TensorView<Element, Layout>;
1243
-
1244
- /// Scalar type
1245
- typedef typename TensorView::Element T;
1246
-
1247
- /// Coordinate in tensor's index space
1248
- typedef typename TensorView::TensorCoord TensorCoord;
1249
-
1250
- ///
1251
- static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices");
1252
-
1253
- /// Parameters structure
1254
- struct Params {
1255
- TensorView view{};
1256
- Element element{};
1257
- FillMode fill_mode{FillMode::kNone};
1258
- int alignment{0};
1259
- };
1260
-
1261
- //
1262
- // Data members
1263
- //
1264
-
1265
- /// Parameters object
1266
- Params params;
1267
-
1268
- //
1269
- // Methods
1270
- //
1271
-
1272
- CUTLASS_DEVICE
1273
- TensorClearPartialFunc(Params const &params): params(params) {
1274
-
1275
- }
1276
-
1277
- /// Overwrites the element if it is within the covered region.
1278
- CUTLASS_DEVICE
1279
- void operator()(TensorCoord const &coord) {
1280
-
1281
- bool predicate = true;
1282
-
1283
- switch (params.fill_mode) {
1284
-
1285
- case FillMode::kLower:
1286
- if ((coord[0] >= coord[1]) ||
1287
- ((coord[1] - coord[0]) >= params.alignment)) {
1288
- predicate = false;
1289
- break;
1290
- }
1291
- break;
1292
-
1293
- case FillMode::kUpper:
1294
- if ((coord[0] <= coord[1]) ||
1295
- ((coord[0] - coord[1]) >= params.alignment)) {
1296
- predicate = false;
1297
- break;
1298
- }
1299
- break;
1300
-
1301
- case FillMode::kNone: // fall-through
1302
-
1303
- default:
1304
- predicate = false;
1305
- break;
1306
- }
1307
-
1308
- if (predicate) {
1309
- params.view.at(coord) = params.element;
1310
- }
1311
- }
1312
- };
1313
-
1314
- } // namespace detail
1315
-
1316
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1317
-
1318
- /// Fills a tensor everywhere with a unique value for its diagonal.
1319
- template <
1320
- typename Element, ///< Element type
1321
- typename Layout> ///< Layout function
1322
- void TensorFillDiagonal(
1323
- TensorView<Element, Layout> view, ///< destination tensor
1324
- Element diag = Element(1), ///< value to write in the diagonal
1325
- Element other = Element(0), ///< value to write off the diagonal
1326
- cudaStream_t stream = nullptr) {
1327
-
1328
- typedef detail::TensorFillDiagonalFunc<Element, Layout> Func;
1329
- typedef typename Func::Params Params;
1330
-
1331
- TensorForEach<Func, Layout::kRank, Params>(
1332
- view.extent(),
1333
- Params(view, diag, other),
1334
- /*grid_size*/0, /*block_size*/0,
1335
- stream
1336
- );
1337
- }
1338
-
1339
- /// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are
1340
- /// not written.
1341
- template <
1342
- typename Element, ///< Element type
1343
- typename Layout> ///< Layout function
1344
- void TensorFillPartial(
1345
- TensorView<Element, Layout> view, ///< destination tensor
1346
- Element element,
1347
- FillMode fill_mode,
1348
- cudaStream_t stream = nullptr) {
1349
-
1350
- typedef detail::TensorFillPartialFunc<Element, Layout> Func;
1351
- typedef typename Func::Params Params;
1352
-
1353
- TensorForEach<Func, Layout::kRank, Params>(
1354
- view.extent(),
1355
- Params(view, element, fill_mode),
1356
- stream
1357
- );
1358
- }
1359
-
1360
- /// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side
1361
- /// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros)
1362
- template <
1363
- typename Element, ///< Element type
1364
- typename Layout> ///< Layout function
1365
- void TensorClearPartial(
1366
- TensorView<Element, Layout> view, ///< destination tensor
1367
- Element element,
1368
- FillMode fill_mode,
1369
- int alignment,
1370
- cudaStream_t stream = nullptr) {
1371
-
1372
- typedef detail::TensorClearPartialFunc<Element, Layout> Func;
1373
- typedef typename Func::Params Params;
1374
-
1375
- TensorForEach<Func, Layout::kRank, Params>(
1376
- view.extent(),
1377
- Params{view, element, fill_mode, alignment},
1378
- /*grid_size*/0, /*block_size*/0,
1379
- stream
1380
- );
1381
- }
1382
-
1383
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1384
-
1385
- /// Fills a tensor with a uniform value
1386
- template <
1387
- typename Element, ///< Element type
1388
- typename Layout> ///< Layout function
1389
- void TensorFill(
1390
- TensorView<Element, Layout> view, ///< destination tensor
1391
- Element val = Element(0), ///< value to uniformly fill it with
1392
- cudaStream_t stream = nullptr) {
1393
-
1394
- TensorFillDiagonal(view, val, val, stream);
1395
- }
1396
-
1397
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1398
-
1399
- /// Fills a tensor's diagonal with 1 and 0 everywhere else.
1400
- template <
1401
- typename Element, ///< Element type
1402
- typename Layout> ///< Layout function
1403
- void TensorFillIdentity(
1404
- TensorView<Element, Layout> view, ///< destination tensor
1405
- cudaStream_t stream = nullptr) {
1406
-
1407
- TensorFillDiagonal(view, Element(1), Element(0), stream);
1408
- }
1409
-
1410
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1411
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1412
-
1413
- namespace detail {
1414
-
1415
- /// Computes a random Gaussian distribution
1416
- template <
1417
- typename Element, ///< Element type
1418
- typename Layout> ///< Layout function
1419
- struct TensorUpdateDiagonalFunc {
1420
-
1421
- /// View type
1422
- using TensorView = TensorView<Element, Layout>;
1423
-
1424
- /// Scalar type
1425
- typedef typename TensorView::Element T;
1426
-
1427
- /// Coordinate in tensor's index space
1428
- typedef typename TensorView::TensorCoord TensorCoord;
1429
-
1430
- /// Parameters structure
1431
- struct Params {
1432
-
1433
- //
1434
- // Data members
1435
- //
1436
-
1437
- TensorView view;
1438
- Element diag;
1439
-
1440
- /// Default ctor
1441
- CUTLASS_HOST_DEVICE
1442
- Params() { }
1443
-
1444
- //
1445
- // Methods
1446
- //
1447
-
1448
- /// Construction of Gaussian RNG functor.
1449
- Params(
1450
- TensorView view_ = TensorView(),
1451
- Element diag_ = Element(1)
1452
- ):
1453
- view(view_), diag(diag_) {
1454
-
1455
- }
1456
- };
1457
-
1458
- //
1459
- // Data members
1460
- //
1461
-
1462
- /// Parameters object
1463
- Params params;
1464
-
1465
- //
1466
- // Methods
1467
- //
1468
-
1469
- /// Device-side initialization of RNG
1470
- CUTLASS_DEVICE
1471
- TensorUpdateDiagonalFunc(Params const &params): params(params) {
1472
-
1473
- }
1474
-
1475
- /// Compute random value and update RNG state
1476
- CUTLASS_DEVICE
1477
- void operator()(TensorCoord const &coord) {
1478
-
1479
- bool is_diag = true;
1480
-
1481
- CUTLASS_PRAGMA_UNROLL
1482
- for (int i = 1; i < Layout::kRank; ++i) {
1483
- if (coord[i] != coord[i - 1]) {
1484
- is_diag = false;
1485
- break;
1486
- }
1487
- }
1488
-
1489
- if (is_diag) {
1490
- params.view.at(coord) = params.diag;
1491
- }
1492
- }
1493
- };
1494
-
1495
- } // namespace detail
1496
-
1497
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1498
-
1499
- /// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements.
1500
- template <
1501
- typename Element, ///< Element type
1502
- typename Layout> ///< Layout function
1503
- void TensorUpdateDiagonal(
1504
- TensorView<Element, Layout> view, ///< destination tensor
1505
- Element diag = Element(1),
1506
- cudaStream_t stream = nullptr) {
1507
-
1508
- typedef detail::TensorUpdateDiagonalFunc<Element, Layout> Func;
1509
- typedef typename Func::Params Params;
1510
-
1511
- TensorForEach<Func, Layout::kRank, Params>(
1512
- view.extent(),
1513
- Params(view, diag),
1514
- /*grid_size*/0, /*block_size*/0,
1515
- stream
1516
- );
1517
- }
1518
-
1519
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1520
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1521
-
1522
- namespace detail {
1523
-
1524
- /// Computes a random Gaussian distribution
1525
- template <
1526
- typename Element, ///< Element type
1527
- typename Layout> ///< Layout function
1528
- struct TensorUpdateOffDiagonalFunc {
1529
-
1530
- /// View type
1531
- using TensorView = TensorView<Element, Layout>;
1532
-
1533
- /// Scalar type
1534
- typedef typename TensorView::Element T;
1535
-
1536
- /// Coordinate in tensor's index space
1537
- typedef typename TensorView::TensorCoord TensorCoord;
1538
-
1539
- /// Parameters structure
1540
- struct Params {
1541
-
1542
- //
1543
- // Data members
1544
- //
1545
-
1546
- TensorView view;
1547
- Element other;
1548
-
1549
- /// Default ctor
1550
- CUTLASS_HOST_DEVICE
1551
- Params() { }
1552
-
1553
- //
1554
- // Methods
1555
- //
1556
-
1557
- /// Construction of Gaussian RNG functor.
1558
- Params(
1559
- TensorView view_ = TensorView(),
1560
- Element other_ = Element(0)
1561
- ):
1562
- view(view_), other(other_) {
1563
-
1564
- }
1565
- };
1566
-
1567
- //
1568
- // Data members
1569
- //
1570
-
1571
- /// Parameters object
1572
- Params params;
1573
-
1574
- //
1575
- // Methods
1576
- //
1577
-
1578
- /// Device-side initialization of RNG
1579
- CUTLASS_DEVICE
1580
- TensorUpdateOffDiagonalFunc(Params const &params): params(params) {
1581
-
1582
- }
1583
-
1584
- /// Compute random value and update RNG state
1585
- CUTLASS_DEVICE
1586
- void operator()(TensorCoord const &coord) {
1587
-
1588
- bool is_diag = true;
1589
-
1590
- CUTLASS_PRAGMA_UNROLL
1591
- for (int i = 1; i < Layout::kRank; ++i) {
1592
- if (coord[i] != coord[i - 1]) {
1593
- is_diag = false;
1594
- break;
1595
- }
1596
- }
1597
-
1598
- if (!is_diag) {
1599
- params.view.at(coord) = params.other;
1600
- }
1601
- }
1602
- };
1603
-
1604
- } // namespace detail
1605
-
1606
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1607
-
1608
- /// Writes a uniform value to all elements in the tensor without modifying diagonal elements.
1609
- template <
1610
- typename Element, ///< Element type
1611
- typename Layout> ///< Layout function
1612
- void TensorUpdateOffDiagonal(
1613
- TensorView<Element, Layout> view, ///< destination tensor
1614
- Element other = Element(1),
1615
- cudaStream_t stream = nullptr) {
1616
-
1617
- typedef detail::TensorUpdateOffDiagonalFunc<Element, Layout> Func;
1618
- typedef typename Func::Params Params;
1619
-
1620
- TensorForEach<Func, Layout::kRank, Params>(
1621
- view.extent(),
1622
- Params(view, other),
1623
- /*grid_size*/0, /*block_size*/0,
1624
- stream
1625
- );
1626
- }
1627
-
1628
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1629
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1630
-
1631
- namespace detail {
1632
-
1633
- /// Computes a random Gaussian distribution
1634
- template <
1635
- typename Element, ///< Element type
1636
- typename Layout> ///< Layout function
1637
- struct TensorFillLinearFunc {
1638
-
1639
- /// View type
1640
- using TensorView = TensorView<Element, Layout>;
1641
-
1642
- /// Scalar type
1643
- typedef typename TensorView::Element T;
1644
-
1645
- /// Coordinate in tensor's index space
1646
- typedef typename TensorView::TensorCoord TensorCoord;
1647
-
1648
- /// Parameters structure
1649
- struct Params {
1650
-
1651
- //
1652
- // Data members
1653
- //
1654
-
1655
- TensorView view;
1656
- Array<Element, Layout::kRank> v;
1657
- Element s;
1658
-
1659
- /// Default ctor
1660
- CUTLASS_HOST_DEVICE
1661
- Params() { }
1662
-
1663
- //
1664
- // Methods
1665
- //
1666
-
1667
- /// Construction of Gaussian RNG functor.
1668
- Params(
1669
- TensorView view_, ///< destination tensor
1670
- Array<Element, Layout::kRank> const & v_,
1671
- Element s_ = Element(0)
1672
- ):
1673
- view(view_), v(v_), s(s_) {
1674
-
1675
- }
1676
- };
1677
-
1678
- //
1679
- // Data members
1680
- //
1681
-
1682
- /// Parameters object
1683
- Params params;
1684
-
1685
- //
1686
- // Methods
1687
- //
1688
-
1689
- /// Device-side initialization of RNG
1690
- CUTLASS_DEVICE
1691
- TensorFillLinearFunc(Params const &params): params(params) {
1692
-
1693
- }
1694
-
1695
- /// Compute random value and update RNG state
1696
- CUTLASS_DEVICE
1697
- void operator()(TensorCoord const &coord) {
1698
-
1699
- Element sum = params.s;
1700
-
1701
- CUTLASS_PRAGMA_UNROLL
1702
- for (int i = 0; i < Layout::kRank; ++i) {
1703
- if constexpr (is_complex<Element>::value) {
1704
- if constexpr (sizeof_bits<Element>::value <= 32) {
1705
- sum = Element(static_cast<complex<float>>(sum) +
1706
- static_cast<complex<float>>(params.v[i]) * static_cast<complex<float>>(coord[i]));
1707
- }
1708
- }
1709
- else if constexpr (sizeof_bits<Element>::value <= 32) {
1710
- if constexpr (std::numeric_limits<Element>::is_integer) {
1711
- sum = Element(static_cast<int32_t>(sum) +
1712
- static_cast<int32_t>(params.v[i]) * static_cast<int32_t>(coord[i]));
1713
- }
1714
- else {
1715
- sum = Element(static_cast<float>(sum) +
1716
- static_cast<float>(params.v[i]) * static_cast<float>(coord[i]));
1717
- }
1718
- }
1719
- else {
1720
- sum += params.v[i] * coord[i];
1721
- }
1722
- }
1723
-
1724
- params.view.at(coord) = sum;
1725
- }
1726
- };
1727
-
1728
- } // namespace detail
1729
-
1730
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1731
-
1732
- /// Fills tensor with a linear combination of its coordinate and another vector
1733
- template <
1734
- typename Element, ///< Element type
1735
- typename Layout> ///< Layout function
1736
- void TensorFillLinear(
1737
- TensorView<Element, Layout> view, ///< destination tensor
1738
- Array<Element, Layout::kRank> const & v,
1739
- Element s = Element(0),
1740
- cudaStream_t stream = nullptr) {
1741
-
1742
- using Func = detail::TensorFillLinearFunc<Element, Layout>;
1743
- using Params = typename Func::Params;
1744
-
1745
- TensorForEach<Func, Layout::kRank, Params>(
1746
- view.extent(),
1747
- Params(view, v, s),
1748
- /*grid_size*/0, /*block_size*/0,
1749
- stream
1750
- );
1751
- }
1752
-
1753
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1754
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1755
-
1756
- /// Fills a tensor with random values from a distribution.
1757
- template <
1758
- typename Element, ///< Element type
1759
- typename Layout> ///< Layout function
1760
- void TensorFillRandom(
1761
- TensorView<Element, Layout> view, ///< destination tensor
1762
- uint64_t seed,
1763
- Distribution dist,
1764
- cudaStream_t stream = nullptr,
1765
- int exclude_zero = -1 ///< If non-negative, excludes 0.
1766
- /// Note that setting this flag will result in more 1's,
1767
- /// as we use a simple mechanism to replace 0's by adding/subtracting 1's.
1768
- ) {
1769
-
1770
- using Real = typename RealType<Element>::Type;
1771
-
1772
- if (dist.kind == Distribution::Gaussian) {
1773
- TensorFillRandomGaussian<Element, Layout>(
1774
- view,
1775
- seed,
1776
- static_cast<Real>(dist.gaussian.mean),
1777
- static_cast<Real>(dist.gaussian.stddev),
1778
- dist.int_scale,
1779
- exclude_zero,
1780
- stream);
1781
- } else if (dist.kind == Distribution::Uniform) {
1782
- TensorFillRandomUniform<Element, Layout>(
1783
- view,
1784
- seed,
1785
- static_cast<Real>(dist.uniform.max),
1786
- static_cast<Real>(dist.uniform.min),
1787
- dist.int_scale,
1788
- dist.uniform.pnan,
1789
- exclude_zero,
1790
- stream);
1791
- }
1792
- }
1793
-
1794
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1795
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1796
-
1797
- /// Fills a block of data with sequential elements
1798
- template <
1799
- typename Element
1800
- >
1801
- void BlockFillSequential(
1802
- Element *ptr,
1803
- int64_t capacity,
1804
- Element v = Element(1),
1805
- Element s = Element(0)) {
1806
-
1807
- using Layout = layout::PackedVectorLayout;
1808
- Layout::TensorCoord size(static_cast<Layout::Index>(capacity)); // -Wconversion
1809
- Layout layout = Layout::packed(size);
1810
- TensorView<Element, Layout> view(ptr, layout, size);
1811
-
1812
- Array<Element, Layout::kRank> c{};
1813
- c[0] = v;
1814
-
1815
- TensorFillLinear(view, c, s);
1816
- }
1817
-
1818
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1819
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1820
-
1821
- /// Fills a block of data with sequential elements
1822
- template <
1823
- typename Element
1824
- >
1825
- void BlockFillRandom(
1826
- Element *ptr,
1827
- size_t capacity,
1828
- uint64_t seed,
1829
- Distribution dist,
1830
- cudaStream_t stream = nullptr) {
1831
-
1832
- using Real = typename RealType<Element>::Type;
1833
-
1834
- if (dist.kind == Distribution::Gaussian) {
1835
- BlockFillRandomGaussian<Element>(
1836
- ptr,
1837
- capacity,
1838
- seed,
1839
- static_cast<Real>(dist.gaussian.mean),
1840
- static_cast<Real>(dist.gaussian.stddev),
1841
- dist.int_scale,
1842
- stream);
1843
- }
1844
- else if (dist.kind == Distribution::Uniform) {
1845
- BlockFillRandomUniform<Element>(
1846
- ptr,
1847
- capacity,
1848
- seed,
1849
- static_cast<Real>(dist.uniform.max),
1850
- static_cast<Real>(dist.uniform.min),
1851
- dist.int_scale,
1852
- dist.uniform.pnan,
1853
- stream);
1854
- }
1855
- }
1856
-
1857
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1858
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1859
-
1860
- namespace detail {
1861
-
1862
- /// Computes a random Gaussian distribution
1863
- template <
1864
- typename Element, ///< Element type
1865
- typename Layout> ///< Layout function
1866
- struct TensorCopyDiagonalInFunc {
1867
-
1868
- /// View type
1869
- using TensorView = TensorView<Element, Layout>;
1870
-
1871
- /// Scalar type
1872
- typedef typename TensorView::Element T;
1873
-
1874
- /// Coordinate in tensor's index space
1875
- typedef typename TensorView::TensorCoord TensorCoord;
1876
-
1877
- /// Parameters structure
1878
- struct Params {
1879
-
1880
- //
1881
- // Data members
1882
- //
1883
-
1884
- TensorView view;
1885
- Element const *ptr;
1886
-
1887
- /// Default ctor
1888
- CUTLASS_HOST_DEVICE
1889
- Params() { }
1890
-
1891
- //
1892
- // Methods
1893
- //
1894
-
1895
- /// Construction of Gaussian RNG functor.
1896
- Params(
1897
- TensorView view_, ///< destination tensor
1898
- Element const *ptr_
1899
- ):
1900
- view(view_), ptr(ptr_) {
1901
-
1902
- }
1903
- };
1904
-
1905
- //
1906
- // Data members
1907
- //
1908
-
1909
- /// Parameters object
1910
- Params params;
1911
-
1912
- //
1913
- // Methods
1914
- //
1915
-
1916
- /// Device-side initialization of RNG
1917
- CUTLASS_DEVICE
1918
- TensorCopyDiagonalInFunc(Params const &params): params(params) {
1919
-
1920
- }
1921
-
1922
- /// Only update the diagonal element
1923
- CUTLASS_DEVICE
1924
- void operator()(TensorCoord const &coord) {
1925
- bool is_diagonal = true;
1926
-
1927
- CUTLASS_PRAGMA_UNROLL
1928
- for (int i = 1; i < Layout::kRank; ++i) {
1929
- if (coord[i] != coord[0]) {
1930
- is_diagonal = false;
1931
- }
1932
- }
1933
- if (is_diagonal) {
1934
- params.view.at(coord) = params.ptr[coord[0]];
1935
- }
1936
- }
1937
- };
1938
-
1939
- } // namespace detail
1940
-
1941
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1942
-
1943
- /// Copies a diagonal in from host memory without modifying off-diagonal elements.
1944
- template <
1945
- typename Element, ///< Element type
1946
- typename Layout> ///< Layout function
1947
- void TensorCopyDiagonalIn(
1948
- TensorView<Element, Layout> view, ///< destination tensor
1949
- Element const *ptr, ///< dense buffer of elements
1950
- cudaStream_t stream = nullptr) {
1951
-
1952
- using Func = detail::TensorCopyDiagonalInFunc<Element, Layout>;
1953
- using Params = typename Func::Params;
1954
-
1955
- TensorForEach<Func, Layout::kRank, Params>(
1956
- view.extent(),
1957
- Params(view, ptr),
1958
- /*grid_size*/0, /*block_size*/0,
1959
- stream
1960
- );
1961
- }
1962
-
1963
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1964
- ///////////////////////////////////////////////////////////////////////////////////////////////////
1965
-
1966
-
1967
- namespace detail {
1968
-
1969
- /// Computes a random Gaussian distribution
1970
- template <
1971
- typename Element, ///< Element type
1972
- typename Layout> ///< Layout function
1973
- struct TensorCopyDiagonalOutFunc {
1974
-
1975
- /// View type
1976
- using TensorView = TensorView<Element, Layout>;
1977
-
1978
- /// Scalar type
1979
- typedef typename TensorView::Element T;
1980
-
1981
- /// Coordinate in tensor's index space
1982
- typedef typename TensorView::TensorCoord TensorCoord;
1983
-
1984
- /// Parameters structure
1985
- struct Params {
1986
-
1987
- //
1988
- // Data members
1989
- //
1990
-
1991
- TensorView view;
1992
- Element *ptr;
1993
-
1994
- /// Default ctor
1995
- CUTLASS_HOST_DEVICE
1996
- Params() { }
1997
-
1998
- //
1999
- // Methods
2000
- //
2001
-
2002
- /// Construction of Gaussian RNG functor.
2003
- Params(
2004
- TensorView view_, ///< destination tensor
2005
- Element *ptr_
2006
- ):
2007
- view(view_), ptr(ptr_) {
2008
-
2009
- }
2010
- };
2011
-
2012
- //
2013
- // Data members
2014
- //
2015
-
2016
- /// Parameters object
2017
- Params params;
2018
-
2019
- //
2020
- // Methods
2021
- //
2022
-
2023
- /// Device-side initialization of RNG
2024
- CUTLASS_DEVICE
2025
- TensorCopyDiagonalOutFunc(Params const &params): params(params) {
2026
-
2027
- }
2028
-
2029
- /// Compute random value and update RNG state
2030
- CUTLASS_DEVICE
2031
- void operator()(TensorCoord const &coord) {
2032
- bool is_diagonal = true;
2033
-
2034
- CUTLASS_PRAGMA_UNROLL
2035
- for (int i = 1; i < Layout::kRank; ++i) {
2036
- if (coord[i] != coord[0]) {
2037
- is_diagonal = false;
2038
- }
2039
- }
2040
- if (is_diagonal) {
2041
- params.ptr[coord[0]] = params.view.at(coord);
2042
- }
2043
- }
2044
- };
2045
-
2046
- } // namespace detail
2047
-
2048
- ///////////////////////////////////////////////////////////////////////////////////////////////////
2049
-
2050
- /// Copies the diagonal of a tensor into a dense buffer in host memory.
2051
- template <
2052
- typename Element, ///< Element type
2053
- typename Layout> ///< Layout function
2054
- void TensorCopyDiagonalOut(
2055
- Element *ptr, ///< dense buffer of elements
2056
- TensorView<Element, Layout> view, ///< source tensor
2057
- cudaStream_t stream = nullptr) {
2058
-
2059
- using Func = detail::TensorCopyDiagonalOutFunc<Element, Layout>;
2060
- using Params = typename Func::Params;
2061
-
2062
- TensorForEach<Func, Layout::kRank, Params>(
2063
- view.extent(),
2064
- Params(view, ptr),
2065
- /*grid_size*/0, /*block_size*/0,
2066
- stream
2067
- );
2068
- }
2069
-
2070
- ///////////////////////////////////////////////////////////////////////////////////////////////////
2071
- ///////////////////////////////////////////////////////////////////////////////////////////////////
2072
-
2073
- } // namespace device
2074
- } // namespace reference
2075
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h DELETED
@@ -1,142 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include <stdexcept>
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/util/reference/device/kernel/tensor_foreach.h"
36
-
37
- namespace cutlass {
38
- namespace reference {
39
- namespace device {
40
-
41
- ///////////////////////////////////////////////////////////////////////////////////////////////////
42
-
43
- /// Launches a kernel calling a functor for each element in a tensor's index space.
44
- template <typename Func, int Rank, typename Params>
45
- struct TensorForEach {
46
-
47
- /// Constructor performs the operation.
48
- TensorForEach(
49
- Coord<Rank> size, Params params = Params(),
50
- int grid_size = 0, int block_size = 0,
51
- cudaStream_t stream = nullptr) {
52
-
53
- if (!grid_size || !block_size) {
54
-
55
- // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
56
- cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
57
- &grid_size,
58
- &block_size,
59
- reinterpret_cast<void const *>(kernel::TensorForEach<Func, Rank, Params>));
60
-
61
- if (result != cudaSuccess) {
62
- throw std::runtime_error("Failed to query occupancy.");
63
- }
64
- // Limit block size. This has the effect of increasing the number of items processed by a
65
- // single thread and reduces the impact of initialization overhead.
66
- block_size = (block_size < 128 ? block_size : 128);
67
- }
68
-
69
- dim3 grid(grid_size, 1, 1);
70
- dim3 block(block_size, 1, 1);
71
-
72
- kernel::TensorForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(size, params);
73
- }
74
- };
75
-
76
- ///////////////////////////////////////////////////////////////////////////////////////////////////
77
-
78
- /// Launches a kernel calling a functor for each element along a tensor's diagonal
79
- template <typename Func, int Rank, typename Params>
80
- struct TensorDiagonalForEach {
81
-
82
- /// Constructor performs the operation
83
- TensorDiagonalForEach(
84
- Coord<Rank> size, Params params = Params(),
85
- int start = 0, int end = -1,
86
- int block_size = 128, cudaStream_t stream = nullptr) {
87
-
88
- if (end < 0) {
89
- end = size.min();
90
- }
91
-
92
- dim3 block(block_size, 1, 1);
93
- dim3 grid((end - start + block_size - 1) / block_size, 1, 1);
94
-
95
- kernel::TensorDiagonalForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(
96
- size, params, start, end);
97
- }
98
- };
99
-
100
-
101
- ///////////////////////////////////////////////////////////////////////////////////////////////////
102
-
103
- template <typename Element, typename Func>
104
- struct BlockForEach {
105
-
106
- /// Constructor performs the operation.
107
- BlockForEach(
108
- Element *ptr,
109
- size_t capacity,
110
- typename Func::Params params = typename Func::Params(),
111
- int grid_size = 0,
112
- int block_size = 0,
113
- cudaStream_t stream = nullptr) {
114
-
115
- if (!grid_size || !block_size) {
116
-
117
- // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
118
- cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
119
- &grid_size,
120
- &block_size,
121
- reinterpret_cast<void const *>(kernel::BlockForEach<Element, Func>));
122
-
123
- if (result != cudaSuccess) {
124
- throw std::runtime_error("Failed to query occupancy.");
125
- }
126
- // Limit block size. This has the effect of increasing the number of items processed by a
127
- // single thread and reduces the impact of initialization overhead.
128
- block_size = (block_size < 128 ? block_size : 128);
129
- }
130
-
131
- dim3 grid(grid_size, 1, 1);
132
- dim3 block(block_size, 1, 1);
133
-
134
- kernel::BlockForEach<Element, Func><<< grid, block, 0, stream >>>(ptr, capacity, params);
135
- }
136
- };
137
-
138
- ///////////////////////////////////////////////////////////////////////////////////////////////////
139
-
140
- } // namespace device
141
- } // namespace reference
142
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h DELETED
@@ -1,514 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include <cmath>
34
-
35
- #include "cutlass/cutlass.h"
36
- #include "cutlass/complex.h"
37
- #include "cutlass/functional.h"
38
- #include "cutlass/numeric_conversion.h"
39
- #include "cutlass/tensor_view.h"
40
- #include "cutlass/util/device_memory.h"
41
- #include "cutlass/util/reference/detail/linear_to_coordinate.h"
42
-
43
- /////////////////////////////////////////////////////////////////////////////////////////////////
44
-
45
- namespace cutlass {
46
- namespace reference {
47
- namespace device {
48
-
49
- /////////////////////////////////////////////////////////////////////////////////////////////////
50
-
51
- namespace kernel {
52
-
53
- template <
54
- typename Element,
55
- typename Layout,
56
- typename ComputeType,
57
- typename ReduceOp,
58
- typename TransformOp,
59
- int kBlockSize = 128
60
- >
61
- __global__ void TensorTransformReducePartial(
62
- TensorView<Element, Layout> view, /// View of the tensor to reduce over
63
- ComputeType identity, /// Identity element of the reduction operation
64
- ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
65
- TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
66
- ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
67
-
68
- int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
69
- int64_t size = view.size();
70
-
71
- __shared__ ComputeType scratchpad[kBlockSize];
72
-
73
- for (; idx < size; idx += blockDim.x * gridDim.x) {
74
-
75
- // Map linear thread ID onto tensor coordinate
76
- typename Layout::TensorCoord coord;
77
-
78
- cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
79
-
80
- if (view.contains(coord)) {
81
-
82
- // Fetch element
83
- Element x = view.at(coord);
84
-
85
- // Transform
86
- identity = reduce(identity, transform(x));
87
- }
88
- }
89
-
90
- scratchpad[threadIdx.x] = identity;
91
-
92
- __syncthreads();
93
-
94
- // One thread performs the final reduction and stores out. This could be enhanced via
95
- // a tree reduction and pipelining.
96
- if (threadIdx.x == 0) {
97
-
98
- for (int i = 1; i < kBlockSize; ++i) {
99
- identity = reduce(identity, scratchpad[i]);
100
- }
101
-
102
- workspace[blockIdx.x] = identity;
103
- }
104
- }
105
-
106
- template <
107
- typename Element,
108
- typename Layout,
109
- typename ComputeType,
110
- typename ReduceOp,
111
- typename TransformOp,
112
- int kBlockSize = 128
113
- >
114
- __global__ void TensorTransformReducePartial(
115
- TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
116
- TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
117
- ComputeType identity, /// Identity element of the reduction operation
118
- ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
119
- TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
120
- ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
121
-
122
- int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
123
- auto size = static_cast<int64_t>(view_A.size());
124
-
125
- __shared__ ComputeType scratchpad[kBlockSize];
126
-
127
- for (; idx < size; idx += blockDim.x * gridDim.x) {
128
-
129
- // Map linear thread ID onto tensor coordinate
130
- typename Layout::TensorCoord coord;
131
-
132
- cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
133
-
134
- if (view_A.contains(coord)) {
135
-
136
- // Fetch element
137
- Element a = view_A.at(coord);
138
- Element b = view_B.at(coord);
139
-
140
- // Transform
141
- identity = reduce(identity, transform(a, b));
142
- }
143
- }
144
-
145
- scratchpad[threadIdx.x] = identity;
146
-
147
- __syncthreads();
148
-
149
- // One thread performs the final reduction and stores out. This could be enhanced via
150
- // a tree reduction and pipelining.
151
- if (threadIdx.x == 0) {
152
-
153
- for (int i = 1; i < kBlockSize; ++i) {
154
- identity = reduce(identity, scratchpad[i]);
155
- }
156
-
157
- workspace[blockIdx.x] = identity;
158
- }
159
- }
160
-
161
-
162
- template <
163
- typename ComputeType,
164
- typename ReduceOp,
165
- int kBlockSize = 32
166
- >
167
- __global__ void TensorTransformReduceFinalize(
168
- ComputeType *workspace,
169
- ComputeType identity,
170
- int workspace_size,
171
- ReduceOp reduce) {
172
-
173
- __shared__ ComputeType scratchpad[kBlockSize];
174
-
175
- for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) {
176
- identity = reduce(identity, workspace[idx]);
177
- }
178
-
179
- scratchpad[threadIdx.x] = identity;
180
-
181
- __syncthreads();
182
-
183
- if (threadIdx.x == 0) {
184
-
185
- for (int i = 1; i < kBlockSize; ++i) {
186
- identity = reduce(identity, scratchpad[i]);
187
- }
188
-
189
- workspace[0] = identity;
190
- }
191
- }
192
-
193
- } // namespace kernel
194
-
195
- /////////////////////////////////////////////////////////////////////////////////////////////////
196
-
197
- /// Transform-reduce operation over the elements of a tensor
198
- template <
199
- typename Element,
200
- typename Layout,
201
- typename ComputeType,
202
- typename ReduceOp,
203
- typename TransformOp
204
- >
205
- ComputeType TensorTransformReduce(
206
- TensorView<Element, Layout> view, /// View of the tensor to reduce over
207
- ComputeType identity, /// Identity element of the reduction operation
208
- ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
209
- TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
210
- ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
211
- int workspace_size, /// Number of elements in workspace
212
- cudaStream_t stream = nullptr, /// CUDA stream to launch into
213
- bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
214
- ) {
215
-
216
- int const kBlockSize = 128;
217
-
218
- dim3 block(kBlockSize, 1);
219
- dim3 grid(workspace_size, 1);
220
-
221
- kernel::TensorTransformReducePartial<
222
- Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
223
- ><<< grid, block, 0, stream >>>(
224
- view, identity, reduce, transform, workspace
225
- );
226
-
227
- int const kFinalizeBlockSize = 32;
228
-
229
- kernel::TensorTransformReduceFinalize<
230
- ComputeType, ReduceOp, kFinalizeBlockSize
231
- ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
232
- workspace, identity, workspace_size, reduce
233
- );
234
-
235
- cudaStreamSynchronize(stream);
236
-
237
- if (copy_out) {
238
- cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
239
- if (result != cudaSuccess) {
240
- throw std::runtime_error("cudaMemcpy() failed");
241
- }
242
- }
243
-
244
- return identity;
245
- }
246
-
247
- /// Transform-reduce operation over the elements of two tensors, zipped together
248
- template <
249
- typename Element,
250
- typename Layout,
251
- typename ComputeType,
252
- typename ReduceOp,
253
- typename TransformOp
254
- >
255
- ComputeType TensorTransformReduce(
256
- TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
257
- TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
258
- ComputeType identity, /// Identity element of the reduction operation
259
- ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
260
- TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
261
- ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
262
- int workspace_size, /// Number of elements in workspace
263
- cudaStream_t stream = nullptr, /// CUDA stream to launch into
264
- bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
265
- ) {
266
-
267
- if (view_A.extent() != view_B.extent()) {
268
- throw std::runtime_error("Extents must be equal.");
269
- }
270
-
271
- int const kBlockSize = 128;
272
-
273
- dim3 block(kBlockSize, 1);
274
- dim3 grid(workspace_size, 1);
275
-
276
- kernel::TensorTransformReducePartial<
277
- Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
278
- ><<< grid, block, 0, stream >>>(
279
- view_A, view_B, identity, reduce, transform, workspace
280
- );
281
-
282
- int const kFinalizeBlockSize = 32;
283
-
284
- kernel::TensorTransformReduceFinalize<
285
- ComputeType, ReduceOp, kFinalizeBlockSize
286
- ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
287
- workspace, identity, workspace_size, reduce
288
- );
289
-
290
- cudaStreamSynchronize(stream);
291
-
292
- if (copy_out) {
293
- cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
294
- if (result != cudaSuccess) {
295
- throw std::runtime_error("cudaMemcpy() failed");
296
- }
297
- }
298
-
299
- return identity;
300
- }
301
-
302
- /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
303
- /// workspace
304
- template <
305
- typename Element,
306
- typename Layout,
307
- typename ComputeType,
308
- typename ReduceOp,
309
- typename TransformOp
310
- >
311
- ComputeType TensorTransformReduce(
312
- TensorView<Element, Layout> view,
313
- ComputeType identity,
314
- ReduceOp reduce,
315
- TransformOp transform,
316
- cudaStream_t stream = nullptr,
317
- int workspace_size = 0
318
- ) {
319
-
320
- // Optionally query for the SM count to size the workspace.
321
- if (!workspace_size) {
322
-
323
- int device_idx = 0;
324
- cudaDeviceProp prop;
325
-
326
- cudaError_t result = cudaGetDevice(&device_idx);
327
- if (result != cudaSuccess) {
328
- throw std::runtime_error("cudaGetDevice() failed");
329
- }
330
-
331
- result = cudaGetDeviceProperties(&prop, device_idx);
332
- if (result != cudaSuccess) {
333
- throw std::runtime_error("cudaGetDeviceProp() failed");
334
- }
335
-
336
- workspace_size = int(prop.multiProcessorCount);
337
- }
338
-
339
- DeviceAllocation<ComputeType> workspace(workspace_size);
340
-
341
- ComputeType output = TensorTransformReduce(
342
- view,
343
- identity,
344
- reduce,
345
- transform,
346
- workspace.get(),
347
- workspace_size,
348
- stream,
349
- true);
350
-
351
- return output;
352
- }
353
-
354
-
355
- /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
356
- /// workspace
357
- template <
358
- typename Element,
359
- typename Layout,
360
- typename ComputeType,
361
- typename ReduceOp,
362
- typename TransformOp
363
- >
364
- ComputeType TensorTransformReduce(
365
- TensorView<Element, Layout> view_A,
366
- TensorView<Element, Layout> view_B,
367
- ComputeType identity,
368
- ReduceOp reduce,
369
- TransformOp transform,
370
- cudaStream_t stream = nullptr,
371
- int workspace_size = 0
372
- ) {
373
-
374
- // Optionally query for the SM count to size the workspace.
375
- if (!workspace_size) {
376
-
377
- int device_idx = 0;
378
- cudaDeviceProp prop;
379
-
380
- cudaError_t result = cudaGetDevice(&device_idx);
381
- if (result != cudaSuccess) {
382
- throw std::runtime_error("cudaGetDevice() failed");
383
- }
384
-
385
- result = cudaGetDeviceProperties(&prop, device_idx);
386
- if (result != cudaSuccess) {
387
- throw std::runtime_error("cudaGetDeviceProp() failed");
388
- }
389
-
390
- workspace_size = int(prop.multiProcessorCount);
391
- }
392
-
393
- DeviceAllocation<ComputeType> workspace(workspace_size);
394
-
395
- ComputeType output = TensorTransformReduce(
396
- view_A,
397
- view_B,
398
- identity,
399
- reduce,
400
- transform,
401
- workspace.get(),
402
- workspace_size,
403
- stream,
404
- true);
405
-
406
- return output;
407
- }
408
-
409
- /////////////////////////////////////////////////////////////////////////////////////////////////
410
-
411
- /// Helper to compute the sum of the elements of a tensor
412
- template <
413
- typename Element,
414
- typename Layout,
415
- typename ComputeType = Element
416
- >
417
- ComputeType TensorSum(
418
- TensorView<Element, Layout> view,
419
- ComputeType identity = ComputeType(),
420
- cudaStream_t stream = nullptr,
421
- int workspace_size = 0
422
- ) {
423
-
424
- plus<ComputeType> reduce;
425
- NumericConverter<ComputeType, Element> transform;
426
-
427
- return TensorTransformReduce(
428
- view, identity, reduce, transform, stream, workspace_size);
429
- }
430
-
431
- /// Helper to compute the sum of the squares of the elements of a tensor
432
- template <
433
- typename Element,
434
- typename Layout,
435
- typename ComputeType = Element
436
- >
437
- ComputeType TensorSumSq(
438
- TensorView<Element, Layout> view,
439
- ComputeType identity = ComputeType(),
440
- cudaStream_t stream = nullptr,
441
- int workspace_size = 0
442
- ) {
443
-
444
- plus<ComputeType> reduce;
445
- magnitude_squared<Element, ComputeType> transform;
446
-
447
- return TensorTransformReduce(
448
- view, identity, reduce, transform, stream, workspace_size);
449
- }
450
-
451
- /// Helper to compute the norm of the elements of a tensor.
452
- template <
453
- typename Element,
454
- typename Layout,
455
- typename ComputeType = double
456
- >
457
- ComputeType TensorNorm(
458
- TensorView<Element, Layout> view,
459
- ComputeType identity = ComputeType(),
460
- cudaStream_t stream = nullptr,
461
- int workspace_size = 0
462
- ) {
463
-
464
- return std::sqrt(TensorSumSq(view, identity, stream, workspace_size));
465
- }
466
-
467
- /////////////////////////////////////////////////////////////////////////////////////////////////
468
-
469
- /// Helper to compute the sum of the squares of the differences of two tensors
470
- template <
471
- typename Element,
472
- typename Layout,
473
- typename ComputeType = double
474
- >
475
- ComputeType TensorSumSqDiff(
476
- TensorView<Element, Layout> view_A,
477
- TensorView<Element, Layout> view_B,
478
- ComputeType identity = ComputeType(),
479
- cudaStream_t stream = nullptr,
480
- int workspace_size = 0
481
- ) {
482
-
483
- plus<ComputeType> reduce;
484
- magnitude_squared_difference<Element, ComputeType> transform;
485
-
486
- return TensorTransformReduce(
487
- view_A, view_B, identity, reduce, transform, stream, workspace_size);
488
- }
489
-
490
-
491
- /// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
492
- template <
493
- typename Element,
494
- typename Layout,
495
- typename ComputeType = double
496
- >
497
- ComputeType TensorNormDiff(
498
- TensorView<Element, Layout> view_A,
499
- TensorView<Element, Layout> view_B,
500
- ComputeType identity = ComputeType(),
501
- cudaStream_t stream = nullptr,
502
- int workspace_size = 0
503
- ) {
504
-
505
- return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size));
506
- }
507
-
508
- /////////////////////////////////////////////////////////////////////////////////////////////////
509
-
510
- } // namespace device
511
- } // namespace reference
512
- } // namespace cutlass
513
-
514
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h DELETED
@@ -1,141 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /* \file
32
- \brief Defines device-side elementwise operations on TensorView. Note, the operations defined
33
- in this header are not specialized for any particular data layout and are therefore not
34
- intended to offer the best possible performance. Rather, they are intended to be generic
35
- reference implementations to support the CUTLASS unit tests.
36
- */
37
-
38
- #pragma once
39
-
40
- // Cutlass includes
41
- #include "cutlass/cutlass.h"
42
- #include "cutlass/tensor_view.h"
43
-
44
- #include "cutlass/util/reference/device/tensor_foreach.h"
45
-
46
- ///////////////////////////////////////////////////////////////////////////////////////////////////
47
-
48
- namespace cutlass {
49
- namespace reference {
50
- namespace device {
51
-
52
- ///////////////////////////////////////////////////////////////////////////////////////////////////
53
- ///////////////////////////////////////////////////////////////////////////////////////////////////
54
-
55
- namespace detail {
56
-
57
- template <
58
- typename Element, ///< Element type
59
- typename Layout> ///< Layout function
60
- struct TensorReLuFunc {
61
-
62
- /// View type
63
- using TensorView = TensorView<Element, Layout>;
64
-
65
- /// Coordinate in tensor's index space
66
- using TensorCoord = typename TensorView::TensorCoord;
67
-
68
- /// Parameters structure
69
- struct Params {
70
-
71
- //
72
- // Data members
73
- //
74
-
75
- TensorView view;
76
- Element threshold;
77
-
78
-
79
- //
80
- // Methods
81
- //
82
-
83
- Params(
84
- TensorView view_ = TensorView(),
85
- Element threshold_ = Element(0)
86
- ):
87
- view(view_), threshold(threshold_) {
88
-
89
- }
90
- };
91
-
92
- //
93
- // Data members
94
- //
95
-
96
- Params params;
97
-
98
- //
99
- // Methods
100
- //
101
-
102
- CUTLASS_DEVICE
103
- TensorReLuFunc(Params const &params): params(params) {
104
-
105
- }
106
-
107
- CUTLASS_DEVICE
108
- void operator()(TensorCoord const &coord) {
109
-
110
- Element const & value = params.view.at(coord);
111
- params.view.at(coord) = (value < params.threshold) ? params.threshold : value;
112
- }
113
- };
114
-
115
- } // namespace detail
116
-
117
- ///////////////////////////////////////////////////////////////////////////////////////////////////
118
-
119
- /// Apply ReLu on a tensor
120
- template <
121
- typename Element, ///< Element type
122
- typename Layout> ///< Layout function
123
- void TensorReLu(
124
- TensorView<Element, Layout> view, ///< destination tensor
125
- Element threshold = Element(0)) { ///< ReLu threshold
126
-
127
- using Func = detail::TensorReLuFunc<Element, Layout>;
128
- using Params = typename Func::Params;
129
-
130
- TensorForEach<Func, Layout::kRank, Params>(
131
- view.extent(),
132
- Params(view, threshold)
133
- );
134
- }
135
-
136
- ///////////////////////////////////////////////////////////////////////////////////////////////////
137
- ///////////////////////////////////////////////////////////////////////////////////////////////////
138
-
139
- } // namespace device
140
- } // namespace reference
141
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h DELETED
@@ -1,186 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for GEMM in host-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/tensor_view.h"
39
- #include "cutlass/gemm/gemm.h"
40
-
41
- namespace cutlass {
42
- namespace reference {
43
- namespace device {
44
- namespace thread {
45
-
46
- ////////////////////////////////////////////////////////////////////////////////////////////////////
47
-
48
- /// Thread-level blocked general matrix product.
49
- //
50
- // Note, this is a reference implementation. Performance is not expected to approach peak.
51
- //
52
- template <
53
- typename TensorRefA,
54
- typename TensorRefB,
55
- typename TensorRefC,
56
- typename ScalarType,
57
- typename AccumulatorType,
58
- typename OutputTile,
59
- typename InnerProductOp = multiply_add<AccumulatorType>,
60
- typename ConvertOp = NumericConverter<typename TensorRefC::Element, ScalarType>
61
- >
62
- struct Gemm {
63
-
64
- using ElementA = typename TensorRefA::Element;
65
- using ElementB = typename TensorRefB::Element;
66
- using ElementC = typename TensorRefC::Element;
67
-
68
- //
69
- // Data members
70
- //
71
-
72
- /// Tile for A operand
73
- ElementA A_tile[OutputTile::kColumn];
74
-
75
- /// Tile for B operand
76
- ElementB B_tile[OutputTile::kRow];
77
-
78
- /// Tile for Accumulator
79
- AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow];
80
-
81
- //
82
- // Methods
83
- //
84
-
85
- /// Constructor
86
- CUTLASS_HOST_DEVICE
87
- Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
88
-
89
- // Clear fetch registers
90
- for (int i = 0; i < OutputTile::kColumn; ++i) {
91
- A_tile[i] = ElementA(0);
92
- }
93
-
94
- for (int j = 0; j < OutputTile::kRow; ++j) {
95
- B_tile[j] = ElementB(0);
96
- }
97
-
98
- // Clear accumulators
99
- CUTLASS_PRAGMA_UNROLL
100
- for (int j = 0; j < OutputTile::kColumn; ++j) {
101
- CUTLASS_PRAGMA_UNROLL
102
- for (int i = 0; i < OutputTile::kRow; ++i) {
103
- accum[j][i] = initial_accum;
104
- }
105
- }
106
- }
107
-
108
- /// Computes a matrix product
109
- CUTLASS_HOST_DEVICE
110
- Gemm & multiply_add(
111
- gemm::GemmCoord problem_size,
112
- TensorRefA tensor_a,
113
- TensorRefB tensor_b,
114
- MatrixCoord output_coord = MatrixCoord()) {
115
-
116
- InnerProductOp inner_product_op;
117
-
118
- // Loop over the GEMM K dimension
119
- CUTLASS_PRAGMA_NO_UNROLL
120
- for (int k = 0; k < problem_size.k(); ++k) {
121
-
122
- // Fetch a slice of the A matrix
123
- CUTLASS_PRAGMA_UNROLL
124
- for (int i = 0; i < OutputTile::kColumn; ++i) {
125
- if (output_coord.row() + i < problem_size.m()) {
126
- A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k));
127
- }
128
- }
129
-
130
- // Fetch a slice of the B matrix
131
- CUTLASS_PRAGMA_UNROLL
132
- for (int j = 0; j < OutputTile::kRow; ++j) {
133
- if (output_coord.column() + j < problem_size.n()) {
134
- B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j));
135
- }
136
- }
137
-
138
- // Compute an accumulated matrix product
139
- CUTLASS_PRAGMA_UNROLL
140
- for (int j = 0; j < OutputTile::kRow; ++j) {
141
- CUTLASS_PRAGMA_UNROLL
142
- for (int i = 0; i < OutputTile::kColumn; ++i) {
143
- accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]);
144
- }
145
- }
146
- }
147
-
148
- return *this;
149
- }
150
-
151
- /// Performs linear scaling of matrix product and updates output tensor
152
- CUTLASS_HOST_DEVICE
153
- Gemm & epilogue(
154
- gemm::GemmCoord problem_size,
155
- ScalarType alpha,
156
- ScalarType beta,
157
- TensorRefC tensor_c,
158
- TensorRefC tensor_d,
159
- MatrixCoord output_coord = MatrixCoord()) {
160
-
161
- ConvertOp convert_op;
162
-
163
- // Update the output tensor
164
- for (int j = 0; j < OutputTile::kRow; ++j) {
165
- for (int i = 0; i < OutputTile::kColumn; ++i) {
166
- MatrixCoord coord = output_coord + MatrixCoord(i, j);
167
- if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
168
-
169
- tensor_d.at(coord) = convert_op(
170
- alpha * ScalarType(accum[j][i]) +
171
- beta * ScalarType(tensor_c.at(coord))
172
- );
173
- }
174
- }
175
- }
176
-
177
- return *this;
178
- }
179
- };
180
-
181
- ////////////////////////////////////////////////////////////////////////////////////////////////////
182
-
183
- } // namespace thread
184
- } // namespace device
185
- } // namespace reference
186
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp DELETED
@@ -1,782 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for CONV in host-side code.
33
- */
34
- #pragma once
35
-
36
- /////////////////////////////////////////////////////////////////////////////////////////////////
37
-
38
- #include "cutlass/complex.h"
39
- #include "cutlass/numeric_conversion.h"
40
- #include "cutlass/epilogue/thread/activation.h"
41
-
42
- #include "cute/tensor.hpp"
43
-
44
- #include <cuda_runtime.h>
45
-
46
- /////////////////////////////////////////////////////////////////////////////////////////////////
47
-
48
- namespace cutlass::reference::host {
49
-
50
- /////////////////////////////////////////////////////////////////////////////////////////////////
51
-
52
- namespace detail {
53
-
54
- template<class EngineAct, class LayoutAct>
55
- bool
56
- is_activation_in_bounds(
57
- cute::Tensor<EngineAct, LayoutAct> const& activation,
58
- int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
59
- return ((g_ >= 0 && g_ < size<5>(activation)) &&
60
- (n_ >= 0 && n_ < size<4>(activation)) &&
61
- (d_ >= 0 && d_ < size<3>(activation)) &&
62
- (h_ >= 0 && h_ < size<2>(activation)) &&
63
- (w_ >= 0 && w_ < size<1>(activation)) &&
64
- (c_ >= 0 && c_ < size<0>(activation)));
65
- }
66
-
67
- template<class EngineAct, class LayoutAct>
68
- bool
69
- is_activation_in_bounds(
70
- cute::Tensor<EngineAct, LayoutAct> const& activation,
71
- int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
72
- return ((g_ >= 0 && g_ < size<4>(activation)) &&
73
- (n_ >= 0 && n_ < size<3>(activation)) &&
74
- (h_ >= 0 && h_ < size<2>(activation)) &&
75
- (w_ >= 0 && w_ < size<1>(activation)) &&
76
- (c_ >= 0 && c_ < size<0>(activation)));
77
- }
78
-
79
- template<class EngineAct, class LayoutAct>
80
- bool
81
- is_activation_in_bounds(
82
- cute::Tensor<EngineAct, LayoutAct> const& activation,
83
- int32_t n_, int32_t w_, int32_t c_, int32_t g_) {
84
- return ((g_ >= 0 && g_ < size<3>(activation)) &&
85
- (n_ >= 0 && n_ < size<2>(activation)) &&
86
- (w_ >= 0 && w_ < size<1>(activation)) &&
87
- (c_ >= 0 && c_ < size<0>(activation)));
88
- }
89
-
90
- } // namespace detail
91
-
92
- template<
93
- class ElementAcc_,
94
- class ElementScalar_,
95
- class ElementCompute_,
96
- class ElementC_,
97
- class ElementOut_,
98
- bool ResidualAdd_,
99
- class TensorAlpha_,
100
- class TensorBeta_,
101
- class TensorBias_,
102
- class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>
103
- >
104
- struct ConvEpilogueFusionParams {
105
- using ElementAcc = ElementAcc_;
106
- using ElementScalar = ElementScalar_;
107
- using ElementCompute = ElementCompute_;
108
- using ElementC = ElementC_;
109
- using ElementOut = ElementOut_;
110
- using TensorAlpha = TensorAlpha_;
111
- using TensorBeta = TensorBeta_;
112
- using TensorBias = TensorBias_;
113
- using ActivationFunctor = ActivationFunctor_;
114
- static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation
115
-
116
- ElementScalar alpha = ElementScalar(1);
117
- ElementScalar beta = ElementScalar(0);
118
-
119
- TensorAlpha tensor_alpha{};
120
- TensorBeta tensor_beta{};
121
- TensorBias tensor_bias{};
122
- };
123
-
124
- template<
125
- cutlass::conv::Operator ConvOp,
126
- int NumSpatialDims,
127
- class TensorA,
128
- class TensorB,
129
- class TensorC,
130
- class TensorD,
131
- class ShapePadding,
132
- class StrideTraversal,
133
- class ShapeDilation,
134
- class EpilogueFusionParams
135
- >
136
- struct ConvReferenceImpl {
137
- // Hard code accumlulator type to float to avoid data lost in accumulating add.
138
- using ElementAcc = cutlass::platform::conditional_t<cutlass::platform::is_same_v<typename EpilogueFusionParams::ElementAcc, double>, double, float>;
139
- using ElementC = typename EpilogueFusionParams::ElementC;
140
- using ElementOut = typename EpilogueFusionParams::ElementOut;
141
- using ElementScalar = typename EpilogueFusionParams::ElementScalar;
142
- using ElementCompute = typename EpilogueFusionParams::ElementCompute;
143
- using ElementBias = typename EpilogueFusionParams::TensorBias::value_type;
144
- using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor;
145
-
146
- // Input related converter
147
- NumericConverter<ElementCompute, ElementAcc> acc_converter;
148
- NumericConverter<ElementCompute, ElementC> residual_converter;
149
- NumericConverter<ElementCompute, ElementBias> bias_converter;
150
- // Scale related converter
151
- NumericConverter<ElementCompute, ElementScalar> scale_converter;
152
- // Output related converter
153
- NumericConverter<ElementOut, ElementCompute> output_converter;
154
-
155
- EpilogueFusionParams& epi_fusion_params_;
156
- TensorA const& tensor_a_;
157
- TensorB const& tensor_b_;
158
- TensorC const& tensor_c_;
159
- TensorD& tensor_d_;
160
-
161
- ShapePadding const& padding_;
162
- StrideTraversal const& tstride_;
163
- ShapeDilation const& dilation_;
164
-
165
- // Epilogue activation operation
166
- ActivationFunctor epi_activation;
167
-
168
- ConvReferenceImpl(
169
- TensorA const& tensor_a,
170
- TensorB const& tensor_b,
171
- TensorC const& tensor_c,
172
- TensorD& tensor_d,
173
- ShapePadding const& padding,
174
- StrideTraversal const& tstride,
175
- ShapeDilation const& dilation,
176
- EpilogueFusionParams& epi_fusion_params)
177
- : tensor_a_(tensor_a),
178
- tensor_b_(tensor_b),
179
- tensor_c_(tensor_c),
180
- tensor_d_(tensor_d),
181
- padding_(padding),
182
- tstride_(tstride),
183
- dilation_(dilation),
184
- epi_fusion_params_(epi_fusion_params)
185
- {
186
- static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
187
- static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
188
- }
189
-
190
- void compute_reference() {
191
- if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
192
- fprop_reference(cute::Int<NumSpatialDims>{});
193
- }
194
- else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
195
- dgrad_reference(cute::Int<NumSpatialDims>{});
196
- }
197
- else {
198
- wgrad_reference(cute::Int<NumSpatialDims>{});
199
- }
200
- }
201
-
202
- private:
203
- // Specialization for 1D fprop kernel
204
- void fprop_reference(cute::Int<1> spatial_dims) {
205
- int32_t G = size<3>(tensor_d_);
206
- int32_t N = size<2>(tensor_d_);
207
- int32_t Q = size<1>(tensor_d_);
208
- int32_t K = size<0>(tensor_d_);
209
- int32_t S = size<1>(tensor_b_);
210
- int32_t C = size<0>(tensor_b_);
211
-
212
- #if defined(_OPENMP)
213
- #pragma omp parallel for collapse(2)
214
- #endif
215
- for (int32_t g = 0; g < G; ++g) {
216
- for (int32_t n = 0; n < N; ++n) {
217
- for (int32_t q = 0; q < Q; ++q) {
218
- for (int32_t k = 0; k < K; ++k) {
219
- auto accumulator = ElementAcc(0);
220
- for (int32_t s = 0; s < S; ++s) {
221
- for (int32_t c = 0; c < C; ++c) {
222
- int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
223
- if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) {
224
- auto a = tensor_a_(c, w, n, g);
225
- auto b = tensor_b_(c, s, k, g);
226
- accumulator += ElementAcc(a * b);
227
- }
228
- }
229
- }
230
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
231
- epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
232
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
233
- epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
234
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
235
- if (not EpilogueFusionParams::ResidualAdd) {
236
- output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
237
- }
238
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
239
- output += bias_converter(epi_fusion_params_.tensor_bias[k]);
240
- }
241
- output = epi_activation(output);
242
- if (EpilogueFusionParams::ResidualAdd) {
243
- output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
244
- }
245
- tensor_d_(k, q, n, g) = output_converter(output);
246
- }
247
- }
248
- }
249
- }
250
-
251
- }
252
-
253
- // Specialization for 2D fprop kernel
254
- void fprop_reference(cute::Int<2> spatial_dims) {
255
- int32_t G = size<4>(tensor_d_);
256
- int32_t N = size<3>(tensor_d_);
257
- int32_t P = size<2>(tensor_d_);
258
- int32_t Q = size<1>(tensor_d_);
259
- int32_t K = size<0>(tensor_d_);
260
- int32_t R = size<2>(tensor_b_);
261
- int32_t S = size<1>(tensor_b_);
262
- int32_t C = size<0>(tensor_b_);
263
-
264
- #if defined(_OPENMP)
265
- #pragma omp parallel for collapse(3)
266
- #endif
267
- for (int32_t g = 0; g < G; ++g) {
268
- for (int32_t n = 0; n < N; ++n) {
269
- for (int32_t p = 0; p < P; ++p) {
270
- for (int32_t q = 0; q < Q; ++q) {
271
- for (int32_t k = 0; k < K; ++k) {
272
- auto accumulator = ElementAcc(0);
273
- for (int32_t r = 0; r < R; ++r) {
274
- for (int32_t s = 0; s < S; ++s) {
275
- for (int32_t c = 0; c < C; ++c) {
276
- int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
277
- int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
278
- if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) {
279
- auto a = tensor_a_(c, w, h, n, g);
280
- auto b = tensor_b_(c, s, r, k, g);
281
- accumulator += ElementAcc(a * b);
282
- }
283
- }
284
- }
285
- }
286
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
287
- epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
288
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
289
- epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
290
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
291
- if (not EpilogueFusionParams::ResidualAdd) {
292
- output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
293
- }
294
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
295
- output += bias_converter(epi_fusion_params_.tensor_bias[k]);
296
- }
297
- output = epi_activation(output);
298
- if (EpilogueFusionParams::ResidualAdd) {
299
- output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
300
- }
301
- tensor_d_(k, q, p, n, g) = output_converter(output);
302
- }
303
- }
304
- }
305
- }
306
- }
307
-
308
- }
309
-
310
- // Specialization for 3D fprop kernel
311
- void fprop_reference(cute::Int<3> spatial_dims) {
312
- int32_t G = size<5>(tensor_d_);
313
- int32_t N = size<4>(tensor_d_);
314
- int32_t Z = size<3>(tensor_d_);
315
- int32_t P = size<2>(tensor_d_);
316
- int32_t Q = size<1>(tensor_d_);
317
- int32_t K = size<0>(tensor_d_);
318
- int32_t T = size<3>(tensor_b_);
319
- int32_t R = size<2>(tensor_b_);
320
- int32_t S = size<1>(tensor_b_);
321
- int32_t C = size<0>(tensor_b_);
322
-
323
- #if defined(_OPENMP)
324
- #pragma omp parallel for collapse(3)
325
- #endif
326
- for (int32_t g = 0; g < G; ++g) {
327
- for (int32_t n = 0; n < N; ++n) {
328
- for (int32_t z = 0; z < Z; ++z) {
329
- for (int32_t p = 0; p < P; ++p) {
330
- for (int32_t q = 0; q < Q; ++q) {
331
- for (int32_t k = 0; k < K; ++k) {
332
- auto accumulator = ElementAcc(0);
333
- for (int32_t t = 0; t < T; ++t) {
334
- for (int32_t r = 0; r < R; ++r) {
335
- for (int32_t s = 0; s < S; ++s) {
336
- for (int32_t c = 0; c < C; ++c) {
337
- int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
338
- int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
339
- int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
340
- if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) {
341
- auto a = tensor_a_(c, w, h, d, n, g);
342
- auto b = tensor_b_(c, s, r, t, k, g);
343
- accumulator += ElementAcc(a * b);
344
- }
345
- }
346
- }
347
- }
348
- }
349
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
350
- epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
351
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
352
- epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
353
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
354
- if (not EpilogueFusionParams::ResidualAdd) {
355
- output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
356
- }
357
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
358
- output += bias_converter(epi_fusion_params_.tensor_bias[k]);
359
- }
360
- output = epi_activation(output);
361
- if (EpilogueFusionParams::ResidualAdd) {
362
- output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
363
- }
364
- tensor_d_(k, q, p, z, n, g) = output_converter(output);
365
- }
366
- }
367
- }
368
- }
369
- }
370
- }
371
-
372
- }
373
-
374
- // Specialization for 1D dgrad kernel
375
- void dgrad_reference(cute::Int<1> spatial_dims) {
376
- int32_t G = size<3>(tensor_d_);
377
- int32_t N = size<2>(tensor_d_);
378
- int32_t W = size<1>(tensor_d_);
379
- int32_t C = size<0>(tensor_d_);
380
- int32_t K = size<2>(tensor_b_);
381
- int32_t S = size<1>(tensor_b_);
382
-
383
- #if defined(_OPENMP)
384
- #pragma omp parallel for collapse(2)
385
- #endif
386
- for (int32_t g = 0; g < G; ++g) {
387
- for (int32_t n = 0; n < N; ++n) {
388
- for (int32_t w = 0; w < W; ++w) {
389
- for (int32_t c = 0; c < C; ++c) {
390
- auto accumulator = ElementAcc(0);
391
- for (int32_t k = 0; k < K; ++k) {
392
- for (int32_t s = 0; s < S; ++s) {
393
- int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
394
-
395
- if (q % cute::get<0>(tstride_) == 0) {
396
- q /= cute::get<0>(tstride_);
397
- } else {
398
- continue;
399
- }
400
-
401
- if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) {
402
- accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g));
403
- }
404
- }
405
- }
406
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
407
- ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
408
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
409
- ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
410
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
411
- if (not EpilogueFusionParams::ResidualAdd) {
412
- output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
413
- }
414
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
415
- output += bias_converter(epi_fusion_params_.tensor_bias[c]);
416
- }
417
- output = epi_activation(output);
418
- if (EpilogueFusionParams::ResidualAdd) {
419
- output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
420
- }
421
- tensor_d_(c, w, n, g) = output_converter(output);
422
- }
423
- }
424
- }
425
- }
426
-
427
- }
428
-
429
- // Specialization for 2D dgrad kernel
430
- void dgrad_reference(cute::Int<2> spatial_dims) {
431
- int32_t G = size<4>(tensor_d_);
432
- int32_t N = size<3>(tensor_d_);
433
- int32_t H = size<2>(tensor_d_);
434
- int32_t W = size<1>(tensor_d_);
435
- int32_t C = size<0>(tensor_d_);
436
- int32_t K = size<3>(tensor_b_);
437
- int32_t R = size<2>(tensor_b_);
438
- int32_t S = size<1>(tensor_b_);
439
-
440
- #if defined(_OPENMP)
441
- #pragma omp parallel for collapse(3)
442
- #endif
443
- for (int32_t g = 0; g < G; ++g) {
444
- for (int32_t n = 0; n < N; ++n) {
445
- for (int32_t h = 0; h < H; ++h) {
446
- for (int32_t w = 0; w < W; ++w) {
447
- for (int32_t c = 0; c < C; ++c) {
448
- auto accumulator = ElementAcc(0);
449
- for (int32_t k = 0; k < K; ++k) {
450
- for (int32_t r = 0; r < R; ++r) {
451
- for (int32_t s = 0; s < S; ++s) {
452
- int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
453
- int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
454
-
455
- if (q % cute::get<0>(tstride_) == 0) {
456
- q /= cute::get<0>(tstride_);
457
- } else {
458
- continue;
459
- }
460
-
461
- if (p % cute::get<1>(tstride_) == 0) {
462
- p /= cute::get<1>(tstride_);
463
- } else {
464
- continue;
465
- }
466
-
467
- if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) {
468
- accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g));
469
- }
470
- }
471
- }
472
- }
473
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
474
- ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
475
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
476
- ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
477
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
478
- if (not EpilogueFusionParams::ResidualAdd) {
479
- output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
480
- }
481
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
482
- output += bias_converter(epi_fusion_params_.tensor_bias[c]);
483
- }
484
- output = epi_activation(output);
485
- if (EpilogueFusionParams::ResidualAdd) {
486
- output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
487
- }
488
-
489
- tensor_d_(c, w, h, n, g) = output_converter(output);
490
- }
491
- }
492
- }
493
- }
494
- }
495
-
496
- }
497
-
498
- // Specialization for 3D dgrad kernel
499
- void dgrad_reference(cute::Int<3> spatial_dims) {
500
- int32_t G = size<5>(tensor_d_);
501
- int32_t N = size<4>(tensor_d_);
502
- int32_t D = size<3>(tensor_d_);
503
- int32_t H = size<2>(tensor_d_);
504
- int32_t W = size<1>(tensor_d_);
505
- int32_t C = size<0>(tensor_d_);
506
- int32_t K = size<4>(tensor_b_);
507
- int32_t T = size<3>(tensor_b_);
508
- int32_t R = size<2>(tensor_b_);
509
- int32_t S = size<1>(tensor_b_);
510
-
511
- #if defined(_OPENMP)
512
- #pragma omp parallel for collapse(3)
513
- #endif
514
- for (int32_t g = 0; g < G; ++g) {
515
- for (int32_t n = 0; n < N; ++n) {
516
- for (int32_t d = 0; d < D; ++d) {
517
- for (int32_t h = 0; h < H; ++h) {
518
- for (int32_t w = 0; w < W; ++w) {
519
- for (int32_t c = 0; c < C; ++c) {
520
- auto accumulator = ElementAcc(0);
521
- for (int32_t k = 0; k < K; ++k) {
522
- for (int32_t t = 0; t < T; ++t) {
523
- for (int32_t r = 0; r < R; ++r) {
524
- for (int32_t s = 0; s < S; ++s) {
525
- int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
526
- int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
527
- int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_);
528
-
529
- if (q % cute::get<0>(tstride_) == 0) {
530
- q /= cute::get<0>(tstride_);
531
- } else {
532
- continue;
533
- }
534
-
535
- if (p % cute::get<1>(tstride_) == 0) {
536
- p /= cute::get<1>(tstride_);
537
- } else {
538
- continue;
539
- }
540
-
541
- if (z % cute::get<2>(tstride_) == 0) {
542
- z /= cute::get<2>(tstride_);
543
- } else {
544
- continue;
545
- }
546
-
547
- if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) {
548
- accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g));
549
- }
550
- }
551
- }
552
- }
553
- }
554
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
555
- ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
556
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
557
- ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
558
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
559
- if (not EpilogueFusionParams::ResidualAdd) {
560
- output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
561
- }
562
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
563
- output += bias_converter(epi_fusion_params_.tensor_bias[c]);
564
- }
565
- output = epi_activation(output);
566
- if (EpilogueFusionParams::ResidualAdd) {
567
- output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
568
- }
569
- tensor_d_(c, w, h, d, n, g) = output_converter(output);
570
- }
571
- }
572
- }
573
- }
574
- }
575
- }
576
-
577
- }
578
-
579
- // Specialization for 1D wgrad kernel
580
- void wgrad_reference(cute::Int<1> spatial_dims) {
581
- int32_t G = size<3>(tensor_d_);
582
- int32_t N =
583
- size<2>(tensor_a_);
584
- int32_t Q =
585
- size<1>(tensor_a_);
586
- int32_t K =
587
- size<0>(tensor_a_);
588
- int32_t S = size<1>(tensor_d_);
589
- int32_t C = size<0>(tensor_d_);
590
-
591
- #if defined(_OPENMP)
592
- #pragma omp parallel for collapse(2)
593
- #endif
594
- for (int32_t g = 0; g < G; ++g) {
595
- for (int32_t k = 0; k < K; ++k) {
596
- for (int32_t s = 0; s < S; ++s) {
597
- for (int32_t c = 0; c < C; ++c) {
598
- auto accumulator = ElementAcc(0);
599
- for (int32_t n = 0; n < N; ++n) {
600
- for (int32_t q = 0; q < Q; ++q) {
601
- int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
602
- bool is_in_bounds =
603
- detail::is_activation_in_bounds(tensor_b_, n, w, c, g);
604
- if (is_in_bounds) {
605
- auto act =
606
- tensor_b_(c, w, n, g);
607
- auto xformed_act =
608
- tensor_a_(k, q, n, g);
609
- accumulator += ElementAcc(act * xformed_act);
610
- }
611
- }
612
- }
613
-
614
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
615
- epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
616
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
617
- epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
618
-
619
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
620
- if (not EpilogueFusionParams::ResidualAdd) {
621
- output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
622
- }
623
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
624
- output += bias_converter(epi_fusion_params_.tensor_bias[c]);
625
- }
626
- output = epi_activation(output);
627
- if (EpilogueFusionParams::ResidualAdd) {
628
- output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
629
- }
630
- tensor_d_(c, s, k, g) = output_converter(output);
631
- }
632
- }
633
- }
634
- }
635
- }
636
-
637
- // Specialization for 2D wgrad kernel
638
- void wgrad_reference(cute::Int<2> spatial_dims) {
639
- int32_t G = size<4>(tensor_d_);
640
- int32_t N =
641
- size<3>(tensor_a_);
642
- int32_t P =
643
- size<2>(tensor_a_);
644
- int32_t Q =
645
- size<1>(tensor_a_);
646
- int32_t K =
647
- size<0>(tensor_a_);
648
- int32_t R = size<2>(tensor_d_);
649
- int32_t S = size<1>(tensor_d_);
650
- int32_t C = size<0>(tensor_d_);
651
-
652
- #if defined(_OPENMP)
653
- #pragma omp parallel for collapse(3)
654
- #endif
655
- for (int32_t g = 0; g < G; ++g) {
656
- for (int32_t k = 0; k < K; ++k) {
657
- for (int32_t r = 0; r < R; ++r) {
658
- for (int32_t s = 0; s < S; ++s) {
659
- for (int32_t c = 0; c < C; ++c) {
660
- auto accumulator = ElementAcc(0);
661
- for (int32_t n = 0; n < N; ++n) {
662
- for (int32_t p = 0; p < P; ++p) {
663
- for (int32_t q = 0; q < Q; ++q) {
664
- int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
665
- int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
666
- bool is_in_bounds =
667
- detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g);
668
- if (is_in_bounds) {
669
- auto act =
670
- tensor_b_(c, w, h, n, g);
671
- auto xformed_act =
672
- tensor_a_(k, q, p, n, g);
673
- accumulator += ElementAcc(act * xformed_act);
674
- }
675
- }
676
- }
677
- }
678
-
679
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
680
- epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
681
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
682
- epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
683
-
684
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
685
- if (not EpilogueFusionParams::ResidualAdd) {
686
- output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
687
- }
688
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
689
- output += bias_converter(epi_fusion_params_.tensor_bias[c]);
690
- }
691
- output = epi_activation(output);
692
- if (EpilogueFusionParams::ResidualAdd) {
693
- output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
694
- }
695
- tensor_d_(c, s, r, k, g) = output_converter(output);
696
- }
697
- }
698
- }
699
- }
700
- }
701
- }
702
-
703
- // Specialization for 3D wgrad kernel
704
- void wgrad_reference(cute::Int<3> spatial_dims) {
705
- int32_t G = size<5>(tensor_d_);
706
- int32_t N =
707
- size<4>(tensor_a_);
708
- int32_t Z =
709
- size<3>(tensor_a_);
710
- int32_t P =
711
- size<2>(tensor_a_);
712
- int32_t Q =
713
- size<1>(tensor_a_);
714
- int32_t K =
715
- size<0>(tensor_a_);
716
- int32_t T = size<3>(tensor_d_);
717
- int32_t R = size<2>(tensor_d_);
718
- int32_t S = size<1>(tensor_d_);
719
- int32_t C = size<0>(tensor_d_);
720
-
721
- #if defined(_OPENMP)
722
- #pragma omp parallel for collapse(3)
723
- #endif
724
- for (int32_t g = 0 ; g < G; ++g) {
725
- for (int32_t k = 0; k < K; ++k) {
726
- for (int32_t t = 0; t < T; ++t) {
727
- for (int32_t r = 0; r < R; ++r) {
728
- for (int32_t s = 0; s < S; ++s) {
729
- for (int32_t c = 0; c < C; ++c) {
730
- auto accumulator = ElementAcc(0);
731
- for (int32_t n = 0; n < N; ++n) {
732
- for (int32_t z = 0; z < Z; ++z) {
733
- for (int32_t p = 0; p < P; ++p) {
734
- for (int32_t q = 0; q < Q; ++q) {
735
- int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
736
- int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
737
- int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
738
- bool is_in_bounds =
739
- detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g);
740
- if (is_in_bounds) {
741
- auto act =
742
- tensor_b_(c, w, h, d, n, g);
743
- auto xformed_act =
744
- tensor_a_(k, q, p, z, n, g);
745
- accumulator += ElementAcc(act * xformed_act);
746
- }
747
- }
748
- }
749
- }
750
- }
751
-
752
- ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
753
- epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
754
- ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
755
- epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
756
-
757
- ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
758
- if (not EpilogueFusionParams::ResidualAdd) {
759
- output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
760
- }
761
- if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
762
- output += bias_converter(epi_fusion_params_.tensor_bias[c]);
763
- }
764
- output = epi_activation(output);
765
- if (EpilogueFusionParams::ResidualAdd) {
766
- output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
767
- }
768
- tensor_d_(c, s, r, t, k, g) = output_converter(output);
769
- }
770
- }
771
- }
772
- }
773
- }
774
- }
775
- }
776
- };
777
-
778
- /////////////////////////////////////////////////////////////////////////////////////////////////
779
-
780
- } // cutlass::reference::host
781
-
782
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h DELETED
@@ -1,802 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Reference implementation for convolution in host-side code.
34
- */
35
-
36
- #pragma once
37
-
38
- #include "cutlass/coord.h"
39
- #include "cutlass/functional.h"
40
- #include "cutlass/layout/tensor.h"
41
- #include "cutlass/numeric_conversion.h"
42
- #include "cutlass/numeric_types.h"
43
- #include "cutlass/tensor_ref.h"
44
- #include "cutlass/tensor_view.h"
45
- #include "cutlass/conv/convolution.h"
46
- #include "cutlass/conv/conv2d_problem_size.h"
47
- #include "cutlass/conv/conv3d_problem_size.h"
48
- #include <iostream>
49
-
50
- namespace cutlass {
51
- namespace reference {
52
- namespace host {
53
-
54
- ////////////////////////////////////////////////////////////////////////////////////////////////////
55
- /// Forward propagation
56
- ////////////////////////////////////////////////////////////////////////////////////////////////////
57
-
58
- /// y = conv2d(x, w)
59
- template <
60
- typename ElementA,
61
- typename LayoutA,
62
- typename ElementB,
63
- typename LayoutB,
64
- typename ElementC,
65
- typename LayoutC,
66
- typename ElementCompute,
67
- typename ElementAccumulator = ElementCompute,
68
- typename ElementD = ElementC,
69
- typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
70
- typename InnerProductOp = multiply_add<ElementAccumulator>
71
- >
72
- void Conv2dFprop(
73
- conv::Conv2dProblemSize problem_size,
74
- TensorRef<ElementA, LayoutA> tensor_x,
75
- TensorRef<ElementB, LayoutB> tensor_w,
76
- TensorRef<ElementC, LayoutC> tensor_y_in,
77
- TensorRef<ElementD, LayoutC> tensor_y_out,
78
- ElementCompute alpha,
79
- ElementCompute beta) {
80
-
81
- ConvertOp convert_op;
82
- InnerProductOp inner_product_op;
83
-
84
- // Apply MMA and accumulate ElementAccumulator
85
- for (int n = 0; n < problem_size.N; ++n) {
86
- for (int p = 0; p < problem_size.P; ++p) {
87
- for (int q = 0; q < problem_size.Q; ++q) {
88
- for (int k = 0; k < problem_size.K; ++k) {
89
-
90
- int group_idx = k / (problem_size.K / problem_size.groups);
91
- int channels_per_group = problem_size.C / problem_size.groups;
92
-
93
- ElementAccumulator acc = ElementAccumulator();
94
-
95
- for (int r = 0; r < problem_size.R; ++r) {
96
- for (int s = 0; s < problem_size.S; ++s) {
97
- for (int c = 0; c < channels_per_group; ++c) {
98
-
99
- int filter_r = r;
100
- int filter_s = s;
101
-
102
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
103
- filter_r = problem_size.R - 1 - r;
104
- filter_s = problem_size.S - 1 - s;
105
- }
106
-
107
- int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
108
- int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
109
-
110
- if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
111
-
112
- ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group});
113
- ElementB b = tensor_w.at({k, r, s, c});
114
-
115
- acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
116
-
117
- }
118
- }
119
- }
120
- }
121
-
122
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
123
- ElementC c_ref = ElementC();
124
-
125
- if (beta != ElementCompute()) {
126
- c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k));
127
- }
128
-
129
- tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) =
130
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
131
- }
132
- }
133
- }
134
- }
135
- }
136
-
137
- /// Depthwise-separable convolution
138
- template <typename ElementA,
139
- typename LayoutA,
140
- typename ElementB,
141
- typename LayoutB,
142
- typename ElementC,
143
- typename LayoutC,
144
- typename ElementCompute,
145
- typename ElementAccumulator = ElementCompute,
146
- typename ElementD = ElementC,
147
- typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
148
- typename InnerProductOp = multiply_add<ElementAccumulator>>
149
- void Depsep_Fprop(cutlass::TensorView<ElementA, LayoutA> tensor_A,
150
- cutlass::TensorView<ElementB, LayoutB> tensor_B,
151
- cutlass::TensorView<ElementC, LayoutC> tensor_C,
152
- cutlass::TensorView<ElementD, LayoutC> tensor_D,
153
- ElementCompute alpha,
154
- ElementCompute beta,
155
- cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(),
156
- cutlass::Coord<2> conv_stride = cutlass::Coord<2>(),
157
- cutlass::Coord<2> dilation = cutlass::Coord<2>(),
158
- cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) {
159
-
160
- ConvertOp convert_op;
161
- InnerProductOp inner_product_op;
162
-
163
- // Apply MMA and accumulate ElementAccumulator
164
- for (int n = 0; n < tensor_C.extent().n(); ++n) {
165
- for (int p = 0; p < tensor_C.extent().h(); ++p) {
166
- for (int q = 0; q < tensor_C.extent().w(); ++q) {
167
- for (int g = 0; g < tensor_C.extent().c(); ++g) {
168
- ElementAccumulator acc = ElementAccumulator();
169
- for (int r = 0; r < tensor_B.extent().h(); ++r) {
170
- for (int s = 0; s < tensor_B.extent().w(); ++s) {
171
-
172
- // input activation H and W
173
- int h = p * conv_stride[0] - padding[0] + r * dilation[0];
174
- int w = q * conv_stride[1] - padding[2] + s * dilation[1];
175
-
176
- if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) {
177
- ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g));
178
-
179
- ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation)
180
- ? tensor_B.at(cutlass::make_Coord(g, r, s, 0))
181
- : tensor_B.at(cutlass::make_Coord(
182
- g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0));
183
-
184
- acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
185
- }
186
- }
187
- }
188
-
189
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
190
- ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g));
191
- tensor_D.at(cutlass::make_Coord(n, p, q, g)) =
192
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
193
- }
194
- }
195
- }
196
- }
197
- }
198
-
199
- ////////////////////////////////////////////////////////////////////////////////////////////////////
200
- /// Dgrad / Deconv
201
- ////////////////////////////////////////////////////////////////////////////////////////////////////
202
-
203
- /// dx = dgrad(dy, w)
204
- template <
205
- typename ElementA,
206
- typename LayoutA,
207
- typename ElementB,
208
- typename LayoutB,
209
- typename ElementC,
210
- typename LayoutC,
211
- typename ElementCompute,
212
- typename ElementAccumulator = ElementCompute,
213
- typename ElementD = ElementC,
214
- typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
215
- typename InnerProductOp = multiply_add<ElementAccumulator>
216
- >
217
- void Conv2dDgrad(
218
- cutlass::conv::Conv2dProblemSize problem_size,
219
- TensorRef<ElementA, LayoutA> tensor_dy,
220
- TensorRef<ElementB, LayoutB> tensor_w,
221
- TensorRef<ElementC, LayoutC> tensor_dx_in,
222
- TensorRef<ElementD, LayoutC> tensor_dx_out,
223
- ElementCompute alpha,
224
- ElementCompute beta,
225
- bool is_deconv = false) {
226
-
227
- ConvertOp convert_op;
228
- InnerProductOp inner_product_op;
229
-
230
- // Apply MMA and accumulate ElementAccumulator
231
- for (int n = 0; n < problem_size.N; ++n) {
232
- for (int h = 0; h < problem_size.H; ++h) {
233
- for (int w = 0; w < problem_size.W; ++w) {
234
- for (int c = 0; c < problem_size.C; ++c) {
235
-
236
- ElementAccumulator acc = ElementAccumulator();
237
-
238
- for (int r = 0; r < problem_size.R; ++r) {
239
- for (int s = 0; s < problem_size.S; ++s) {
240
- for (int k = 0; k < problem_size.K; ++k) {
241
-
242
- int filter_r = r;
243
- int filter_s = s;
244
-
245
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
246
- filter_r = problem_size.R - 1 - r;
247
- filter_s = problem_size.S - 1 - s;
248
- }
249
-
250
- int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
251
- int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
252
-
253
- if (p >= 0 && (p % problem_size.stride_h) == 0 &&
254
- q >= 0 && (q % problem_size.stride_w) == 0) {
255
-
256
- p = p / problem_size.stride_h;
257
- q = q / problem_size.stride_w;
258
- #if 0
259
- std::cout << "row:"
260
- << n * problem_size.H * problem_size.W +
261
- h * problem_size.W +
262
- w << " "
263
- << "n, p, q: ("
264
- << n << ", "
265
- << p << ", "
266
- << q << ") * "
267
- << "r, s: ("
268
- << r << ", "
269
- << s << ") ["
270
- << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]"
271
- << std::endl;
272
- #endif
273
- if (p < problem_size.P && q < problem_size.Q) {
274
-
275
- ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k));
276
- ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k))
277
- : tensor_w.at(cutlass::make_Coord(k, r, s, c));
278
-
279
- acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
280
- }
281
- }
282
-
283
- } // for (K)
284
- } // for (S)
285
- } // for (R)
286
-
287
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
288
- ElementC c_ref = ElementC();
289
-
290
- if (beta != ElementCompute()) {
291
- c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c));
292
- }
293
-
294
- tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) =
295
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
296
-
297
- } // for (C)
298
- } // for (W)
299
- } // for (H)
300
- } // for (N)
301
- }
302
-
303
- ////////////////////////////////////////////////////////////////////////////////////////////////////
304
- /// Wgrad
305
- ////////////////////////////////////////////////////////////////////////////////////////////////////
306
-
307
- /// dw = wgrad(dy, x)
308
- template <
309
- typename ElementA,
310
- typename LayoutA,
311
- typename ElementB,
312
- typename LayoutB,
313
- typename ElementC,
314
- typename LayoutC,
315
- typename ElementCompute,
316
- typename ElementAccumulator = ElementCompute,
317
- typename ElementD = ElementC,
318
- typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
319
- typename InnerProductOp = multiply_add<ElementAccumulator>
320
- >
321
- void Conv2dWgrad(
322
- cutlass::conv::Conv2dProblemSize problem_size,
323
- TensorRef<ElementA, LayoutA> tensor_dy,
324
- TensorRef<ElementB, LayoutB> tensor_x,
325
- TensorRef<ElementC, LayoutC> tensor_dw_in,
326
- TensorRef<ElementD, LayoutC> tensor_dw_out,
327
- ElementCompute alpha,
328
- ElementCompute beta) {
329
-
330
- InnerProductOp inner_product_op;
331
- ConvertOp convert_op;
332
-
333
- // Apply MMA and accumulate ElementAccumulator
334
- for (int k = 0; k < problem_size.K; ++k) {
335
- for (int r = 0; r < problem_size.R; ++r) {
336
- for (int s = 0; s < problem_size.S; ++s) {
337
- for (int c = 0; c < problem_size.C; ++c) {
338
-
339
- ElementAccumulator acc = ElementAccumulator();
340
-
341
- for (int n = 0; n < problem_size.N; ++n) {
342
- for (int p = 0; p < problem_size.P; ++p) {
343
- for (int q = 0; q < problem_size.Q; ++q) {
344
-
345
- cutlass::Tensor4DCoord b_coord;
346
-
347
- int filter_r = r;
348
- int filter_s = s;
349
-
350
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
351
- filter_r = problem_size.R - 1 - r;
352
- filter_s = problem_size.S - 1 - s;
353
- }
354
-
355
- b_coord = make_Coord(
356
- n,
357
- p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
358
- q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
359
- c);
360
-
361
- if (b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
362
- b_coord.w() < problem_size.W && b_coord.w() >= 0) {
363
-
364
- ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k)));
365
- ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
366
- acc = inner_product_op(a, b, acc);
367
- }
368
- }
369
- }
370
- }
371
-
372
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
373
- ElementC c_ref = ElementC();
374
-
375
- if (beta != ElementCompute()) {
376
- c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c));
377
- }
378
-
379
- tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) =
380
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
381
-
382
- } // for (C)
383
- } // for (S)
384
- } // for (R)
385
- } // for (K)
386
- }
387
-
388
- /// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
389
- template <
390
- typename ElementA,
391
- typename LayoutA,
392
- typename ElementB,
393
- typename LayoutB,
394
- typename ElementC,
395
- typename LayoutC,
396
- typename ElementCompute,
397
- typename ElementAccumulator = ElementCompute,
398
- typename ElementD = ElementC,
399
- typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
400
- typename InnerProductOp = multiply_add<ElementAccumulator>
401
- >
402
- void Conv2d(
403
- conv::Operator convolutional_operator,
404
- conv::Conv2dProblemSize problem_size,
405
- TensorRef<ElementA, LayoutA> tensor_A,
406
- TensorRef<ElementB, LayoutB> tensor_B,
407
- TensorRef<ElementC, LayoutC> tensor_C,
408
- TensorRef<ElementD, LayoutC> tensor_D,
409
- ElementCompute alpha,
410
- ElementCompute beta) {
411
-
412
- switch (convolutional_operator) {
413
- case conv::Operator::kFprop:
414
- Conv2dFprop<
415
- ElementA, LayoutA,
416
- ElementB, LayoutB,
417
- ElementC, LayoutC,
418
- ElementCompute,
419
- ElementAccumulator,
420
- ElementD,
421
- ConvertOp, InnerProductOp
422
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
423
- break;
424
-
425
- case conv::Operator::kDeconv:
426
- case conv::Operator::kDgrad:
427
- Conv2dDgrad<
428
- ElementA, LayoutA,
429
- ElementB, LayoutB,
430
- ElementC, LayoutC,
431
- ElementCompute,
432
- ElementAccumulator,
433
- ElementD,
434
- ConvertOp, InnerProductOp
435
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
436
- break;
437
-
438
- case conv::Operator::kWgrad:
439
- Conv2dWgrad<
440
- ElementA, LayoutA,
441
- ElementB, LayoutB,
442
- ElementC, LayoutC,
443
- ElementCompute,
444
- ElementAccumulator,
445
- ElementD,
446
- ConvertOp, InnerProductOp
447
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
448
- break;
449
-
450
- default:
451
- break;
452
- }
453
- }
454
-
455
- ////////////////////////////////////////////////////////////////////////////////////////////////////
456
- /// 3D convolution
457
- ////////////////////////////////////////////////////////////////////////////////////////////////////
458
-
459
- /// y = conv3d(x, w)
460
- template <
461
- typename ElementA,
462
- typename LayoutA,
463
- typename ElementB,
464
- typename LayoutB,
465
- typename ElementC,
466
- typename LayoutC,
467
- typename ElementCompute,
468
- typename ElementAccumulator = ElementCompute,
469
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
470
- typename InnerProductOp = multiply_add<ElementAccumulator>
471
- >
472
- void Conv3dFprop(
473
- conv::Conv3dProblemSize problem_size,
474
- TensorRef<ElementA, LayoutA> tensor_x,
475
- TensorRef<ElementB, LayoutB> tensor_w,
476
- TensorRef<ElementC, LayoutC> tensor_y_in,
477
- TensorRef<ElementC, LayoutC> tensor_y_out,
478
- ElementCompute alpha,
479
- ElementCompute beta) {
480
-
481
- ConvertOp convert_op;
482
- InnerProductOp inner_product_op;
483
-
484
- // Apply MMA and accumulate ElementAccumulator
485
- for (int n = 0; n < problem_size.N; ++n) {
486
- for (int z = 0; z < problem_size.Z; ++z) {
487
- for (int p = 0; p < problem_size.P; ++p) {
488
- for (int q = 0; q < problem_size.Q; ++q) {
489
- for (int k = 0; k < problem_size.K; ++k) {
490
-
491
- ElementAccumulator acc = ElementAccumulator();
492
-
493
- for (int t = 0; t < problem_size.T; ++t) {
494
- for (int r = 0; r < problem_size.R; ++r) {
495
- for (int s = 0; s < problem_size.S; ++s) {
496
- for (int c = 0; c < problem_size.C; ++c) {
497
-
498
- int filter_t = t;
499
- int filter_r = r;
500
- int filter_s = s;
501
-
502
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
503
- filter_t = problem_size.T - 1 - t;
504
- filter_r = problem_size.R - 1 - r;
505
- filter_s = problem_size.S - 1 - s;
506
- }
507
-
508
- int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
509
- int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
510
- int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
511
-
512
- if (d >= 0 && d < problem_size.D &&
513
- h >=0 && h < problem_size.H &&
514
- w >= 0 && w < problem_size.W) {
515
-
516
- ElementA a = tensor_x.at({n, d, h, w, c});
517
- ElementB b = tensor_w.at({k, t, r, s, c});
518
-
519
- acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
520
- }
521
- }
522
- }
523
- }
524
- }
525
-
526
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
527
- ElementC c_ref = ElementC();
528
-
529
- if (beta != ElementCompute()) {
530
- c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k));
531
- }
532
-
533
- tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) =
534
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
535
- }
536
- }
537
- }
538
- }
539
- }
540
- }
541
-
542
- ////////////////////////////////////////////////////////////////////////////////////////////////////
543
- /// Dgrad / Deconv
544
- ////////////////////////////////////////////////////////////////////////////////////////////////////
545
-
546
- /// dx = dgrad(dy, w)
547
- template <
548
- typename ElementA,
549
- typename LayoutA,
550
- typename ElementB,
551
- typename LayoutB,
552
- typename ElementC,
553
- typename LayoutC,
554
- typename ElementCompute,
555
- typename ElementAccumulator = ElementCompute,
556
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
557
- typename InnerProductOp = multiply_add<ElementAccumulator>
558
- >
559
- void Conv3dDgrad(
560
- cutlass::conv::Conv3dProblemSize problem_size,
561
- TensorRef<ElementA, LayoutA> tensor_dy,
562
- TensorRef<ElementB, LayoutB> tensor_w,
563
- TensorRef<ElementC, LayoutC> tensor_dx_in,
564
- TensorRef<ElementC, LayoutC> tensor_dx_out,
565
- ElementCompute alpha,
566
- ElementCompute beta,
567
- bool is_deconv = false) {
568
-
569
- ConvertOp convert_op;
570
- InnerProductOp inner_product_op;
571
-
572
- // Apply MMA and accumulate ElementAccumulator
573
- for (int n = 0; n < problem_size.N; ++n) {
574
- for (int d = 0; d < problem_size.D; ++d) {
575
- for (int h = 0; h < problem_size.H; ++h) {
576
- for (int w = 0; w < problem_size.W; ++w) {
577
- for (int c = 0; c < problem_size.C; ++c) {
578
-
579
- ElementAccumulator acc = ElementAccumulator();
580
-
581
- for (int t = 0; t < problem_size.T; ++t) {
582
- for (int r = 0; r < problem_size.R; ++r) {
583
- for (int s = 0; s < problem_size.S; ++s) {
584
- for (int k = 0; k < problem_size.K; ++k) {
585
-
586
- int filter_t = t;
587
- int filter_r = r;
588
- int filter_s = s;
589
-
590
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
591
- filter_t = problem_size.T - 1 - t;
592
- filter_r = problem_size.R - 1 - r;
593
- filter_s = problem_size.S - 1 - s;
594
- }
595
-
596
- int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d;
597
- int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
598
- int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
599
-
600
- if (z >= 0 && (z % problem_size.stride_d) == 0 &&
601
- p >= 0 && (p % problem_size.stride_h) == 0 &&
602
- q >= 0 && (q % problem_size.stride_w) == 0) {
603
-
604
- z = z / problem_size.stride_d;
605
- p = p / problem_size.stride_h;
606
- q = q / problem_size.stride_w;
607
-
608
- if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
609
-
610
- ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k));
611
- ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k))
612
- : tensor_w.at(cutlass::make_Coord(k, t, r, s, c));
613
- acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
614
- }
615
- }
616
-
617
- } // for (K)
618
- } // for (S)
619
- } // for (R)
620
- } // for (T)
621
-
622
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
623
- ElementC c_ref = ElementC();
624
-
625
- if (beta != ElementCompute()) {
626
- c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c));
627
- }
628
-
629
- tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) =
630
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
631
-
632
- } // for (C)
633
- } // for (W)
634
- } // for (H)
635
- } // for (D)
636
- } // for (N)
637
- }
638
-
639
- ////////////////////////////////////////////////////////////////////////////////////////////////////
640
- /// Wgrad
641
- ////////////////////////////////////////////////////////////////////////////////////////////////////
642
-
643
- /// dw = wgrad(dy, x)
644
- template <
645
- typename ElementA,
646
- typename LayoutA,
647
- typename ElementB,
648
- typename LayoutB,
649
- typename ElementC,
650
- typename LayoutC,
651
- typename ElementCompute,
652
- typename ElementAccumulator = ElementCompute,
653
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
654
- typename InnerProductOp = multiply_add<ElementAccumulator>
655
- >
656
- void Conv3dWgrad(
657
- cutlass::conv::Conv3dProblemSize problem_size,
658
- TensorRef<ElementA, LayoutA> tensor_dy,
659
- TensorRef<ElementB, LayoutB> tensor_x,
660
- TensorRef<ElementC, LayoutC> tensor_dw_in,
661
- TensorRef<ElementC, LayoutC> tensor_dw_out,
662
- ElementCompute alpha,
663
- ElementCompute beta) {
664
-
665
- InnerProductOp inner_product_op;
666
- ConvertOp convert_op;
667
-
668
- // Apply MMA and accumulate ElementAccumulator
669
- for (int k = 0; k < problem_size.K; ++k) {
670
- for (int t = 0; t < problem_size.T; ++t) {
671
- for (int r = 0; r < problem_size.R; ++r) {
672
- for (int s = 0; s < problem_size.S; ++s) {
673
- for (int c = 0; c < problem_size.C; ++c) {
674
-
675
- ElementAccumulator acc = ElementAccumulator();
676
-
677
- for (int n = 0; n < problem_size.N; ++n) {
678
- for (int z = 0; z < problem_size.Z; ++z) {
679
- for (int p = 0; p < problem_size.P; ++p) {
680
- for (int q = 0; q < problem_size.Q; ++q) {
681
-
682
- int filter_t = t;
683
- int filter_r = r;
684
- int filter_s = s;
685
-
686
- if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
687
- filter_t = problem_size.T - 1 - t;
688
- filter_r = problem_size.R - 1 - r;
689
- filter_s = problem_size.S - 1 - s;
690
- }
691
-
692
- Tensor5DCoord b_coord = make_Coord(
693
- n,
694
- z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d,
695
- p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
696
- q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
697
- c);
698
-
699
- if (b_coord.d() < problem_size.D && b_coord.d() >= 0 &&
700
- b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
701
- b_coord.w() < problem_size.W && b_coord.w() >= 0) {
702
-
703
- ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)));
704
- ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
705
-
706
- acc = inner_product_op(a, b, acc);
707
- }
708
- }
709
- }
710
- }
711
- }
712
-
713
- // Apply Epilogue, compute ElementCompute, convert and store ElementC
714
- ElementC c_ref = ElementC();
715
-
716
- if (beta != ElementCompute()) {
717
- c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c));
718
- }
719
-
720
- tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) =
721
- convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
722
-
723
- } // for (C)
724
- } // for (S)
725
- } // for (R)
726
- } // for (T)
727
- } // for (K)
728
- }
729
-
730
- ///////////////////////////////////////////////////////////////////////////////////////////////////
731
-
732
- /// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
733
- template <
734
- typename ElementA,
735
- typename LayoutA,
736
- typename ElementB,
737
- typename LayoutB,
738
- typename ElementC,
739
- typename LayoutC,
740
- typename ElementCompute,
741
- typename ElementAccumulator = ElementCompute,
742
- typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
743
- typename InnerProductOp = multiply_add<ElementAccumulator>
744
- >
745
- void Conv3d(
746
- conv::Operator convolutional_operator,
747
- conv::Conv3dProblemSize problem_size,
748
- TensorRef<ElementA, LayoutA> tensor_A,
749
- TensorRef<ElementB, LayoutB> tensor_B,
750
- TensorRef<ElementC, LayoutC> tensor_C,
751
- TensorRef<ElementC, LayoutC> tensor_D,
752
- ElementCompute alpha,
753
- ElementCompute beta) {
754
-
755
- switch (convolutional_operator) {
756
- case conv::Operator::kFprop:
757
- Conv3dFprop<
758
- ElementA, LayoutA,
759
- ElementB, LayoutB,
760
- ElementC, LayoutC,
761
- ElementCompute,
762
- ElementAccumulator,
763
- ConvertOp, InnerProductOp
764
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
765
- break;
766
-
767
- case conv::Operator::kDeconv:
768
- case conv::Operator::kDgrad:
769
- Conv3dDgrad<
770
- ElementA, LayoutA,
771
- ElementB, LayoutB,
772
- ElementC, LayoutC,
773
- ElementCompute,
774
- ElementAccumulator,
775
- ConvertOp, InnerProductOp
776
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
777
- break;
778
-
779
- case conv::Operator::kWgrad:
780
- Conv3dWgrad<
781
- ElementA, LayoutA,
782
- ElementB, LayoutB,
783
- ElementC, LayoutC,
784
- ElementCompute,
785
- ElementAccumulator,
786
- ConvertOp, InnerProductOp
787
- >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
788
- break;
789
-
790
- default:
791
- break;
792
- }
793
- }
794
-
795
- /////////////////////////////////////////////////////////////////////////////////////////////////
796
-
797
- } // namespace host
798
- } // namespace reference
799
- } // namespace cutlass
800
-
801
- /////////////////////////////////////////////////////////////////////////////////////////////////
802
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h DELETED
@@ -1,66 +0,0 @@
1
-
2
- /***************************************************************************************************
3
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- *
6
- * Redistribution and use in source and binary forms, with or without
7
- * modification, are permitted provided that the following conditions are met:
8
- *
9
- * 1. Redistributions of source code must retain the above copyright notice, this
10
- * list of conditions and the following disclaimer.
11
- *
12
- * 2. Redistributions in binary form must reproduce the above copyright notice,
13
- * this list of conditions and the following disclaimer in the documentation
14
- * and/or other materials provided with the distribution.
15
- *
16
- * 3. Neither the name of the copyright holder nor the names of its
17
- * contributors may be used to endorse or promote products derived from
18
- * this software without specific prior written permission.
19
- *
20
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- *
31
- **************************************************************************************************/
32
- #pragma once
33
-
34
- #include <cmath>
35
-
36
- #include "cutlass/cutlass.h"
37
- #include "cutlass/complex.h"
38
- #include "cutlass/util/reference/host/tensor_reduce.h"
39
- #include "cutlass/core_io.h"
40
-
41
- namespace cutlass {
42
- namespace reference {
43
- namespace host {
44
-
45
- /// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference
46
- template <
47
- typename Element,
48
- typename Layout,
49
- typename ComputeType = double
50
- >
51
- ComputeType TensorRelativeErrorMetric(
52
- TensorView<Element, Layout> view_A_computed,
53
- TensorView<Element, Layout> view_B_reference,
54
- ComputeType identity = ComputeType()
55
- ) {
56
-
57
- return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) /
58
- cutlass::reference::host::TensorNorm(view_B_reference, identity);
59
- }
60
-
61
-
62
- ///////////////////////////////////////////////////////////////////////////////////////////////////
63
-
64
- } // namespace host
65
- } // namespace reference
66
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h DELETED
@@ -1,531 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for GEMM in host-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/numeric_types.h"
39
- #include "cutlass/functional.h"
40
- #include "cutlass/numeric_conversion.h"
41
-
42
- #include "cutlass/tensor_view.h"
43
- #include "cutlass/gemm/gemm.h"
44
- #include "cutlass/arch/mma.h"
45
- #include "cutlass/util/host_tensor.h"
46
-
47
- namespace cutlass {
48
- namespace reference {
49
- namespace host {
50
-
51
- template<typename Out, typename In>
52
- struct CastIfScalar {
53
- static Out cast(In in) {
54
- return Out(in);
55
- }
56
- };
57
-
58
- template<typename OutScalar, typename In>
59
- struct CastIfScalar<cutlass::complex<OutScalar>, In> {
60
- typedef cutlass::complex<OutScalar> Out;
61
- static Out cast(In in) {
62
- return Out(static_cast<OutScalar>(in));
63
- }
64
- };
65
-
66
- template<typename OutScalar, typename InScalar>
67
- struct CastIfScalar<cutlass::complex<OutScalar>, cutlass::complex<InScalar>> {
68
- typedef cutlass::complex<OutScalar> Out;
69
- typedef cutlass::complex<InScalar> In;
70
- static Out cast(In in) {
71
- return Out(in);
72
- }
73
- };
74
-
75
- template<typename Out, typename In>
76
- Out cast_if_scalar(In in) {
77
- return CastIfScalar<Out, In>::cast(in);
78
- }
79
-
80
- ////////////////////////////////////////////////////////////////////////////////////////////////////
81
-
82
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
83
- /// objects.
84
- template <
85
- typename ElementA,
86
- typename LayoutA,
87
- typename ElementB,
88
- typename LayoutB,
89
- typename ElementC,
90
- typename LayoutC,
91
- typename ScalarType,
92
- typename ComputeType,
93
- typename InnerProductOp = multiply_add<ComputeType>,
94
- typename ConvertOp = NumericConverter<ElementC, ScalarType>
95
- >
96
- void compute_gemm(
97
- gemm::GemmCoord problem_size,
98
- ScalarType alpha,
99
- TensorRef<ElementA, LayoutA> tensor_a,
100
- TensorRef<ElementB, LayoutB> tensor_b,
101
- ScalarType beta,
102
- TensorRef<ElementC, LayoutC> tensor_c,
103
- TensorRef<ElementC, LayoutC> tensor_d,
104
- ComputeType initial_accum) {
105
-
106
- static_assert(
107
- LayoutA::kRank == 2 &&
108
- LayoutB::kRank == 2 &&
109
- LayoutC::kRank == 2, "Tensors must be of rank 2");
110
-
111
-
112
- // Note: batch is ignored.
113
- int const M = problem_size.m();
114
- int const N = problem_size.n();
115
- int const K = problem_size.k();
116
-
117
- // Blocking necessary to speedup reference implementation
118
- int const Mblock = 16;
119
- int const Nblock = 16;
120
-
121
- ConvertOp convert_op;
122
- InnerProductOp inner_product_op;
123
-
124
- for (int row_block = 0; row_block < M; row_block += Mblock) {
125
- for (int col_block = 0; col_block < N; col_block += Nblock) {
126
-
127
- ComputeType accum[Mblock][Nblock];
128
-
129
- for (int j = 0; j < Nblock; j++) {
130
- for (int i = 0; i < Mblock; i++) {
131
- accum[i][j] = initial_accum;
132
- }
133
- }
134
-
135
- for (int k_block = 0; k_block < K; ++k_block) {
136
- for (int j = 0; j < Nblock; j++) {
137
- for (int i = 0; i < Mblock; i++) {
138
- int row = row_block + i;
139
- int col = col_block + j;
140
-
141
- if (row < M && col < N) {
142
- ElementA a = tensor_a.at(MatrixCoord(row, k_block));
143
- ElementB b = tensor_b.at(MatrixCoord(k_block, col));
144
-
145
- ComputeType compute_a(cast_if_scalar<ComputeType>(a));
146
- ComputeType compute_b(cast_if_scalar<ComputeType>(b));
147
-
148
- accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
149
- }
150
- }
151
- }
152
- }
153
-
154
- for (int j = 0; j < Nblock; j++) {
155
- for (int i = 0; i < Mblock; i++) {
156
- int row = row_block + i;
157
- int col = col_block + j;
158
-
159
- MatrixCoord coord = MatrixCoord(row, col);
160
-
161
- if (row < M && col < N) {
162
- tensor_d.at(coord) = convert_op(
163
- alpha * ScalarType(accum[i][j]) +
164
- beta * ScalarType(tensor_c.at(coord)));
165
- }
166
- }
167
- }
168
- }
169
- }
170
- }
171
-
172
- ////////////////////////////////////////////////////////////////////////////////////////////////////
173
-
174
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
175
- /// objects.
176
- template <
177
- typename ElementA,
178
- typename LayoutA,
179
- typename ElementB,
180
- typename LayoutB,
181
- typename ElementC,
182
- typename LayoutC,
183
- typename ScalarType,
184
- typename ComputeType,
185
- typename InnerProductOp = multiply_add<ComputeType>,
186
- typename ConvertOp = NumericConverter<ElementC, ScalarType>
187
- >
188
- void compute_gemm(
189
- gemm::GemmCoord problem_size,
190
- ScalarType alpha,
191
- TensorRef<ElementA, LayoutA> tensor_a,
192
- TensorRef<ElementB, LayoutB> tensor_b,
193
- ScalarType beta,
194
- TensorRef<ElementC, LayoutC> tensor_c,
195
- ComputeType initial_accum) {
196
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
197
- ScalarType, ComputeType, InnerProductOp, ConvertOp>(
198
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
199
- initial_accum);
200
- }
201
-
202
- ////////////////////////////////////////////////////////////////////////////////////////////////////
203
-
204
- template <
205
- typename ElementA,
206
- typename LayoutA,
207
- typename ElementB,
208
- typename LayoutB,
209
- typename ElementC,
210
- typename LayoutC,
211
- typename ScalarType,
212
- typename ComputeType,
213
- typename InnerProductOp = cutlass::arch::OpMultiplyAdd
214
- >
215
- struct Gemm;
216
-
217
- ////////////////////////////////////////////////////////////////////////////////////////////////////
218
-
219
- /// Partial specialization for multiply-add
220
- template <typename ElementA, typename LayoutA, typename ElementB,
221
- typename LayoutB, typename ElementC, typename LayoutC,
222
- typename ScalarType, typename ComputeType>
223
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
224
- ComputeType, arch::OpMultiplyAdd> {
225
-
226
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
227
- TensorRef<ElementA, LayoutA> tensor_a,
228
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
229
- TensorRef<ElementC, LayoutC> tensor_c,
230
- ComputeType initial_accum = ComputeType(0)) {
231
- static_assert(
232
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
233
- "Tensors must be of rank 2");
234
-
235
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
236
- ScalarType, ComputeType, multiply_add<ComputeType>>(
237
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
238
- }
239
-
240
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
241
- TensorRef<ElementA, LayoutA> tensor_a,
242
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
243
- TensorRef<ElementC, LayoutC> tensor_c,
244
- TensorRef<ElementC, LayoutC> tensor_d,
245
- ComputeType initial_accum = ComputeType(0)) {
246
- static_assert(
247
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
248
- "Tensors must be of rank 2");
249
-
250
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
251
- ScalarType, ComputeType, multiply_add<ComputeType>>(
252
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
253
- }
254
- };
255
-
256
- ////////////////////////////////////////////////////////////////////////////////////////////////////
257
-
258
- /// Partial specialization for multiply-add
259
- template <typename ElementA, typename LayoutA, typename ElementB,
260
- typename LayoutB, typename ElementC, typename LayoutC,
261
- typename ScalarType, typename ComputeType>
262
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
263
- ComputeType, arch::OpMultiplyAddFastBF16> {
264
-
265
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
266
- TensorRef<ElementA, LayoutA> tensor_a,
267
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
268
- TensorRef<ElementC, LayoutC> tensor_c,
269
- ComputeType initial_accum = ComputeType(0)) {
270
- static_assert(
271
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
272
- "Tensors must be of rank 2");
273
-
274
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
275
- ScalarType, ComputeType, multiply_add<ComputeType>>(
276
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
277
- }
278
-
279
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
280
- TensorRef<ElementA, LayoutA> tensor_a,
281
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
282
- TensorRef<ElementC, LayoutC> tensor_c,
283
- TensorRef<ElementC, LayoutC> tensor_d,
284
- ComputeType initial_accum = ComputeType(0)) {
285
- static_assert(
286
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
287
- "Tensors must be of rank 2");
288
-
289
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
290
- ScalarType, ComputeType, multiply_add<ComputeType>>(
291
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
292
- }
293
- };
294
-
295
- ////////////////////////////////////////////////////////////////////////////////////////////////////
296
-
297
- /// Partial specialization for multiply-add-saturate
298
- template <typename ElementA, typename LayoutA, typename ElementB,
299
- typename LayoutB, typename ElementC, typename LayoutC,
300
- typename ScalarType, typename ComputeType>
301
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
302
- ComputeType, arch::OpMultiplyAddSaturate> {
303
-
304
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
305
- TensorRef<ElementA, LayoutA> tensor_a,
306
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
307
- TensorRef<ElementC, LayoutC> tensor_c,
308
- ComputeType initial_accum = ComputeType(0)) {
309
- static_assert(
310
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
311
- "Tensors must be of rank 2");
312
-
313
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
314
- ScalarType, ComputeType, multiply_add<ComputeType>,
315
- NumericConverterClamp<ElementC, ScalarType>>(
316
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
317
- }
318
-
319
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
320
- TensorRef<ElementA, LayoutA> tensor_a,
321
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
322
- TensorRef<ElementC, LayoutC> tensor_c,
323
- TensorRef<ElementC, LayoutC> tensor_d,
324
- ComputeType initial_accum = ComputeType(0)) {
325
- static_assert(
326
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
327
- "Tensors must be of rank 2");
328
-
329
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
330
- ScalarType, ComputeType, multiply_add<ComputeType>,
331
- NumericConverterClamp<ElementC, ScalarType>>(
332
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
333
- }
334
- };
335
-
336
- ////////////////////////////////////////////////////////////////////////////////////////////////////
337
-
338
- /// Partial specialization for XOR-popc
339
- template <typename ElementA, typename LayoutA, typename ElementB,
340
- typename LayoutB, typename ElementC, typename LayoutC,
341
- typename ScalarType, typename ComputeType>
342
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
343
- ComputeType, arch::OpXorPopc> {
344
-
345
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
346
- TensorRef<ElementA, LayoutA> tensor_a,
347
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
348
- TensorRef<ElementC, LayoutC> tensor_c,
349
- ComputeType initial_accum = ComputeType(0)) {
350
- static_assert(
351
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
352
- "Tensors must be of rank 2");
353
-
354
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
355
- ScalarType, ComputeType, xor_popc_add<ComputeType>>(
356
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
357
- }
358
-
359
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
360
- TensorRef<ElementA, LayoutA> tensor_a,
361
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
362
- TensorRef<ElementC, LayoutC> tensor_c,
363
- TensorRef<ElementC, LayoutC> tensor_d,
364
- ComputeType initial_accum = ComputeType(0)) {
365
- static_assert(
366
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
367
- "Tensors must be of rank 2");
368
-
369
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
370
- ScalarType, ComputeType, xor_popc_add<ComputeType>>(
371
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
372
- }
373
- };
374
-
375
- /// Partial specialization for AND-popc
376
- template <typename ElementA, typename LayoutA, typename ElementB,
377
- typename LayoutB, typename ElementC, typename LayoutC,
378
- typename ScalarType, typename ComputeType>
379
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
380
- ComputeType, arch::OpAndPopc> {
381
-
382
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
383
- TensorRef<ElementA, LayoutA> tensor_a,
384
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
385
- TensorRef<ElementC, LayoutC> tensor_c,
386
- ComputeType initial_accum = ComputeType(0)) {
387
- static_assert(
388
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
389
- "Tensors must be of rank 2");
390
-
391
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
392
- ScalarType, ComputeType, and_popc_add<ComputeType>>(
393
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
394
- }
395
-
396
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
397
- TensorRef<ElementA, LayoutA> tensor_a,
398
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
399
- TensorRef<ElementC, LayoutC> tensor_c,
400
- TensorRef<ElementC, LayoutC> tensor_d,
401
- ComputeType initial_accum = ComputeType(0)) {
402
- static_assert(
403
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
404
- "Tensors must be of rank 2");
405
-
406
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
407
- ScalarType, ComputeType, and_popc_add<ComputeType>>(
408
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
409
- }
410
- };
411
-
412
- ////////////////////////////////////////////////////////////////////////////////////////////////////
413
-
414
- /// Partial specialization for multiply-add
415
- template <typename ElementA, typename LayoutA, typename ElementB,
416
- typename LayoutB, typename ElementC, typename LayoutC,
417
- typename ScalarType, typename ComputeType>
418
- struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
419
- ComputeType, arch::OpMultiplyAddFastF32> {
420
-
421
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
422
- TensorRef<ElementA, LayoutA> tensor_a,
423
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
424
- TensorRef<ElementC, LayoutC> tensor_c,
425
- ComputeType initial_accum = ComputeType(0)) {
426
- static_assert(
427
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
428
- "Tensors must be of rank 2");
429
-
430
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
431
- ScalarType, ComputeType, multiply_add<ComputeType>>(
432
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
433
- }
434
-
435
- void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
436
- TensorRef<ElementA, LayoutA> tensor_a,
437
- TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
438
- TensorRef<ElementC, LayoutC> tensor_c,
439
- TensorRef<ElementC, LayoutC> tensor_d,
440
- ComputeType initial_accum = ComputeType(0)) {
441
- static_assert(
442
- LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
443
- "Tensors must be of rank 2");
444
-
445
- compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
446
- ScalarType, ComputeType, multiply_add<ComputeType>>(
447
- problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
448
- }
449
- };
450
-
451
- ////////////////////////////////////////////////////////////////////////////////////////////////////
452
-
453
- ////////////////////////////////////////////////////////////////////////////////////////////////////
454
- //
455
- // Batched GEMM
456
- //
457
- ////////////////////////////////////////////////////////////////////////////////////////////////////
458
-
459
- /// Computes a batch of GEMMs over a set of matrices of common dimension.
460
- //
461
- // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
462
- //
463
- template <
464
- typename TensorRefCollectionA,
465
- typename TensorRefCollectionB,
466
- typename TensorRefCollectionC,
467
- typename ScalarType,
468
- typename AccumulatorType
469
- >
470
- void BatchedGemm(
471
- gemm::GemmCoord problem_size,
472
- int batch_count,
473
- ScalarType alpha,
474
- TensorRefCollectionA const& tensor_a,
475
- TensorRefCollectionB const& tensor_b,
476
- ScalarType beta,
477
- TensorRefCollectionC &tensor_c,
478
- AccumulatorType initial_accum) {
479
-
480
- typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
481
- typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
482
- typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
483
-
484
- for (int batch = 0;
485
- batch < batch_count;
486
- ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
487
-
488
- Gemm<typename TensorRefCollectionA::Element,
489
- typename TensorRefCollectionA::Layout,
490
- typename TensorRefCollectionB::Element,
491
- typename TensorRefCollectionB::Layout,
492
- typename TensorRefCollectionC::Element,
493
- typename TensorRefCollectionC::Layout,
494
- typename TensorRefCollectionC::Element,
495
- typename TensorRefCollectionC::Element>
496
- gemm;
497
-
498
- gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
499
- initial_accum);
500
- }
501
- }
502
-
503
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
504
- /// objects.
505
- //
506
- // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
507
- //
508
- template <
509
- typename TensorRefCollectionA,
510
- typename TensorRefCollectionB,
511
- typename TensorRefCollectionC,
512
- typename ScalarType,
513
- typename AccumulatorType
514
- >
515
- void BatchedGemm(
516
- gemm::GemmCoord problem_size,
517
- int batch_count,
518
- ScalarType alpha,
519
- TensorRefCollectionA const& tensor_a,
520
- TensorRefCollectionB const& tensor_b,
521
- ScalarType beta,
522
- TensorRefCollectionC &tensor_c) {
523
-
524
- BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
525
- }
526
-
527
- ////////////////////////////////////////////////////////////////////////////////////////////////////
528
-
529
- } // namespace host
530
- } // namespace reference
531
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h DELETED
@@ -1,210 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for complex-valued GEMM in host-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/complex.h"
39
- #include "cutlass/numeric_types.h"
40
- #include "cutlass/functional.h"
41
- #include "cutlass/numeric_conversion.h"
42
- #include "cutlass/matrix_coord.h"
43
-
44
- #include "cutlass/tensor_view.h"
45
-
46
- #include "cutlass/gemm/gemm.h"
47
-
48
- namespace cutlass {
49
- namespace reference {
50
- namespace host {
51
-
52
- ////////////////////////////////////////////////////////////////////////////////////////////////////
53
-
54
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
- /// objects.
56
- ///
57
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
58
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
59
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
60
- /// arguments explicitly.
61
- template <
62
- typename ElementA,
63
- typename LayoutA,
64
- typename ElementB,
65
- typename LayoutB,
66
- typename ElementC,
67
- typename LayoutC,
68
- typename ScalarType,
69
- typename ComputeType,
70
- typename ElementD = ElementC,
71
- typename ConvertOp = NumericConverter<ElementD, ScalarType>,
72
- typename InnerProductOp = multiply_add<ComputeType>
73
- >
74
- void GemmComplex(
75
- gemm::GemmCoord problem_size,
76
- ScalarType alpha,
77
- TensorRef<ElementA, LayoutA> tensor_a,
78
- ComplexTransform transform_a,
79
- TensorRef<ElementB, LayoutB> tensor_b,
80
- ComplexTransform transform_b,
81
- ScalarType beta,
82
- TensorRef<ElementC, LayoutC> tensor_c,
83
- TensorRef<ElementD, LayoutC> tensor_d,
84
- ComputeType initial_accum,
85
- int batch_count = 1,
86
- int64_t batch_stride_A = 0,
87
- int64_t batch_stride_B = 0,
88
- int64_t batch_stride_C = 0,
89
- int64_t batch_stride_D = 0) {
90
-
91
- static_assert(
92
- LayoutA::kRank == 2 &&
93
- LayoutB::kRank == 2 &&
94
- LayoutC::kRank == 2, "Tensors must be of rank 2");
95
-
96
- // Note: batch is ignored.
97
- int const M = problem_size.m();
98
- int const N = problem_size.n();
99
- int const K = problem_size.k();
100
-
101
- // Blocking necessary to speedup reference implementation
102
- int const Mblock = 16;
103
- int const Nblock = 16;
104
-
105
- ConvertOp convert_op;
106
- InnerProductOp inner_product_op;
107
-
108
- for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
109
-
110
- // Compute matrix product using blocks
111
- for (int row_block = 0; row_block < M; row_block += Mblock) {
112
- for (int col_block = 0; col_block < N; col_block += Nblock) {
113
-
114
- ComputeType accum[Mblock][Nblock];
115
-
116
- for (int j = 0; j < Nblock; j++) {
117
- for (int i = 0; i < Mblock; i++) {
118
- accum[i][j] = initial_accum;
119
- }
120
- }
121
-
122
- for (int k_block = 0; k_block < K; ++k_block) {
123
- for (int j = 0; j < Nblock; j++) {
124
- for (int i = 0; i < Mblock; i++) {
125
- int row = row_block + i;
126
- int col = col_block + j;
127
-
128
- if (row < M && col < N) {
129
- ElementA a = tensor_a.at(MatrixCoord(row, k_block));
130
- ElementB b = tensor_b.at(MatrixCoord(k_block, col));
131
-
132
- ComputeType a_ik = ComputeType(a);
133
- ComputeType b_kj = ComputeType(b);
134
-
135
- if (transform_a == ComplexTransform::kConjugate) {
136
- a_ik = conj(a_ik);
137
- }
138
-
139
- if (transform_b == ComplexTransform::kConjugate) {
140
- b_kj = conj(b_kj);
141
- }
142
-
143
- accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
144
- }
145
- }
146
- }
147
- }
148
-
149
- for (int j = 0; j < Nblock; j++) {
150
- for (int i = 0; i < Mblock; i++) {
151
- int row = row_block + i;
152
- int col = col_block + j;
153
-
154
- MatrixCoord coord = MatrixCoord(row, col);
155
-
156
- if (row < M && col < N) {
157
-
158
- tensor_d.at(coord) = convert_op(
159
- alpha * ScalarType(accum[i][j]) +
160
- beta * ScalarType(tensor_c.at(coord)));
161
- }
162
- }
163
- }
164
-
165
- } // for (col_block)
166
- } // for (row_block)
167
-
168
- tensor_a.add_pointer_offset(batch_stride_A);
169
- tensor_b.add_pointer_offset(batch_stride_B);
170
- tensor_c.add_pointer_offset(batch_stride_C);
171
- tensor_d.add_pointer_offset(batch_stride_D);
172
-
173
- } // for (batch_idx)
174
- }
175
-
176
- ////////////////////////////////////////////////////////////////////////////////////////////////////
177
-
178
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
179
- /// objects.
180
- ///
181
- /// This assumes the accumulator type is the same type as the scalars.
182
- template <
183
- typename ElementA,
184
- typename LayoutA,
185
- typename ElementB,
186
- typename LayoutB,
187
- typename ElementC,
188
- typename LayoutC,
189
- typename ScalarType,
190
- typename ElementD = ElementC
191
- >
192
- void GemmComplex(
193
- gemm::GemmCoord problem_size,
194
- ScalarType alpha,
195
- TensorRef<ElementA, LayoutA> tensor_a,
196
- ComplexTransform transform_a,
197
- TensorRef<ElementB, LayoutB> tensor_b,
198
- ComplexTransform transform_b,
199
- ScalarType beta,
200
- TensorRef<ElementC, LayoutC> tensor_c,
201
- TensorRef<ElementD, LayoutC> tensor_d) {
202
-
203
- GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
204
- }
205
-
206
- ////////////////////////////////////////////////////////////////////////////////////////////////////
207
-
208
- } // namespace host
209
- } // namespace reference
210
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h DELETED
@@ -1,228 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Reference implementation for complex-valued GEMM in host-side code.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/coord.h"
38
- #include "cutlass/complex.h"
39
- #include "cutlass/numeric_types.h"
40
- #include "cutlass/functional.h"
41
- #include "cutlass/numeric_conversion.h"
42
- #include "cutlass/tensor_ref_planar_complex.h"
43
-
44
- #include "cutlass/tensor_view.h"
45
- #include "cutlass/gemm/gemm.h"
46
-
47
- namespace cutlass {
48
- namespace reference {
49
- namespace host {
50
-
51
- ////////////////////////////////////////////////////////////////////////////////////////////////////
52
-
53
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
54
- /// objects.
55
- ///
56
- /// Explicitly naming types needed by this template can be cumbersome, particularly for the
57
- /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
58
- /// AccumulatorType(0) as the last function argument can be easier than naming all template
59
- /// arguments explicitly.
60
- template <
61
- typename ElementA,
62
- typename LayoutA,
63
- typename ElementB,
64
- typename LayoutB,
65
- typename ElementC,
66
- typename LayoutC,
67
- typename ScalarType,
68
- typename ComputeType,
69
- typename ConvertOp = NumericConverter<ElementC, ScalarType>,
70
- typename InnerProductOp = multiply_add<complex<ComputeType>>
71
- >
72
- void GemmPlanarComplex(
73
- gemm::GemmCoord problem_size,
74
- complex<ScalarType> alpha,
75
- TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
76
- ComplexTransform transform_a,
77
- TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
78
- ComplexTransform transform_b,
79
- complex<ScalarType> beta,
80
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
81
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
82
- complex<ComputeType> initial_accum) {
83
-
84
- static_assert(
85
- LayoutA::kRank == 2 &&
86
- LayoutB::kRank == 2 &&
87
- LayoutC::kRank == 2, "Tensors must be of rank 2");
88
-
89
- using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
90
- using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
91
- using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
92
-
93
- // Note: batch is ignored.
94
- int const M = problem_size.m();
95
- int const N = problem_size.n();
96
- int const K = problem_size.k();
97
-
98
- // Blocking necessary to speedup reference implementation
99
- int const Mblock = 16;
100
- int const Nblock = 16;
101
-
102
- ConvertOp convert_op;
103
- InnerProductOp inner_product_op;
104
-
105
- for (int row_block = 0; row_block < M; row_block += Mblock) {
106
- for (int col_block = 0; col_block < N; col_block += Nblock) {
107
-
108
- complex<ComputeType> accum[Mblock][Nblock];
109
-
110
- for (int j = 0; j < Nblock; j++) {
111
- for (int i = 0; i < Mblock; i++) {
112
- accum[i][j] = initial_accum;
113
- }
114
- }
115
-
116
- for (int k_block = 0; k_block < K; ++k_block) {
117
- for (int j = 0; j < Nblock; j++) {
118
- for (int i = 0; i < Mblock; i++) {
119
- int row = row_block + i;
120
- int col = col_block + j;
121
-
122
- if (row < M && col < N) {
123
-
124
- ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
125
- ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
126
-
127
- complex<ComputeType> a = complex<ComputeType>{
128
- ComputeType(a_ik.real()),
129
- ComputeType(a_ik.imag())
130
- };
131
-
132
- complex<ComputeType> b = complex<ComputeType>{
133
- ComputeType(b_kj.real()),
134
- ComputeType(b_kj.imag())
135
- };
136
-
137
- if (transform_a == ComplexTransform::kConjugate) {
138
- a = conj(a);
139
- }
140
-
141
- if (transform_b == ComplexTransform::kConjugate) {
142
- b = conj(b);
143
- }
144
-
145
- accum[i][j] = inner_product_op(a, b, accum[i][j]);
146
- }
147
- }
148
- }
149
- }
150
-
151
- for (int j = 0; j < Nblock; j++) {
152
- for (int i = 0; i < Mblock; i++) {
153
- int row = row_block + i;
154
- int col = col_block + j;
155
-
156
- MatrixCoord coord = MatrixCoord(row, col);
157
-
158
- if (row < M && col < N) {
159
-
160
- complex<ScalarType> acc{
161
- ScalarType(accum[i][j].real()),
162
- ScalarType(accum[i][j].imag())
163
- };
164
-
165
- ComplexC d_ij = tensor_c.at(coord);
166
-
167
- complex<ScalarType> src{
168
- ScalarType(d_ij.real()),
169
- ScalarType(d_ij.imag())
170
- };
171
-
172
- complex<ScalarType> result = alpha * acc + beta * src;
173
-
174
- d_ij.real() = convert_op(result.real());
175
- d_ij.imag() = convert_op(result.imag());
176
-
177
- tensor_d.at(coord) = d_ij;
178
- }
179
- }
180
- }
181
- }
182
- }
183
- }
184
-
185
- ////////////////////////////////////////////////////////////////////////////////////////////////////
186
-
187
- /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
188
- /// objects.
189
- ///
190
- /// This assumes the accumulator type is the same type as the scalars.
191
- template <
192
- typename ElementA,
193
- typename LayoutA,
194
- typename ElementB,
195
- typename LayoutB,
196
- typename ElementC,
197
- typename LayoutC,
198
- typename ScalarType
199
- >
200
- void GemmPlanarComplex(
201
- gemm::GemmCoord problem_size,
202
- complex<ScalarType> alpha,
203
- TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
204
- ComplexTransform transform_a,
205
- TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
206
- ComplexTransform transform_b,
207
- complex<ScalarType> beta,
208
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
209
- TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
210
-
211
- GemmPlanarComplex(
212
- problem_size,
213
- alpha,
214
- tensor_a, transform_a,
215
- tensor_b, transform_b,
216
- beta,
217
- tensor_c,
218
- tensor_d,
219
- complex<ScalarType>());
220
- }
221
-
222
- ////////////////////////////////////////////////////////////////////////////////////////////////////
223
-
224
- } // namespace host
225
- } // namespace reference
226
- } // namespace cutlass
227
-
228
- ////////////////////////////////////////////////////////////////////////////////////////////////////