Kernels
danieldk HF Staff commited on
Commit
38c7386
·
verified ·
1 Parent(s): e7410d9

Build uploaded using `kernels` (batch 7/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h +432 -0
  2. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h +252 -0
  3. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h +142 -0
  4. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h +661 -0
  5. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h +789 -0
  6. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h +1500 -0
  7. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h +641 -0
  8. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h +241 -0
  9. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h +1234 -0
  10. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h +406 -0
  11. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h +89 -0
  12. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h +125 -0
  13. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h +104 -0
  14. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h +147 -0
  15. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp +1271 -0
  16. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h +218 -0
  17. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h +132 -0
  18. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h +206 -0
  19. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h +203 -0
  20. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h +2860 -0
  21. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h +89 -0
  22. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h +561 -0
  23. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h +377 -0
  24. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h +679 -0
  25. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h +143 -0
  26. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h +78 -0
  27. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h +267 -0
  28. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp +394 -0
  29. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h +821 -0
  30. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h +1239 -0
  31. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp +94 -0
  32. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp +63 -0
  33. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp +271 -0
  34. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +917 -0
  35. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +785 -0
  36. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h +658 -0
  37. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h +519 -0
  38. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp +601 -0
  39. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h +194 -0
  40. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp +137 -0
  41. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp +448 -0
  42. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h +270 -0
  43. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h +388 -0
  44. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +269 -0
  45. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp +136 -0
  46. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp +65 -0
  47. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h +322 -0
  48. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h +1927 -0
  49. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h +2007 -0
  50. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +357 -0
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/arch/mma.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/quaternion.h"
40
+ #include "cutlass/functional.h"
41
+
42
+ #include "cutlass/layout/matrix.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ namespace cutlass {
48
+ namespace arch {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Matrix multiply-add operation
53
+ template <
54
+ /// Layout of A matrix
55
+ typename LayoutA,
56
+ /// Layout of B matrix
57
+ typename LayoutB,
58
+ /// Layout of C matrix
59
+ typename LayoutC
60
+ >
61
+ struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
62
+
63
+ using Shape = gemm::GemmShape<1, 1, 1>;
64
+ using Operator = OpMultiplyAdd;
65
+ using ElementC = float;
66
+
67
+ CUTLASS_HOST_DEVICE
68
+ void operator()(
69
+ Array<float, 1> &d,
70
+ Array<float, 1> const &a,
71
+ Array<float, 1> const &b,
72
+ Array<float, 1> const &c
73
+ ) {
74
+ d[0] = a[0] * b[0] + c[0];
75
+ }
76
+ };
77
+
78
+ /////////////////////////////////////////////////////////////////////////////////////////////////
79
+
80
+ /// Matrix multiply-add operation
81
+ template <
82
+ /// Layout of A matrix
83
+ typename LayoutA,
84
+ /// Layout of B matrix
85
+ typename LayoutB,
86
+ /// Layout of C matrix
87
+ typename LayoutC
88
+ >
89
+ struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
90
+
91
+ using Shape = gemm::GemmShape<1, 1, 1>;
92
+ using Operator = OpMultiplyAdd;
93
+ using ElementC = double;
94
+
95
+ CUTLASS_HOST_DEVICE
96
+ void operator()(
97
+ Array<double, 1> &d,
98
+ Array<double, 1> const &a,
99
+ Array<double, 1> const &b,
100
+ Array<double, 1> const &c
101
+ ) {
102
+
103
+ d[0] = a[0] * b[0] + c[0];
104
+ }
105
+ };
106
+
107
+ /////////////////////////////////////////////////////////////////////////////////////////////////
108
+
109
+ /// Matrix multiply-add operation
110
+ template <
111
+ /// Layout of A matrix
112
+ typename LayoutA,
113
+ /// Layout of B matrix
114
+ typename LayoutB,
115
+ /// Layout of C matrix
116
+ typename LayoutC
117
+ >
118
+ struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
119
+
120
+ using Shape = gemm::GemmShape<1, 1, 1>;
121
+ using Operator = OpMultiplyAdd;
122
+ using ElementC = int;
123
+
124
+ CUTLASS_HOST_DEVICE
125
+ void operator()(
126
+ Array<int, 1> &d,
127
+ Array<int, 1> const &a,
128
+ Array<int, 1> const &b,
129
+ Array<int, 1> const &c
130
+ ) {
131
+
132
+ d[0] = a[0] * b[0] + c[0];
133
+ }
134
+ };
135
+
136
+ /////////////////////////////////////////////////////////////////////////////////////////////////
137
+
138
+ /// Matrix multiply-add operation
139
+ template <
140
+ /// Layout of A matrix
141
+ typename LayoutA,
142
+ /// Layout of B matrix
143
+ typename LayoutB,
144
+ /// Layout of C matrix
145
+ typename LayoutC
146
+ >
147
+ struct Mma<
148
+ gemm::GemmShape<1, 1, 1>,
149
+ 1,
150
+ complex<float>,
151
+ LayoutA,
152
+ complex<float>,
153
+ LayoutB,
154
+ complex<float>,
155
+ LayoutC,
156
+ OpMultiplyAdd> {
157
+
158
+ using Shape = gemm::GemmShape<1, 1, 1>;
159
+ using Operator = OpMultiplyAddComplex;
160
+ using ElementC = complex<float>;
161
+
162
+ CUTLASS_HOST_DEVICE
163
+ void operator()(
164
+ Array<complex<float>, 1> &d,
165
+ Array<complex<float>, 1> const &a,
166
+ Array<complex<float>, 1> const &b,
167
+ Array<complex<float>, 1> const &c
168
+ ) {
169
+
170
+ d[0].real() = a[0].real() * b[0].real() + c[0].real();
171
+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
172
+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
173
+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
174
+ }
175
+ };
176
+
177
+ /////////////////////////////////////////////////////////////////////////////////////////////////
178
+
179
+ /// Matrix multiply-add operation
180
+ template <
181
+ /// Layout of A matrix
182
+ typename LayoutA,
183
+ /// Layout of B matrix
184
+ typename LayoutB,
185
+ /// Layout of C matrix
186
+ typename LayoutC
187
+ >
188
+ struct Mma<
189
+ gemm::GemmShape<1, 1, 1>,
190
+ 1,
191
+ complex<float>,
192
+ LayoutA,
193
+ float,
194
+ LayoutB,
195
+ complex<float>,
196
+ LayoutC,
197
+ OpMultiplyAdd> {
198
+
199
+ using Shape = gemm::GemmShape<1, 1, 1>;
200
+ using Operator = OpMultiplyAddComplex;
201
+ using ElementC = complex<float>;
202
+
203
+ CUTLASS_HOST_DEVICE
204
+ void operator()(
205
+ Array<complex<float>, 1> &d,
206
+ Array<complex<float>, 1> const &a,
207
+ Array<float, 1> const &b,
208
+ Array<complex<float>, 1> const &c
209
+ ) {
210
+
211
+ d[0].real() = a[0].real() * b[0] + c[0].real();
212
+ d[0].imag() = a[0].imag() * b[0] + c[0].imag();
213
+ }
214
+ };
215
+
216
+ /////////////////////////////////////////////////////////////////////////////////////////////////
217
+
218
+ /// Matrix multiply-add operation
219
+ template <
220
+ /// Layout of A matrix
221
+ typename LayoutA,
222
+ /// Layout of B matrix
223
+ typename LayoutB,
224
+ /// Layout of C matrix
225
+ typename LayoutC
226
+ >
227
+ struct Mma<
228
+ gemm::GemmShape<1, 1, 1>,
229
+ 1,
230
+ float,
231
+ LayoutA,
232
+ complex<float>,
233
+ LayoutB,
234
+ complex<float>,
235
+ LayoutC,
236
+ OpMultiplyAdd> {
237
+
238
+ using Shape = gemm::GemmShape<1, 1, 1>;
239
+ using Operator = OpMultiplyAddComplex;
240
+ using ElementC = complex<float>;
241
+
242
+ CUTLASS_HOST_DEVICE
243
+ void operator()(
244
+ Array<complex<float>, 1> &d,
245
+ Array<float, 1> const &a,
246
+ Array<complex<float>, 1> const &b,
247
+ Array<complex<float>, 1> const &c
248
+ ) {
249
+
250
+ d[0].real() = a[0] * b[0].real() + c[0].real();
251
+ d[0].imag() = a[0] * b[0].imag() + d[0].imag();
252
+ }
253
+ };
254
+
255
+ /////////////////////////////////////////////////////////////////////////////////////////////////
256
+
257
+ /// Matrix multiply-add operation
258
+ template <
259
+ /// Layout of A matrix
260
+ typename LayoutA,
261
+ /// Layout of B matrix
262
+ typename LayoutB,
263
+ /// Layout of C matrix
264
+ typename LayoutC
265
+ >
266
+ struct Mma<
267
+ gemm::GemmShape<1, 1, 1>,
268
+ 1,
269
+ complex<double>,
270
+ LayoutA,
271
+ complex<double>,
272
+ LayoutB,
273
+ complex<double>,
274
+ LayoutC,
275
+ OpMultiplyAdd> {
276
+
277
+ using Shape = gemm::GemmShape<1, 1, 1>;
278
+ using Operator = OpMultiplyAddComplex;
279
+ using ElementC = complex<double>;
280
+
281
+ CUTLASS_HOST_DEVICE
282
+ void operator()(
283
+ Array<complex<double>, 1> &d,
284
+ Array<complex<double>, 1> const &a,
285
+ Array<complex<double>, 1> const &b,
286
+ Array<complex<double>, 1> const &c
287
+ ) {
288
+
289
+ d[0].real() = a[0].real() * b[0].real() + c[0].real();
290
+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
291
+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
292
+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
293
+ }
294
+ };
295
+
296
+ /// Matrix multiply-add operation
297
+ template <
298
+ /// Layout of A matrix
299
+ typename LayoutA,
300
+ /// Layout of B matrix
301
+ typename LayoutB,
302
+ /// Layout of C matrix
303
+ typename LayoutC
304
+ >
305
+ struct Mma<
306
+ gemm::GemmShape<1, 1, 1>,
307
+ 1,
308
+ complex<double>,
309
+ LayoutA,
310
+ double,
311
+ LayoutB,
312
+ complex<double>,
313
+ LayoutC,
314
+ OpMultiplyAdd> {
315
+
316
+ using Shape = gemm::GemmShape<1, 1, 1>;
317
+ using Operator = OpMultiplyAddComplex;
318
+ using ElementC = complex<double>;
319
+
320
+ CUTLASS_HOST_DEVICE
321
+ void operator()(
322
+ Array<complex<double>, 1> &d,
323
+ Array<complex<double>, 1> const &a,
324
+ Array<double, 1> const &b,
325
+ Array<complex<double>, 1> const &c
326
+ ) {
327
+
328
+ d[0].real() = a[0].real() * b[0] + c[0].real();
329
+ d[0].imag() = a[0].imag() * b[0] + c[0].imag();
330
+ }
331
+ };
332
+
333
+ /// Matrix multiply-add operation
334
+ template <
335
+ /// Layout of A matrix
336
+ typename LayoutA,
337
+ /// Layout of B matrix
338
+ typename LayoutB,
339
+ /// Layout of C matrix
340
+ typename LayoutC
341
+ >
342
+ struct Mma<
343
+ gemm::GemmShape<1, 1, 1>,
344
+ 1,
345
+ double,
346
+ LayoutA,
347
+ complex<double>,
348
+ LayoutB,
349
+ complex<double>,
350
+ LayoutC,
351
+ OpMultiplyAdd> {
352
+
353
+ using Shape = gemm::GemmShape<1, 1, 1>;
354
+ using Operator = OpMultiplyAddComplex;
355
+ using ElementC = complex<double>;
356
+
357
+ CUTLASS_HOST_DEVICE
358
+ void operator()(
359
+ Array<complex<double>, 1> &d,
360
+ Array<double, 1> const &a,
361
+ Array<complex<double>, 1> const &b,
362
+ Array<complex<double>, 1> const &c
363
+ ) {
364
+
365
+ d[0].real() = a[0] * b[0].real() + c[0].real();
366
+ d[0].imag() = a[0] * b[0].imag() + d[0].imag();
367
+ }
368
+ };
369
+
370
+ /////////////////////////////////////////////////////////////////////////////////////////////////
371
+
372
+ /// Matrix multiply-add operation
373
+ template <
374
+ /// Layout of A matrix
375
+ typename LayoutA,
376
+ /// Layout of B matrix
377
+ typename LayoutB,
378
+ /// Layout of C matrix
379
+ typename LayoutC
380
+ >
381
+ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
382
+
383
+ using Shape = gemm::GemmShape<1, 1, 1>;
384
+ using Operator = OpMultiplyAdd;
385
+ using ElementC = float;
386
+
387
+ CUTLASS_HOST_DEVICE
388
+ void operator()(
389
+ Array<float, 1> &d,
390
+ Array<half_t, 1> const &a,
391
+ Array<half_t, 1> const &b,
392
+ Array<float, 1> const &c
393
+ ) {
394
+ d[0] = float(a[0]) * float(b[0]) + c[0];
395
+ }
396
+ };
397
+
398
+ /////////////////////////////////////////////////////////////////////////////////////////////////
399
+
400
+ /// Matrix multiply-add operation for Quaternions
401
+ template <
402
+ /// Layout of A matrix
403
+ typename LayoutA,
404
+ /// Layout of B matrix
405
+ typename LayoutB,
406
+ /// Layout of C matrix
407
+ typename LayoutC
408
+ >
409
+ struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<float>, LayoutB, Quaternion<float>, LayoutC, OpMultiplyAdd> {
410
+
411
+ using Shape = gemm::GemmShape<1, 1, 1>;
412
+ using Operator = OpMultiplyAdd;
413
+ using Element = Quaternion<float>;
414
+ using ElementC = Element;
415
+
416
+ CUTLASS_HOST_DEVICE
417
+ void operator()(
418
+ Array<Element, 1> &d,
419
+ Array<Element, 1> const &a,
420
+ Array<Element, 1> const &b,
421
+ Array<Element, 1> const &c
422
+ ) {
423
+ multiply_add<Element, Element, Element> op;
424
+ d[0] = op(a[0], b[0], c[0]);
425
+ }
426
+
427
+ };
428
+
429
+ }
430
+ }
431
+
432
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <cuda_fp16.h>
38
+
39
+ #include "cutlass/arch/mma.h"
40
+
41
+ #include "cutlass/layout/matrix.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace arch {
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Matrix multiply-add operation
51
+ template <typename LayoutA, typename LayoutB, typename LayoutC>
52
+ struct Mma<
53
+ gemm::GemmShape<2,1,1>,
54
+ 1,
55
+ half_t,
56
+ LayoutA,
57
+ half_t,
58
+ LayoutB,
59
+ half_t,
60
+ LayoutC,
61
+ OpMultiplyAdd> {
62
+
63
+ using Shape = gemm::GemmShape<2, 1, 1>;
64
+ using Operator = OpMultiplyAdd;
65
+ using ElementC = half_t;
66
+
67
+ CUTLASS_HOST_DEVICE
68
+ void operator()(
69
+ Array<half_t, 2> &d,
70
+ Array<half_t, 2> const &a,
71
+ Array<half_t, 1> const &b,
72
+ Array<half_t, 2> const &c
73
+ ) {
74
+
75
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
76
+
77
+ __half2 const & A = reinterpret_cast<__half2 const &>(a);
78
+ __half2 B = __half2half2(reinterpret_cast<__half const &>(b));
79
+ __half2 const & C = reinterpret_cast<__half2 const &>(c);
80
+
81
+ __half2 D = __hfma2(A, B, C);
82
+
83
+ d = reinterpret_cast<Array<half_t, 2> &>(D);
84
+
85
+ #else
86
+ CUTLASS_PRAGMA_UNROLL
87
+ for (int i = 0; i < 2; ++i) {
88
+ d[i] = a[i] * b[0] + c[i];
89
+ }
90
+ #endif
91
+ }
92
+ };
93
+
94
+ /////////////////////////////////////////////////////////////////////////////////////////////////
95
+
96
+ /// Matrix multiply-add operation
97
+ template <typename LayoutA, typename LayoutB>
98
+ struct Mma<
99
+ gemm::GemmShape<1,2,1>,
100
+ 1,
101
+ half_t,
102
+ LayoutA,
103
+ half_t,
104
+ LayoutB,
105
+ half_t,
106
+ layout::RowMajor,
107
+ OpMultiplyAdd> {
108
+
109
+ using Shape = gemm::GemmShape<1, 2, 1>;
110
+ using Operator = OpMultiplyAdd;
111
+ using ElementC = half_t;
112
+
113
+ CUTLASS_HOST_DEVICE
114
+ void operator()(
115
+ Array<half_t, 2> &d,
116
+ Array<half_t, 1> const &a,
117
+ Array<half_t, 2> const &b,
118
+ Array<half_t, 2> const &c
119
+ ) {
120
+
121
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
122
+
123
+ __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a));
124
+ __half2 B = reinterpret_cast<__half2 const &>(b);
125
+ __half2 const & C = reinterpret_cast<__half2 const &>(c);
126
+
127
+ __half2 D = __hfma2(A, B, C);
128
+
129
+ d = reinterpret_cast<Array<half_t, 2> &>(D);
130
+
131
+ #else
132
+ CUTLASS_PRAGMA_UNROLL
133
+ for (int i = 0; i < 2; ++i) {
134
+ d[i] = a[0] * b[i] + c[i];
135
+ }
136
+ #endif
137
+ }
138
+ };
139
+
140
+ /////////////////////////////////////////////////////////////////////////////////////////////////
141
+
142
+ /// Matrix multiply-add operation
143
+ template <>
144
+ struct Mma <
145
+ gemm::GemmShape<2, 2, 1>,
146
+ 1,
147
+ half_t,
148
+ layout::ColumnMajor,
149
+ half_t,
150
+ layout::RowMajor,
151
+ half_t,
152
+ layout::ColumnMajor,
153
+ OpMultiplyAdd> {
154
+
155
+ using Shape = gemm::GemmShape<2, 2, 1>;
156
+ using Operator = OpMultiplyAdd;
157
+ using ElementC = half_t;
158
+
159
+ CUTLASS_HOST_DEVICE
160
+ void operator()(
161
+ Array<half_t, 4> &d,
162
+ Array<half_t, 2> const &a,
163
+ Array<half_t, 2> const &b,
164
+ Array<half_t, 4> const &c
165
+ ) {
166
+
167
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
168
+
169
+ __half2 const & A = reinterpret_cast<__half2 const &>(a);
170
+ __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));
171
+ __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));
172
+
173
+ __half2 const *C = reinterpret_cast<__half2 const *>(&c);
174
+
175
+ __half2 Dlo = __hfma2(A, Blo, C[0]);
176
+ __half2 Dhi = __hfma2(A, Bhi, C[1]);
177
+
178
+ Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
179
+
180
+ D[0] = reinterpret_cast<Array<half_t, 2> const &>(Dlo);
181
+ D[1] = reinterpret_cast<Array<half_t, 2> const &>(Dhi);
182
+
183
+ #else
184
+ CUTLASS_PRAGMA_UNROLL
185
+ for (int j = 0; j < 2; ++j) {
186
+ CUTLASS_PRAGMA_UNROLL
187
+ for (int i = 0; i < 2; ++i) {
188
+ d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
189
+ }
190
+ }
191
+ #endif
192
+ }
193
+ };
194
+
195
+ /////////////////////////////////////////////////////////////////////////////////////////////////
196
+
197
+ /// Matrix multiply-add operation
198
+ template <>
199
+ struct Mma<
200
+ gemm::GemmShape<2, 2, 1>,
201
+ 1,
202
+ half_t,
203
+ layout::ColumnMajor,
204
+ half_t,
205
+ layout::RowMajor,
206
+ half_t,
207
+ layout::RowMajor,
208
+ OpMultiplyAdd> {
209
+
210
+ using Shape = gemm::GemmShape<2, 2, 1>;
211
+ using Operator = OpMultiplyAdd;
212
+ using ElementC = half_t;
213
+
214
+ CUTLASS_HOST_DEVICE
215
+ void operator()(
216
+ Array<half_t, 4> &d,
217
+ Array<half_t, 2> const &a,
218
+ Array<half_t, 2> const &b,
219
+ Array<half_t, 4> const &c
220
+ ) {
221
+
222
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
223
+
224
+ __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
225
+ __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
226
+ __half2 const & B = reinterpret_cast<__half2 const &>(b);
227
+
228
+ __half2 const *C = reinterpret_cast<__half2 const *>(&c);
229
+
230
+ __half2 Dlo = __hfma2(Alo, B, C[0]);
231
+ __half2 Dhi = __hfma2(Ahi, B, C[1]);
232
+
233
+ Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
234
+
235
+ D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);
236
+ D[1] = reinterpret_cast<Array<half_t, 2> &>(Dhi);
237
+ #else
238
+ CUTLASS_PRAGMA_UNROLL
239
+ for (int i = 0; i < 2; ++i) {
240
+ CUTLASS_PRAGMA_UNROLL
241
+ for (int j = 0; j < 2; ++j) {
242
+ d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
243
+ }
244
+ }
245
+ #endif
246
+ }
247
+ };
248
+
249
+ /////////////////////////////////////////////////////////////////////////////////////////////////
250
+
251
+ }
252
+ }
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/layout/matrix.h"
38
+
39
+ /////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ namespace cutlass {
42
+ namespace arch {
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ /// Matrix multiply-add operation
47
+ template <typename LayoutA, typename LayoutB, typename LayoutC>
48
+ struct Mma<
49
+ gemm::GemmShape<1,1,4>,
50
+ 1,
51
+ int8_t,
52
+ LayoutA,
53
+ int8_t,
54
+ LayoutB,
55
+ int,
56
+ LayoutC,
57
+ OpMultiplyAdd> {
58
+
59
+ using Shape = gemm::GemmShape<1, 1, 4>;
60
+ using Operator = OpMultiplyAdd;
61
+ using ElementC = int;
62
+
63
+ CUTLASS_HOST_DEVICE
64
+ void operator()(
65
+ Array<int, 1> &d,
66
+ Array<int8_t, 4> const &a,
67
+ Array<int8_t, 4> const &b,
68
+ Array<int, 1> const &c
69
+ ) {
70
+
71
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
72
+
73
+ unsigned const &A = reinterpret_cast<unsigned const &>(a);
74
+ unsigned const &B = reinterpret_cast<unsigned const &>(b);
75
+
76
+ asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
77
+ : "=r"(d[0])
78
+ : "r"(A), "r"(B), "r"(c[0]));
79
+
80
+ #else
81
+
82
+ d[0] = c[0];
83
+
84
+ CUTLASS_PRAGMA_UNROLL
85
+ for (int k = 0; k < 4; ++k) {
86
+ d[0] += a[k] * b[k];
87
+ }
88
+
89
+ #endif
90
+ }
91
+ };
92
+
93
+ /////////////////////////////////////////////////////////////////////////////////////////////////
94
+
95
+ /// Matrix multiply-add operation
96
+ template <typename LayoutC>
97
+ struct Mma<
98
+ gemm::GemmShape<1, 1, 2>,
99
+ 1,
100
+ int16_t,
101
+ layout::RowMajor,
102
+ int16_t,
103
+ layout::ColumnMajor,
104
+ int,
105
+ LayoutC,
106
+ OpMultiplyAdd> {
107
+
108
+ using Shape = gemm::GemmShape<1, 1, 2>;
109
+ using Operator = OpMultiplyAdd;
110
+ using ElementC = int;
111
+
112
+ CUTLASS_HOST_DEVICE
113
+ void operator()(
114
+ Array<int, 1> &d,
115
+ Array<int16_t, 2> const &a,
116
+ Array<int16_t, 2> const &b,
117
+ Array<int, 1> const &c
118
+ ) {
119
+
120
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
121
+
122
+ unsigned const &A = reinterpret_cast<unsigned const &>(a);
123
+ unsigned const &B = reinterpret_cast<unsigned const &>(b);
124
+
125
+ asm volatile("dp2a.s32.s32 %0, %1, %2, %3;"
126
+ : "=r"(d[0])
127
+ : "r"(A), "r"(B), "r"(c[0]));
128
+ #else
129
+ d[0] = c[0];
130
+
131
+ CUTLASS_PRAGMA_UNROLL
132
+ for (int k = 0; k < 2; ++k) {
133
+ d[0] += a[k] * b[k];
134
+ }
135
+ #endif
136
+ }
137
+ };
138
+
139
+ /////////////////////////////////////////////////////////////////////////////////////////////////
140
+
141
+ }
142
+ }
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+ #pragma once
35
+ #include "cutlass/cutlass.h"
36
+ #include CUDA_STD_HEADER(cassert)
37
+
38
+ #include "mma.h"
39
+ #include "cutlass/layout/matrix.h"
40
+ #include "cutlass/numeric_types.h"
41
+
42
+ #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
43
+ #define CUTLASS_ARCH_MMA_SM70_SUPPORTED
44
+ #endif
45
+
46
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
47
+
48
+ #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1))
49
+ #define CUTLASS_ARCH_MMA_SM70_ENABLED
50
+ #endif
51
+
52
+ #endif
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace cutlass {
57
+ namespace arch {
58
+
59
+ /////////////////////////////////////////////////////////////////////////////////////////////////
60
+ //
61
+ // Matrix multiply accumulate 884 - FP16 accumulation
62
+ //
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ /// Matrix multiply-add operation: F16 = F16 * F16 + F16
66
+ template <>
67
+ struct Mma<
68
+ gemm::GemmShape<8,8,4>,
69
+ 8,
70
+ half_t,
71
+ layout::ColumnMajor,
72
+ half_t,
73
+ layout::ColumnMajor,
74
+ half_t,
75
+ layout::RowMajor,
76
+ OpMultiplyAdd> {
77
+
78
+ using Shape = gemm::GemmShape<8, 8, 4>;
79
+
80
+ using ElementA = half_t;
81
+ using LayoutA = layout::ColumnMajor;
82
+ using FragmentA = Array<half_t, 4>;
83
+
84
+ using ElementB = half_t;
85
+ using LayoutB = layout::ColumnMajor;
86
+ using FragmentB = Array<half_t, 4>;
87
+
88
+ using ElementC = half_t;
89
+ using LayoutC = layout::RowMajor;
90
+ using FragmentC = Array<half_t, 8>;
91
+
92
+ using Operator = OpMultiplyAdd;
93
+ using ArchTag = arch::Sm70;
94
+
95
+ CUTLASS_HOST_DEVICE
96
+ void operator()(
97
+ FragmentC &d,
98
+ FragmentA const &a,
99
+ FragmentB const &b,
100
+ FragmentC const &c
101
+ ) {
102
+
103
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
104
+
105
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
106
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
107
+ unsigned const *C = reinterpret_cast<unsigned const *>(&c);
108
+ unsigned *D = reinterpret_cast<unsigned *>(&d);
109
+
110
+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
111
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
112
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
113
+ );
114
+
115
+ #else
116
+ assert(0);
117
+ #if defined(__CUDA_ARCH__)
118
+ asm volatile ("brkpt;\n" ::);
119
+ #endif
120
+ #endif
121
+ }
122
+ };
123
+
124
+ /// Matrix multiply-add operation: F16 = F16 * F16 + F16
125
+ template <>
126
+ struct Mma<
127
+ gemm::GemmShape<8, 8, 4>,
128
+ 8,
129
+ half_t,
130
+ layout::ColumnMajor,
131
+ half_t,
132
+ layout::RowMajor,
133
+ half_t,
134
+ layout::RowMajor,
135
+ OpMultiplyAdd> {
136
+
137
+ using Shape = gemm::GemmShape<8, 8, 4>;
138
+
139
+ using ElementA = half_t;
140
+ using LayoutA = layout::ColumnMajor;
141
+ using FragmentA = Array<half_t, 4>;
142
+
143
+ using ElementB = half_t;
144
+ using LayoutB = layout::RowMajor;
145
+ using FragmentB = Array<half_t, 4>;
146
+
147
+ using ElementC = half_t;
148
+ using LayoutC = layout::RowMajor;
149
+ using FragmentC = Array<half_t, 8>;
150
+
151
+ using Operator = OpMultiplyAdd;
152
+ using ArchTag = arch::Sm70;
153
+
154
+ CUTLASS_HOST_DEVICE
155
+ void operator()(
156
+ FragmentC &d,
157
+ FragmentA const &a,
158
+ FragmentB const &b,
159
+ FragmentC const &c
160
+ ) {
161
+
162
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
163
+
164
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
165
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
166
+ unsigned const *C = reinterpret_cast<unsigned const *>(&c);
167
+ unsigned *D = reinterpret_cast<unsigned *>(&d);
168
+
169
+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
170
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
171
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
172
+ );
173
+
174
+ #else
175
+ assert(0);
176
+ #if defined(__CUDA_ARCH__)
177
+ asm volatile ("brkpt;\n" ::);
178
+ #endif
179
+ #endif
180
+ }
181
+ };
182
+
183
+ /// Matrix multiply-add operation: F16 = F16 * F16 + F16
184
+ template <>
185
+ struct Mma<
186
+ gemm::GemmShape<8, 8, 4>,
187
+ 8,
188
+ half_t,
189
+ layout::RowMajor,
190
+ half_t,
191
+ layout::ColumnMajor,
192
+ half_t,
193
+ layout::RowMajor,
194
+ OpMultiplyAdd> {
195
+
196
+ using Shape = gemm::GemmShape<8, 8, 4>;
197
+
198
+ using ElementA = half_t;
199
+ using LayoutA = layout::RowMajor;
200
+ using FragmentA = Array<half_t, 4>;
201
+
202
+ using ElementB = half_t;
203
+ using LayoutB = layout::ColumnMajor;
204
+ using FragmentB = Array<half_t, 4>;
205
+
206
+ using ElementC = half_t;
207
+ using LayoutC = layout::RowMajor;
208
+ using FragmentC = Array<half_t, 8>;
209
+
210
+ using Operator = OpMultiplyAdd;
211
+ using ArchTag = arch::Sm70;
212
+
213
+ CUTLASS_HOST_DEVICE
214
+ void operator()(
215
+ FragmentC &d,
216
+ FragmentA const &a,
217
+ FragmentB const &b,
218
+ FragmentC const &c
219
+ ) {
220
+
221
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
222
+
223
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
224
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
225
+ unsigned const *C = reinterpret_cast<unsigned const *>(&c);
226
+ unsigned *D = reinterpret_cast<unsigned *>(&d);
227
+
228
+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
229
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
230
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
231
+ );
232
+
233
+ #else
234
+ assert(0);
235
+ #if defined(__CUDA_ARCH__)
236
+ asm volatile ("brkpt;\n" ::);
237
+ #endif
238
+ #endif
239
+ }
240
+ };
241
+
242
+ /// Matrix multiply-add operation: F16 = F16 * F16 + F16
243
+ template <>
244
+ struct Mma<
245
+ gemm::GemmShape<8, 8, 4>,
246
+ 8,
247
+ half_t,
248
+ layout::RowMajor,
249
+ half_t,
250
+ layout::RowMajor,
251
+ half_t,
252
+ layout::RowMajor,
253
+ OpMultiplyAdd> {
254
+
255
+ using Shape = gemm::GemmShape<8, 8, 4>;
256
+
257
+ using ElementA = half_t;
258
+ using LayoutA = layout::RowMajor;
259
+ using FragmentA = Array<half_t, 4>;
260
+
261
+ using ElementB = half_t;
262
+ using LayoutB = layout::RowMajor;
263
+ using FragmentB = Array<half_t, 4>;
264
+
265
+ using ElementC = half_t;
266
+ using LayoutC = layout::RowMajor;
267
+ using FragmentC = Array<half_t, 8>;
268
+
269
+ using Operator = OpMultiplyAdd;
270
+ using ArchTag = arch::Sm70;
271
+
272
+ CUTLASS_HOST_DEVICE
273
+ void operator()(
274
+ FragmentC &d,
275
+ FragmentA const &a,
276
+ FragmentB const &b,
277
+ FragmentC const &c
278
+ ) {
279
+
280
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
281
+
282
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
283
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
284
+ unsigned const *C = reinterpret_cast<unsigned const *>(&c);
285
+ unsigned *D = reinterpret_cast<unsigned *>(&d);
286
+
287
+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
288
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
289
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
290
+ );
291
+
292
+ #else
293
+ assert(0);
294
+ #if defined(__CUDA_ARCH__)
295
+ asm volatile ("brkpt;\n" ::);
296
+ #endif
297
+ #endif
298
+ }
299
+ };
300
+
301
+ /////////////////////////////////////////////////////////////////////////////////////////////////
302
+ //
303
+ // Matrix multiply accumulate 884 - FP32 accumulation
304
+ //
305
+ /////////////////////////////////////////////////////////////////////////////////////////////////
306
+
307
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
308
+ template <>
309
+ struct Mma<
310
+ gemm::GemmShape<8, 8, 4>,
311
+ 8,
312
+ half_t,
313
+ layout::ColumnMajor,
314
+ half_t,
315
+ layout::ColumnMajor,
316
+ float,
317
+ layout::RowMajor,
318
+ OpMultiplyAdd> {
319
+
320
+ using Shape = gemm::GemmShape<8, 8, 4>;
321
+
322
+ using ElementA = half_t;
323
+ using LayoutA = layout::ColumnMajor;
324
+ using FragmentA = Array<half_t, 4>;
325
+
326
+ using ElementB = half_t;
327
+ using LayoutB = layout::ColumnMajor;
328
+ using FragmentB = Array<half_t, 4>;
329
+
330
+ using ElementC = float;
331
+ using LayoutC = layout::RowMajor;
332
+ using FragmentC = Array<float, 8>;
333
+
334
+ using Operator = OpMultiplyAdd;
335
+ using ArchTag = arch::Sm70;
336
+
337
+ /// Multiply-add
338
+ CUTLASS_HOST_DEVICE
339
+ void operator()(
340
+ FragmentC &d,
341
+ FragmentA const &a,
342
+ FragmentB const &b,
343
+ FragmentC const &c
344
+ ) {
345
+
346
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
347
+
348
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
349
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
350
+ float const *C = reinterpret_cast<float const *>(&c);
351
+ float *D = reinterpret_cast<float *>(&d);
352
+
353
+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
354
+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
355
+ : "=f"(D[0]),
356
+ "=f"(D[1]),
357
+ "=f"(D[2]),
358
+ "=f"(D[3]),
359
+ "=f"(D[4]),
360
+ "=f"(D[5]),
361
+ "=f"(D[6]),
362
+ "=f"(D[7])
363
+ : "r"(A[0]),
364
+ "r"(A[1]),
365
+ "r"(B[0]),
366
+ "r"(B[1]),
367
+ "f"(C[0]),
368
+ "f"(C[1]),
369
+ "f"(C[2]),
370
+ "f"(C[3]),
371
+ "f"(C[4]),
372
+ "f"(C[5]),
373
+ "f"(C[6]),
374
+ "f"(C[7])
375
+ );
376
+
377
+ #else
378
+ assert(0);
379
+ #if defined(__CUDA_ARCH__)
380
+ asm volatile ("brkpt;\n" ::);
381
+ #endif
382
+ #endif
383
+ }
384
+ };
385
+
386
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
387
+ template <>
388
+ struct Mma<
389
+ gemm::GemmShape<8, 8, 4>,
390
+ 8,
391
+ half_t,
392
+ layout::ColumnMajor,
393
+ half_t,
394
+ layout::RowMajor,
395
+ float,
396
+ layout::RowMajor,
397
+ OpMultiplyAdd> {
398
+
399
+ using Shape = gemm::GemmShape<8, 8, 4>;
400
+
401
+ using ElementA = half_t;
402
+ using LayoutA = layout::ColumnMajor;
403
+ using FragmentA = Array<half_t, 4>;
404
+
405
+ using ElementB = half_t;
406
+ using LayoutB = layout::RowMajor;
407
+ using FragmentB = Array<half_t, 4>;
408
+
409
+ using ElementC = float;
410
+ using LayoutC = layout::RowMajor;
411
+ using FragmentC = Array<float, 8>;
412
+
413
+ using Operator = OpMultiplyAdd;
414
+ using ArchTag = arch::Sm70;
415
+
416
+ /// Multiply-add
417
+ CUTLASS_HOST_DEVICE
418
+ void operator()(
419
+ FragmentC &d,
420
+ FragmentA const &a,
421
+ FragmentB const &b,
422
+ FragmentC const &c
423
+ ) {
424
+
425
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
426
+
427
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
428
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
429
+ float const *C = reinterpret_cast<float const *>(&c);
430
+ float *D = reinterpret_cast<float *>(&d);
431
+
432
+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
433
+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
434
+ : "=f"(D[0]),
435
+ "=f"(D[1]),
436
+ "=f"(D[2]),
437
+ "=f"(D[3]),
438
+ "=f"(D[4]),
439
+ "=f"(D[5]),
440
+ "=f"(D[6]),
441
+ "=f"(D[7])
442
+ : "r"(A[0]),
443
+ "r"(A[1]),
444
+ "r"(B[0]),
445
+ "r"(B[1]),
446
+ "f"(C[0]),
447
+ "f"(C[1]),
448
+ "f"(C[2]),
449
+ "f"(C[3]),
450
+ "f"(C[4]),
451
+ "f"(C[5]),
452
+ "f"(C[6]),
453
+ "f"(C[7])
454
+ );
455
+
456
+ #else
457
+ assert(0);
458
+ #if defined(__CUDA_ARCH__)
459
+ asm volatile ("brkpt;\n" ::);
460
+ #endif
461
+ #endif
462
+ }
463
+ };
464
+
465
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
466
+ template <>
467
+ struct Mma<
468
+ gemm::GemmShape<8, 8, 4>,
469
+ 8,
470
+ half_t,
471
+ layout::RowMajor,
472
+ half_t,
473
+ layout::ColumnMajor,
474
+ float,
475
+ layout::RowMajor,
476
+ OpMultiplyAdd> {
477
+
478
+ using Shape = gemm::GemmShape<8, 8, 4>;
479
+
480
+ using ElementA = half_t;
481
+ using LayoutA = layout::RowMajor;
482
+ using FragmentA = Array<half_t, 4>;
483
+
484
+ using ElementB = half_t;
485
+ using LayoutB = layout::ColumnMajor;
486
+ using FragmentB = Array<half_t, 4>;
487
+
488
+ using ElementC = float;
489
+ using LayoutC = layout::RowMajor;
490
+ using FragmentC = Array<float, 8>;
491
+
492
+ using Operator = OpMultiplyAdd;
493
+ using ArchTag = arch::Sm70;
494
+
495
+ /// Multiply-add
496
+ CUTLASS_HOST_DEVICE
497
+ void operator()(
498
+ FragmentC &d,
499
+ FragmentA const &a,
500
+ FragmentB const &b,
501
+ FragmentC const &c
502
+ ) {
503
+
504
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
505
+
506
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
507
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
508
+ float const *C = reinterpret_cast<float const *>(&c);
509
+ float *D = reinterpret_cast<float *>(&d);
510
+
511
+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
512
+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
513
+ : "=f"(D[0]),
514
+ "=f"(D[1]),
515
+ "=f"(D[2]),
516
+ "=f"(D[3]),
517
+ "=f"(D[4]),
518
+ "=f"(D[5]),
519
+ "=f"(D[6]),
520
+ "=f"(D[7])
521
+ : "r"(A[0]),
522
+ "r"(A[1]),
523
+ "r"(B[0]),
524
+ "r"(B[1]),
525
+ "f"(C[0]),
526
+ "f"(C[1]),
527
+ "f"(C[2]),
528
+ "f"(C[3]),
529
+ "f"(C[4]),
530
+ "f"(C[5]),
531
+ "f"(C[6]),
532
+ "f"(C[7])
533
+ );
534
+
535
+ #else
536
+ assert(0);
537
+ #if defined(__CUDA_ARCH__)
538
+ asm volatile ("brkpt;\n" ::);
539
+ #endif
540
+ #endif
541
+ }
542
+ };
543
+
544
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
545
+ template <>
546
+ struct Mma<
547
+ gemm::GemmShape<8, 8, 4>,
548
+ 8,
549
+ half_t,
550
+ layout::RowMajor,
551
+ half_t,
552
+ layout::RowMajor,
553
+ float,
554
+ layout::RowMajor,
555
+ OpMultiplyAdd> {
556
+
557
+ using Shape = gemm::GemmShape<8, 8, 4>;
558
+
559
+ using ElementA = half_t;
560
+ using LayoutA = layout::RowMajor;
561
+ using FragmentA = Array<half_t, 4>;
562
+
563
+ using ElementB = half_t;
564
+ using LayoutB = layout::RowMajor;
565
+ using FragmentB = Array<half_t, 4>;
566
+
567
+ using ElementC = float;
568
+ using LayoutC = layout::RowMajor;
569
+ using FragmentC = Array<float, 8>;
570
+
571
+ using Operator = OpMultiplyAdd;
572
+ using ArchTag = arch::Sm70;
573
+
574
+ /// Multiply-add
575
+ CUTLASS_HOST_DEVICE
576
+ void operator()(
577
+ FragmentC &d,
578
+ FragmentA const &a,
579
+ FragmentB const &b,
580
+ FragmentC const &c
581
+ ) {
582
+
583
+ #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
584
+
585
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
586
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
587
+ float const *C = reinterpret_cast<float const *>(&c);
588
+ float *D = reinterpret_cast<float *>(&d);
589
+
590
+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
591
+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
592
+ : "=f"(D[0]),
593
+ "=f"(D[1]),
594
+ "=f"(D[2]),
595
+ "=f"(D[3]),
596
+ "=f"(D[4]),
597
+ "=f"(D[5]),
598
+ "=f"(D[6]),
599
+ "=f"(D[7])
600
+ : "r"(A[0]),
601
+ "r"(A[1]),
602
+ "r"(B[0]),
603
+ "r"(B[1]),
604
+ "f"(C[0]),
605
+ "f"(C[1]),
606
+ "f"(C[2]),
607
+ "f"(C[3]),
608
+ "f"(C[4]),
609
+ "f"(C[5]),
610
+ "f"(C[6]),
611
+ "f"(C[7])
612
+ );
613
+
614
+ #else
615
+ assert(0);
616
+ #if defined(__CUDA_ARCH__)
617
+ asm volatile ("brkpt;\n" ::);
618
+ #endif
619
+ #endif
620
+ }
621
+ };
622
+
623
+ /////////////////////////////////////////////////////////////////////////////////////////////////
624
+
625
+ /// Matrix multiply-add operation specialized for the entire warp
626
+ template <
627
+ typename LayoutA,
628
+ typename LayoutB,
629
+ typename ElementC,
630
+ typename LayoutC,
631
+ typename Operator
632
+ >
633
+ struct Mma<
634
+ gemm::GemmShape<16, 16, 4>,
635
+ 32,
636
+ half_t,
637
+ LayoutA,
638
+ half_t,
639
+ LayoutB,
640
+ ElementC,
641
+ LayoutC,
642
+ Operator
643
+ > :
644
+ public Mma<
645
+ gemm::GemmShape<8, 8, 4>,
646
+ 8,
647
+ half_t,
648
+ LayoutA,
649
+ half_t,
650
+ LayoutB,
651
+ ElementC,
652
+ LayoutC,
653
+ Operator> {
654
+
655
+ using Shape = gemm::GemmShape<16, 16, 4>;
656
+ };
657
+
658
+ /////////////////////////////////////////////////////////////////////////////////////////////////
659
+
660
+ } // namespace arch
661
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply for SM75
33
+ */
34
+
35
+ #pragma once
36
+ #include "cutlass/cutlass.h"
37
+ #include CUDA_STD_HEADER(cassert)
38
+
39
+ #include "cutlass/arch/wmma.h"
40
+
41
+ #if defined(CUTLASS_ARCH_WMMA_ENABLED)
42
+ // CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply.
43
+ #include <mma.h>
44
+ #include "cutlass/wmma_array.h"
45
+ #endif
46
+
47
+ // CUTLASS includes
48
+ #include "cutlass/arch/mma.h"
49
+ #include "cutlass/layout/matrix.h"
50
+ #include "cutlass/numeric_types.h"
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////
53
+
54
+ #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
55
+
56
+ #define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1
57
+
58
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
59
+ #define CUTLASS_ARCH_MMA_SM75_ENABLED
60
+ #endif
61
+ #endif
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+
65
+ namespace cutlass {
66
+ namespace arch {
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+ //
70
+ // Matrix Multiply 1688 - FP16 accumulation
71
+ //
72
+ ////////////////////////////////////////////////////////////////////////////////
73
+
74
+ /// Matrix multiply-add operation - F16 = F16 * F16 + F16
75
+ template <>
76
+ struct Mma<
77
+ gemm::GemmShape<16, 8, 8>,
78
+ 32,
79
+ half_t,
80
+ layout::RowMajor,
81
+ half_t,
82
+ layout::ColumnMajor,
83
+ half_t,
84
+ layout::RowMajor,
85
+ OpMultiplyAdd> {
86
+
87
+ using Shape = gemm::GemmShape<16, 8, 8>;
88
+
89
+ using ElementA = half_t;
90
+ using LayoutA = layout::RowMajor;
91
+ using FragmentA = Array<half_t, 4>;
92
+
93
+ using ElementB = half_t;
94
+ using LayoutB = layout::ColumnMajor;
95
+ using FragmentB = Array<half_t, 2>;
96
+
97
+ using ElementC = half_t;
98
+ using LayoutC = layout::RowMajor;
99
+ using FragmentC = Array<half_t, 4>;
100
+
101
+ using Operator = OpMultiplyAdd;
102
+ using ArchTag = arch::Sm75;
103
+
104
+ CUTLASS_HOST_DEVICE
105
+ void operator()(
106
+ FragmentC &d,
107
+ FragmentA const &a,
108
+ FragmentB const &b,
109
+ FragmentC const &c
110
+ ) const {
111
+
112
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
113
+
114
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
115
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
116
+ unsigned const *C = reinterpret_cast<unsigned const *>(&c);
117
+ unsigned *D = reinterpret_cast<unsigned *>(&d);
118
+
119
+ asm volatile(
120
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
121
+ : "=r"(D[0]), "=r"(D[1])
122
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
123
+
124
+ #else
125
+ CUTLASS_UNUSED(a);
126
+ CUTLASS_UNUSED(b);
127
+ CUTLASS_UNUSED(c);
128
+ CUTLASS_UNUSED(d);
129
+ CUTLASS_NOT_IMPLEMENTED();
130
+ #endif
131
+ }
132
+ };
133
+
134
+ ////////////////////////////////////////////////////////////////////////////////
135
+ //
136
+ // Matrix Multiply 1688 - FP32 accumulation
137
+ //
138
+ ////////////////////////////////////////////////////////////////////////////////
139
+
140
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
141
+ template <>
142
+ struct Mma<
143
+ gemm::GemmShape<16, 8, 8>,
144
+ 32,
145
+ half_t,
146
+ layout::RowMajor,
147
+ half_t,
148
+ layout::ColumnMajor,
149
+ float,
150
+ layout::RowMajor,
151
+ OpMultiplyAdd> {
152
+
153
+ using Shape = gemm::GemmShape<16, 8, 8>;
154
+
155
+ using ElementA = half_t;
156
+ using LayoutA = layout::RowMajor;
157
+ using FragmentA = Array<half_t, 4>;
158
+
159
+ using ElementB = half_t;
160
+ using LayoutB = layout::ColumnMajor;
161
+ using FragmentB = Array<half_t, 2>;
162
+
163
+ using ElementC = float;
164
+ using LayoutC = layout::RowMajor;
165
+ using FragmentC = Array<float, 4>;
166
+
167
+ using Operator = OpMultiplyAdd;
168
+ using ArchTag = arch::Sm75;
169
+
170
+ /// Computes multiply-add
171
+ CUTLASS_HOST_DEVICE
172
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
173
+ FragmentC const &c) const {
174
+
175
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
176
+
177
+ unsigned const *A = reinterpret_cast<unsigned const *>(&a);
178
+ unsigned const *B = reinterpret_cast<unsigned const *>(&b);
179
+ float const *C = reinterpret_cast<float const *>(&c);
180
+ float *D = reinterpret_cast<float *>(&d);
181
+
182
+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
183
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
184
+ :
185
+ "r"(A[0]), "r"(A[1]),
186
+ "r"(B[0]),
187
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
188
+ );
189
+
190
+ #else
191
+ CUTLASS_UNUSED(a);
192
+ CUTLASS_UNUSED(b);
193
+ CUTLASS_UNUSED(c);
194
+ CUTLASS_UNUSED(d);
195
+ CUTLASS_NOT_IMPLEMENTED();
196
+ #endif
197
+ }
198
+ };
199
+
200
+ ////////////////////////////////////////////////////////////////////////////////
201
+ //
202
+ // Integer matrix multiply (8b) with SATURATE
203
+ //
204
+ ////////////////////////////////////////////////////////////////////////////////
205
+
206
+ /// Matrix multiply-add operation: S32 = S8 * S8 + S32
207
+ template <>
208
+ struct Mma<
209
+ gemm::GemmShape<8, 8, 16>,
210
+ 32,
211
+ int8_t,
212
+ layout::RowMajor,
213
+ int8_t,
214
+ layout::ColumnMajor,
215
+ int,
216
+ layout::RowMajor,
217
+ OpMultiplyAddSaturate> {
218
+
219
+ using Shape = gemm::GemmShape<8, 8, 16>;
220
+
221
+ using ElementA = int8_t;
222
+ using LayoutA = layout::RowMajor;
223
+ using FragmentA = Array<int8_t, 4>;
224
+
225
+ using ElementB = int8_t;
226
+ using LayoutB = layout::ColumnMajor;
227
+ using FragmentB = Array<int8_t, 4>;
228
+
229
+ using ElementC = int;
230
+ using LayoutC = layout::RowMajor;
231
+ using FragmentC = Array<int, 2>;
232
+
233
+ using Operator = OpMultiplyAddSaturate;
234
+ using ArchTag = arch::Sm75;
235
+
236
+ /// Computes multiply-add
237
+ CUTLASS_HOST_DEVICE
238
+ void operator()(
239
+ FragmentC &d,
240
+ FragmentA const &a,
241
+ FragmentB const &b,
242
+ FragmentC const &c
243
+ ) const {
244
+
245
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
246
+
247
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
248
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
249
+
250
+ int const *C = reinterpret_cast<int const *>(&c);
251
+ int *D = reinterpret_cast<int *>(&d);
252
+
253
+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
254
+ : "=r"(D[0]), "=r"(D[1])
255
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
256
+ #else
257
+ CUTLASS_UNUSED(a);
258
+ CUTLASS_UNUSED(b);
259
+ CUTLASS_UNUSED(c);
260
+ CUTLASS_UNUSED(d);
261
+ CUTLASS_NOT_IMPLEMENTED();
262
+ #endif
263
+ }
264
+ };
265
+
266
+ /// Matrix multiply-add operation: S32 = U8 * S8 + S32
267
+ template <>
268
+ struct Mma<
269
+ gemm::GemmShape<8, 8, 16>,
270
+ 32,
271
+ uint8_t,
272
+ layout::RowMajor,
273
+ int8_t,
274
+ layout::ColumnMajor,
275
+ int,
276
+ layout::RowMajor,
277
+ OpMultiplyAddSaturate> {
278
+
279
+ using Shape = gemm::GemmShape<8, 8, 16>;
280
+
281
+ using ElementA = uint8_t;
282
+ using LayoutA = layout::RowMajor;
283
+ using FragmentA = Array<uint8_t, 4>;
284
+
285
+ using ElementB = int8_t;
286
+ using LayoutB = layout::ColumnMajor;
287
+ using FragmentB = Array<int8_t, 4>;
288
+
289
+ using ElementC = int;
290
+ using LayoutC = layout::RowMajor;
291
+ using FragmentC = Array<int, 2>;
292
+
293
+ using Operator = OpMultiplyAddSaturate;
294
+ using ArchTag = arch::Sm75;
295
+
296
+ /// Computes multiply-add
297
+ CUTLASS_HOST_DEVICE
298
+ void operator()(
299
+ FragmentC &d,
300
+ FragmentA const &a,
301
+ FragmentB const &b,
302
+ FragmentC const &c
303
+ ) const {
304
+
305
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
306
+
307
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
308
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
309
+
310
+ int const *C = reinterpret_cast<int const *>(&c);
311
+ int *D = reinterpret_cast<int *>(&d);
312
+
313
+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
314
+ : "=r"(D[0]), "=r"(D[1])
315
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
316
+ #else
317
+ CUTLASS_UNUSED(a);
318
+ CUTLASS_UNUSED(b);
319
+ CUTLASS_UNUSED(c);
320
+ CUTLASS_UNUSED(d);
321
+ CUTLASS_NOT_IMPLEMENTED();
322
+ #endif
323
+ }
324
+ };
325
+
326
+ /// Matrix multiply-add operation: S32 = S8 * U8 + S32
327
+ template <>
328
+ struct Mma<
329
+ gemm::GemmShape<8, 8, 16>,
330
+ 32,
331
+ int8_t,
332
+ layout::RowMajor,
333
+ uint8_t,
334
+ layout::ColumnMajor,
335
+ int,
336
+ layout::RowMajor,
337
+ OpMultiplyAddSaturate> {
338
+
339
+ using Shape = gemm::GemmShape<8, 8, 16>;
340
+
341
+ using ElementA = int8_t;
342
+ using LayoutA = layout::RowMajor;
343
+ using FragmentA = Array<int8_t, 4>;
344
+
345
+ using ElementB = uint8_t;
346
+ using LayoutB = layout::ColumnMajor;
347
+ using FragmentB = Array<uint8_t, 4>;
348
+
349
+ using ElementC = int;
350
+ using LayoutC = layout::RowMajor;
351
+ using FragmentC = Array<int, 2>;
352
+
353
+ using Operator = OpMultiplyAddSaturate;
354
+ using ArchTag = arch::Sm75;
355
+
356
+ /// Computes multiply-add
357
+ CUTLASS_HOST_DEVICE
358
+ void operator()(
359
+ FragmentC &d,
360
+ FragmentA const &a,
361
+ FragmentB const &b,
362
+ FragmentC const &c
363
+ ) const {
364
+
365
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
366
+
367
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
368
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
369
+
370
+ int const *C = reinterpret_cast<int const *>(&c);
371
+ int *D = reinterpret_cast<int *>(&d);
372
+
373
+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
374
+ : "=r"(D[0]), "=r"(D[1])
375
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
376
+ #else
377
+ CUTLASS_UNUSED(a);
378
+ CUTLASS_UNUSED(b);
379
+ CUTLASS_UNUSED(c);
380
+ CUTLASS_UNUSED(d);
381
+ CUTLASS_NOT_IMPLEMENTED();
382
+ #endif
383
+ }
384
+ };
385
+
386
+ /// Matrix multiply-add operation: S32 = U8 * U8 + S32
387
+ template <>
388
+ struct Mma<
389
+ gemm::GemmShape<8, 8, 16>,
390
+ 32,
391
+ uint8_t,
392
+ layout::RowMajor,
393
+ uint8_t,
394
+ layout::ColumnMajor,
395
+ int,
396
+ layout::RowMajor,
397
+ OpMultiplyAddSaturate> {
398
+
399
+ using Shape = gemm::GemmShape<8, 8, 16>;
400
+
401
+ using ElementA = uint8_t;
402
+ using LayoutA = layout::RowMajor;
403
+ using FragmentA = Array<uint8_t, 4>;
404
+
405
+ using ElementB = uint8_t;
406
+ using LayoutB = layout::ColumnMajor;
407
+ using FragmentB = Array<uint8_t, 4>;
408
+
409
+ using ElementC = int;
410
+ using LayoutC = layout::RowMajor;
411
+ using FragmentC = Array<int, 2>;
412
+
413
+ using Operator = OpMultiplyAddSaturate;
414
+ using ArchTag = arch::Sm75;
415
+
416
+ /// Computes multiply-add
417
+ CUTLASS_HOST_DEVICE
418
+ void operator()(
419
+ FragmentC &d,
420
+ FragmentA const &a,
421
+ FragmentB const &b,
422
+ FragmentC const &c
423
+ ) const {
424
+
425
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
426
+
427
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
428
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
429
+
430
+ int const *C = reinterpret_cast<int const *>(&c);
431
+ int *D = reinterpret_cast<int *>(&d);
432
+
433
+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
434
+ : "=r"(D[0]), "=r"(D[1])
435
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
436
+ #else
437
+ CUTLASS_UNUSED(a);
438
+ CUTLASS_UNUSED(b);
439
+ CUTLASS_UNUSED(c);
440
+ CUTLASS_UNUSED(d);
441
+ CUTLASS_NOT_IMPLEMENTED();
442
+ #endif
443
+ }
444
+ };
445
+
446
+ ////////////////////////////////////////////////////////////////////////////////
447
+ //
448
+ // Integer matrix multiply (4b) - SATURATE
449
+ //
450
+ ////////////////////////////////////////////////////////////////////////////////
451
+
452
+ /// Matrix multiply-add operation: S32 = S4 * S4 + S32
453
+ template <>
454
+ struct Mma<
455
+ gemm::GemmShape<8, 8, 32>,
456
+ 32,
457
+ int4b_t,
458
+ layout::RowMajor,
459
+ int4b_t,
460
+ layout::ColumnMajor,
461
+ int,
462
+ layout::RowMajor,
463
+ OpMultiplyAddSaturate> {
464
+
465
+ using Shape = gemm::GemmShape<8, 8, 32>;
466
+
467
+ using ElementA = int4b_t;
468
+ using LayoutA = layout::RowMajor;
469
+ using FragmentA = Array<int4b_t, 8>;
470
+
471
+ using ElementB = int4b_t;
472
+ using LayoutB = layout::ColumnMajor;
473
+ using FragmentB = Array<int4b_t, 8>;
474
+
475
+ using ElementC = int;
476
+ using LayoutC = layout::RowMajor;
477
+ using FragmentC = Array<int, 2>;
478
+
479
+ using Operator = OpMultiplyAddSaturate;
480
+ using ArchTag = arch::Sm75;
481
+
482
+ /// Computes multiply-add
483
+ CUTLASS_HOST_DEVICE
484
+ void operator()(
485
+ FragmentC &d,
486
+ FragmentA const &a,
487
+ FragmentB const &b,
488
+ FragmentC const &c
489
+ ) const {
490
+
491
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
492
+
493
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
494
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
495
+
496
+ int const *C = reinterpret_cast<int const *>(&c);
497
+ int *D = reinterpret_cast<int *>(&d);
498
+
499
+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
500
+ : "=r"(D[0]), "=r"(D[1])
501
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
502
+ #else
503
+ CUTLASS_UNUSED(a);
504
+ CUTLASS_UNUSED(b);
505
+ CUTLASS_UNUSED(c);
506
+ CUTLASS_UNUSED(d);
507
+ CUTLASS_NOT_IMPLEMENTED();
508
+ #endif
509
+ }
510
+ };
511
+
512
+ /// Matrix multiply-add operation: S32 = U4 * S4 + S32
513
+ template <>
514
+ struct Mma<
515
+ gemm::GemmShape<8, 8, 32>,
516
+ 32,
517
+ uint4b_t,
518
+ layout::RowMajor,
519
+ int4b_t,
520
+ layout::ColumnMajor,
521
+ int,
522
+ layout::RowMajor,
523
+ OpMultiplyAddSaturate> {
524
+
525
+ using Shape = gemm::GemmShape<8, 8, 32>;
526
+
527
+ using ElementA = uint4b_t;
528
+ using LayoutA = layout::RowMajor;
529
+ using FragmentA = Array<uint4b_t, 8>;
530
+
531
+ using ElementB = int4b_t;
532
+ using LayoutB = layout::ColumnMajor;
533
+ using FragmentB = Array<int4b_t, 8>;
534
+
535
+ using ElementC = int;
536
+ using LayoutC = layout::RowMajor;
537
+ using FragmentC = Array<int, 2>;
538
+
539
+ using Operator = OpMultiplyAddSaturate;
540
+ using ArchTag = arch::Sm75;
541
+
542
+ /// Computes multiply-add
543
+ CUTLASS_HOST_DEVICE
544
+ void operator()(
545
+ FragmentC &d,
546
+ FragmentA const &a,
547
+ FragmentB const &b,
548
+ FragmentC const &c
549
+ ) const {
550
+
551
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
552
+
553
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
554
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
555
+
556
+ int const *C = reinterpret_cast<int const *>(&c);
557
+ int *D = reinterpret_cast<int *>(&d);
558
+
559
+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
560
+ : "=r"(D[0]), "=r"(D[1])
561
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
562
+ #else
563
+ CUTLASS_UNUSED(a);
564
+ CUTLASS_UNUSED(b);
565
+ CUTLASS_UNUSED(c);
566
+ CUTLASS_UNUSED(d);
567
+ CUTLASS_NOT_IMPLEMENTED();
568
+ #endif
569
+ }
570
+ };
571
+
572
+ /// Matrix multiply-add operation: S32 = S4 * U4 + S32
573
+ template <>
574
+ struct Mma<
575
+ gemm::GemmShape<8, 8, 32>,
576
+ 32,
577
+ int4b_t,
578
+ layout::RowMajor,
579
+ uint4b_t,
580
+ layout::ColumnMajor,
581
+ int,
582
+ layout::RowMajor,
583
+ OpMultiplyAddSaturate> {
584
+
585
+ using Shape = gemm::GemmShape<8, 8, 32>;
586
+
587
+ using ElementA = int4b_t;
588
+ using LayoutA = layout::RowMajor;
589
+ using FragmentA = Array<int4b_t, 8>;
590
+
591
+ using ElementB = uint4b_t;
592
+ using LayoutB = layout::ColumnMajor;
593
+ using FragmentB = Array<uint4b_t, 8>;
594
+
595
+ using ElementC = int;
596
+ using LayoutC = layout::RowMajor;
597
+ using FragmentC = Array<int, 2>;
598
+
599
+ using Operator = OpMultiplyAddSaturate;
600
+ using ArchTag = arch::Sm75;
601
+
602
+ /// Computes multiply-add
603
+ CUTLASS_HOST_DEVICE
604
+ void operator()(
605
+ FragmentC &d,
606
+ FragmentA const &a,
607
+ FragmentB const &b,
608
+ FragmentC const &c
609
+ ) const {
610
+
611
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
612
+
613
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
614
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
615
+
616
+ int const *C = reinterpret_cast<int const *>(&c);
617
+ int *D = reinterpret_cast<int *>(&d);
618
+
619
+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
620
+ : "=r"(D[0]), "=r"(D[1])
621
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
622
+ #else
623
+ CUTLASS_UNUSED(a);
624
+ CUTLASS_UNUSED(b);
625
+ CUTLASS_UNUSED(c);
626
+ CUTLASS_UNUSED(d);
627
+ CUTLASS_NOT_IMPLEMENTED();
628
+ #endif
629
+ }
630
+ };
631
+
632
+ /// Matrix multiply-add operation: S32 = U4 * U4 + S32
633
+ template <>
634
+ struct Mma<
635
+ gemm::GemmShape<8, 8, 32>,
636
+ 32,
637
+ uint4b_t,
638
+ layout::RowMajor,
639
+ uint4b_t,
640
+ layout::ColumnMajor,
641
+ int,
642
+ layout::RowMajor,
643
+ OpMultiplyAddSaturate> {
644
+
645
+ using Shape = gemm::GemmShape<8, 8, 32>;
646
+
647
+ using ElementA = uint4b_t;
648
+ using LayoutA = layout::RowMajor;
649
+ using FragmentA = Array<uint4b_t, 8>;
650
+
651
+ using ElementB = uint4b_t;
652
+ using LayoutB = layout::ColumnMajor;
653
+ using FragmentB = Array<uint4b_t, 8>;
654
+
655
+ using ElementC = int;
656
+ using LayoutC = layout::RowMajor;
657
+ using FragmentC = Array<int, 2>;
658
+
659
+ using Operator = OpMultiplyAddSaturate;
660
+ using ArchTag = arch::Sm75;
661
+
662
+ /// Computes multiply-add
663
+ CUTLASS_HOST_DEVICE
664
+ void operator()(
665
+ FragmentC &d,
666
+ FragmentA const &a,
667
+ FragmentB const &b,
668
+ FragmentC const &c
669
+ ) const {
670
+
671
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
672
+
673
+ unsigned const & A = reinterpret_cast<unsigned const &>(a);
674
+ unsigned const & B = reinterpret_cast<unsigned const &>(b);
675
+
676
+ int const *C = reinterpret_cast<int const *>(&c);
677
+ int *D = reinterpret_cast<int *>(&d);
678
+
679
+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
680
+ : "=r"(D[0]), "=r"(D[1])
681
+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
682
+ #else
683
+ CUTLASS_UNUSED(a);
684
+ CUTLASS_UNUSED(b);
685
+ CUTLASS_UNUSED(c);
686
+ CUTLASS_UNUSED(d);
687
+ CUTLASS_NOT_IMPLEMENTED();
688
+ #endif
689
+ }
690
+ };
691
+
692
+ ////////////////////////////////////////////////////////////////////////////////
693
+ //
694
+ // b1 ^ b1 + s32 => s32
695
+ //
696
+ ////////////////////////////////////////////////////////////////////////////////
697
+
698
+ /// Matrix multiply-add operation
699
+ template <>
700
+ struct Mma<
701
+ gemm::GemmShape<8,8,128>,
702
+ 32,
703
+ uint1b_t,
704
+ layout::RowMajor,
705
+ uint1b_t,
706
+ layout::ColumnMajor,
707
+ int,
708
+ layout::RowMajor,
709
+ OpXorPopc> {
710
+
711
+ using Shape = gemm::GemmShape<8,8,128>;
712
+
713
+ using ElementA = uint1b_t;
714
+ using LayoutA = layout::RowMajor;
715
+ using FragmentA = Array<uint1b_t, 32>;
716
+
717
+ using ElementB = uint1b_t;
718
+ using LayoutB = layout::ColumnMajor;
719
+ using FragmentB = Array<uint1b_t, 32>;
720
+
721
+ using ElementC = int;
722
+ using LayoutC = layout::RowMajor;
723
+ using FragmentC = Array<int, 2>;
724
+
725
+ using Operator = OpXorPopc;
726
+ using ArchTag = arch::Sm75;
727
+
728
+ /// Computes multiply-add
729
+ CUTLASS_HOST_DEVICE
730
+ void operator()(
731
+ FragmentC &d,
732
+ FragmentA const &a,
733
+ FragmentB const &b,
734
+ FragmentC const &c
735
+ ) const {
736
+
737
+ #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
738
+
739
+ #if defined(CUTLASS_ARCH_WMMA_ENABLED)
740
+ using WmmaFragmentA = nvcuda::wmma::fragment<
741
+ nvcuda::wmma::matrix_a,
742
+ Shape::kM,
743
+ Shape::kN,
744
+ Shape::kK,
745
+ nvcuda::wmma::experimental::precision::b1,
746
+ nvcuda::wmma::row_major>;
747
+
748
+ using WmmaFragmentB = nvcuda::wmma::fragment<
749
+ nvcuda::wmma::matrix_b,
750
+ Shape::kM,
751
+ Shape::kN,
752
+ Shape::kK,
753
+ nvcuda::wmma::experimental::precision::b1,
754
+ nvcuda::wmma::col_major>;
755
+
756
+ using WmmaFragmentC = nvcuda::wmma::fragment<
757
+ nvcuda::wmma::accumulator,
758
+ Shape::kM,
759
+ Shape::kN,
760
+ Shape::kK,
761
+ int>;
762
+
763
+ WmmaFragmentA const & A = reinterpret_cast<WmmaFragmentA const &>(a);
764
+ WmmaFragmentB const & B = reinterpret_cast<WmmaFragmentB const &>(b);
765
+
766
+ WmmaFragmentC const & C = reinterpret_cast<WmmaFragmentC const &>(c);
767
+ WmmaFragmentC & D = reinterpret_cast<WmmaFragmentC &>(d);
768
+
769
+ nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
770
+ nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
771
+
772
+ #else
773
+
774
+ CUTLASS_UNUSED(a);
775
+ CUTLASS_UNUSED(b);
776
+ CUTLASS_UNUSED(c);
777
+ CUTLASS_UNUSED(d);
778
+ CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.
779
+
780
+ #endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
781
+
782
+ #endif
783
+ }
784
+ };
785
+
786
+ ////////////////////////////////////////////////////////////////////////////////
787
+
788
+ } // namespace arch
789
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h ADDED
@@ -0,0 +1,1500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+ #include "cutlass/cutlass.h"
37
+ #include CUDA_STD_HEADER(cassert)
38
+
39
+ #include "mma.h"
40
+ #include "cutlass/layout/matrix.h"
41
+ #include "cutlass/numeric_types.h"
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////
44
+
45
+ #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
46
+
47
+ #define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1
48
+
49
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
50
+ #define CUTLASS_ARCH_MMA_SM80_ENABLED
51
+
52
+ #if (__CUDA_ARCH__ <= 900)
53
+ #define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED
54
+ #endif
55
+ #if (__CUDA_ARCH__ <= 890)
56
+ #define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED
57
+ #endif
58
+
59
+ #endif
60
+
61
+ #endif
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+
65
+ namespace cutlass {
66
+ namespace arch {
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+ //
70
+ // Matrix Multiply 1688 - Float BF16, FP32 accumulation
71
+ //
72
+ ////////////////////////////////////////////////////////////////////////////////
73
+
74
+ /// Matrix multiply-add operation - F32 = bf16 * bf16 + F32
75
+ template <>
76
+ struct Mma<
77
+ gemm::GemmShape<16, 8, 8>,
78
+ 32,
79
+ bfloat16_t,
80
+ layout::RowMajor,
81
+ bfloat16_t,
82
+ layout::ColumnMajor,
83
+ float,
84
+ layout::RowMajor,
85
+ OpMultiplyAdd> {
86
+
87
+ using Shape = gemm::GemmShape<16, 8, 8>;
88
+
89
+ using ElementA = bfloat16_t;
90
+ using LayoutA = layout::RowMajor;
91
+ using FragmentA = Array<bfloat16_t, 4>;
92
+
93
+ using ElementB = bfloat16_t;
94
+ using LayoutB = layout::ColumnMajor;
95
+ using FragmentB = Array<bfloat16_t, 2>;
96
+
97
+ using ElementC = float;
98
+ using LayoutC = layout::RowMajor;
99
+ using FragmentC = Array<float, 4>;
100
+
101
+ using Operator = OpMultiplyAdd;
102
+ using ArchTag = arch::Sm80;
103
+
104
+ CUTLASS_HOST_DEVICE
105
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
106
+ FragmentC const &c) const {
107
+
108
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
109
+
110
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
111
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
112
+ float const *C = reinterpret_cast<float const *>(&c);
113
+ float *D = reinterpret_cast<float *>(&d);
114
+
115
+ asm(
116
+ "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
117
+ "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
118
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
119
+ :
120
+ "r"(A[0]), "r"(A[1]),
121
+ "r"(B[0]),
122
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
123
+ );
124
+
125
+ #else
126
+
127
+ CUTLASS_UNUSED(d);
128
+ CUTLASS_UNUSED(a);
129
+ CUTLASS_UNUSED(b);
130
+ CUTLASS_UNUSED(c);
131
+ CUTLASS_NOT_IMPLEMENTED();
132
+
133
+ #endif
134
+ }
135
+ };
136
+
137
+ ////////////////////////////////////////////////////////////////////////////////
138
+ //
139
+ // Matrix Multiply 1684 - Float TF32
140
+ //
141
+ ////////////////////////////////////////////////////////////////////////////////
142
+
143
+ /// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
144
+ template <>
145
+ struct Mma<
146
+ gemm::GemmShape<16, 8, 4>,
147
+ 32,
148
+ tfloat32_t,
149
+ layout::RowMajor,
150
+ tfloat32_t,
151
+ layout::ColumnMajor,
152
+ float,
153
+ layout::RowMajor,
154
+ OpMultiplyAdd> {
155
+
156
+ using Shape = gemm::GemmShape<16, 8, 4>;
157
+
158
+ using ElementA = tfloat32_t;
159
+ using LayoutA = layout::RowMajor;
160
+ using FragmentA = Array<tfloat32_t, 2>;
161
+
162
+ using ElementB = tfloat32_t;
163
+ using LayoutB = layout::ColumnMajor;
164
+ using FragmentB = Array<tfloat32_t, 1>;
165
+
166
+ using ElementC = float;
167
+ using LayoutC = layout::RowMajor;
168
+ using FragmentC = Array<float, 4>;
169
+
170
+ using Operator = OpMultiplyAdd;
171
+ using ArchTag = arch::Sm80;
172
+
173
+ CUTLASS_HOST_DEVICE
174
+ void operator()(
175
+ FragmentC &d,
176
+ FragmentA const &a,
177
+ FragmentB const &b,
178
+ FragmentC const &c
179
+ ) const {
180
+
181
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
182
+
183
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
184
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
185
+ float const *C = reinterpret_cast<float const *>(&c);
186
+ float *D = reinterpret_cast<float *>(&d);
187
+
188
+ asm volatile(
189
+ "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
190
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
191
+ :
192
+ "r"(A[0]), "r"(A[1]),
193
+ "r"(B[0]),
194
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
195
+ );
196
+
197
+ #else
198
+
199
+ CUTLASS_UNUSED(d);
200
+ CUTLASS_UNUSED(a);
201
+ CUTLASS_UNUSED(b);
202
+ CUTLASS_UNUSED(c);
203
+ CUTLASS_NOT_IMPLEMENTED();
204
+
205
+ #endif
206
+ }
207
+ };
208
+
209
+ ////////////////////////////////////////////////////////////////////////////////
210
+ //
211
+ // Matrix Multiply 1688 - Float TF32
212
+ //
213
+ ////////////////////////////////////////////////////////////////////////////////
214
+
215
+ /// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
216
+ template <>
217
+ struct Mma<gemm::GemmShape<16, 8, 8>, 32, tfloat32_t, layout::RowMajor,
218
+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor,
219
+ OpMultiplyAdd> {
220
+ using Shape = gemm::GemmShape<16, 8, 8>;
221
+
222
+ using ElementA = tfloat32_t;
223
+ using LayoutA = layout::RowMajor;
224
+ using FragmentA = Array<tfloat32_t, 4>;
225
+
226
+ using ElementB = tfloat32_t;
227
+ using LayoutB = layout::ColumnMajor;
228
+ using FragmentB = Array<tfloat32_t, 2>;
229
+
230
+ using ElementC = float;
231
+ using LayoutC = layout::RowMajor;
232
+ using FragmentC = Array<float, 4>;
233
+
234
+ using Operator = OpMultiplyAdd;
235
+ using ArchTag = arch::Sm80;
236
+
237
+ CUTLASS_HOST_DEVICE
238
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
239
+ FragmentC const &c) const {
240
+
241
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
242
+
243
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
244
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
245
+ float const *C = reinterpret_cast<float const *>(&c);
246
+ float *D = reinterpret_cast<float *>(&d);
247
+
248
+ asm volatile(
249
+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 "
250
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
251
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
252
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
253
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
254
+
255
+ #else
256
+
257
+ CUTLASS_UNUSED(d);
258
+ CUTLASS_UNUSED(a);
259
+ CUTLASS_UNUSED(b);
260
+ CUTLASS_UNUSED(c);
261
+ CUTLASS_NOT_IMPLEMENTED();
262
+
263
+ #endif
264
+ }
265
+ };
266
+
267
+ ////////////////////////////////////////////////////////////////////////////////
268
+ //
269
+ // Matrix Multiply 16816
270
+ //
271
+ ////////////////////////////////////////////////////////////////////////////////
272
+
273
+ /// Matrix multiply-add operation: F16 = F16 * F16 + F16
274
+ template <>
275
+ struct Mma<
276
+ gemm::GemmShape<16, 8, 16>,
277
+ 32,
278
+ half_t,
279
+ layout::RowMajor,
280
+ half_t,
281
+ layout::ColumnMajor,
282
+ half_t,
283
+ layout::RowMajor,
284
+ OpMultiplyAdd> {
285
+
286
+ using Shape = gemm::GemmShape<16, 8, 16>;
287
+
288
+ using ElementA = half_t;
289
+ using LayoutA = layout::RowMajor;
290
+ using FragmentA = Array<half_t, 8>;
291
+
292
+ using ElementB = half_t;
293
+ using LayoutB = layout::ColumnMajor;
294
+ using FragmentB = Array<half_t, 4>;
295
+
296
+ using ElementC = half_t;
297
+ using LayoutC = layout::RowMajor;
298
+ using FragmentC = Array<half_t, 4>;
299
+
300
+ using Operator = OpMultiplyAdd;
301
+ using ArchTag = arch::Sm80;
302
+
303
+ /// Computes multiply-add
304
+ CUTLASS_HOST_DEVICE
305
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
306
+ FragmentC const &c) const {
307
+
308
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
309
+
310
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
311
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
312
+ uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
313
+ uint32_t *D = reinterpret_cast<uint32_t *>(&d);
314
+
315
+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
316
+ : "=r"(D[0]), "=r"(D[1])
317
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
318
+ "r"(B[0]), "r"(B[1]),
319
+ "r"(C[0]), "r"(C[1])
320
+ );
321
+
322
+ #else
323
+
324
+ CUTLASS_UNUSED(d);
325
+ CUTLASS_UNUSED(a);
326
+ CUTLASS_UNUSED(b);
327
+ CUTLASS_UNUSED(c);
328
+ CUTLASS_NOT_IMPLEMENTED();
329
+
330
+ #endif
331
+ }
332
+ };
333
+
334
+ ////////////////////////////////////////////////////////////////////////////////
335
+
336
+ /// Matrix multiply-add operation: F32 = bf16 * bf16 + F32
337
+ template <>
338
+ struct Mma<
339
+ gemm::GemmShape<16, 8, 16>,
340
+ 32,
341
+ bfloat16_t,
342
+ layout::RowMajor,
343
+ bfloat16_t,
344
+ layout::ColumnMajor,
345
+ float,
346
+ layout::RowMajor,
347
+ OpMultiplyAdd> {
348
+
349
+ using Shape = gemm::GemmShape<16, 8, 16>;
350
+
351
+ using ElementA = bfloat16_t;
352
+ using LayoutA = layout::RowMajor;
353
+ using FragmentA = Array<bfloat16_t, 8>;
354
+
355
+ using ElementB = bfloat16_t;
356
+ using LayoutB = layout::ColumnMajor;
357
+ using FragmentB = Array<bfloat16_t, 4>;
358
+
359
+ using ElementC = float;
360
+ using LayoutC = layout::RowMajor;
361
+ using FragmentC = Array<float, 4>;
362
+
363
+ using Operator = OpMultiplyAdd;
364
+ using ArchTag = arch::Sm80;
365
+
366
+ /// Computes multiply-add
367
+ CUTLASS_HOST_DEVICE
368
+ void operator()(
369
+ FragmentC &d,
370
+ FragmentA const &a,
371
+ FragmentB const &b,
372
+ FragmentC const &c
373
+ ) const {
374
+
375
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
376
+
377
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
378
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
379
+ float const *C = reinterpret_cast<float const *>(&c);
380
+ float *D = reinterpret_cast<float *>(&d);
381
+
382
+ asm volatile(
383
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
384
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
385
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
386
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
387
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
388
+
389
+ #else
390
+
391
+ CUTLASS_UNUSED(d);
392
+ CUTLASS_UNUSED(a);
393
+ CUTLASS_UNUSED(b);
394
+ CUTLASS_UNUSED(c);
395
+ CUTLASS_NOT_IMPLEMENTED();
396
+
397
+ #endif
398
+ }
399
+ };
400
+
401
+ ////////////////////////////////////////////////////////////////////////////////
402
+
403
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
404
+ template <>
405
+ struct Mma<
406
+ gemm::GemmShape<16, 8, 16>,
407
+ 32,
408
+ half_t,
409
+ layout::RowMajor,
410
+ half_t,
411
+ layout::ColumnMajor,
412
+ float,
413
+ layout::RowMajor,
414
+ OpMultiplyAdd> {
415
+
416
+ using Shape = gemm::GemmShape<16, 8, 16>;
417
+
418
+ using ElementA = half_t;
419
+ using LayoutA = layout::RowMajor;
420
+ using FragmentA = Array<half_t, 8>;
421
+
422
+ using ElementB = half_t;
423
+ using LayoutB = layout::ColumnMajor;
424
+ using FragmentB = Array<half_t, 4>;
425
+
426
+ using ElementC = float;
427
+ using LayoutC = layout::RowMajor;
428
+ using FragmentC = Array<float, 4>;
429
+
430
+ using Operator = OpMultiplyAdd;
431
+ using ArchTag = arch::Sm80;
432
+
433
+ /// Computes multiply-add
434
+ CUTLASS_HOST_DEVICE
435
+ void operator()(
436
+ FragmentC &d,
437
+ FragmentA const &a,
438
+ FragmentB const &b,
439
+ FragmentC const &c
440
+ ) const {
441
+
442
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
443
+
444
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
445
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
446
+ float const *C = reinterpret_cast<float const *>(&c);
447
+ float *D = reinterpret_cast<float *>(&d);
448
+
449
+ asm volatile(
450
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
451
+ "{%10,%11,%12,%13};\n"
452
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
453
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
454
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
455
+
456
+ #else
457
+
458
+ CUTLASS_UNUSED(d);
459
+ CUTLASS_UNUSED(a);
460
+ CUTLASS_UNUSED(b);
461
+ CUTLASS_UNUSED(c);
462
+ CUTLASS_NOT_IMPLEMENTED();
463
+
464
+ #endif
465
+ }
466
+ };
467
+
468
+ ////////////////////////////////////////////////////////////////////////////////
469
+ //
470
+ // Matrix Multiply 884 - F64
471
+ //
472
+ ////////////////////////////////////////////////////////////////////////////////
473
+
474
+ /// Matrix multiply-add operation: F64 = F64 * F64 + F64
475
+ template <>
476
+ struct Mma<
477
+ gemm::GemmShape<8,8,4>,
478
+ 32,
479
+ double,
480
+ layout::RowMajor,
481
+ double,
482
+ layout::ColumnMajor,
483
+ double,
484
+ layout::RowMajor,
485
+ OpMultiplyAdd> {
486
+
487
+ using Shape = gemm::GemmShape<8,8,4>;
488
+
489
+ using ElementA = double;
490
+ using LayoutA = layout::RowMajor;
491
+ using FragmentA = Array<double, 1>;
492
+
493
+ using ElementB = double;
494
+ using LayoutB = layout::ColumnMajor;
495
+ using FragmentB = Array<double, 1>;
496
+
497
+ using ElementC = double;
498
+ using LayoutC = layout::RowMajor;
499
+ using FragmentC = Array<double, 2>;
500
+
501
+ using Operator = OpMultiplyAdd;
502
+
503
+ using ArchTag = arch::Sm80;
504
+
505
+ CUTLASS_HOST_DEVICE
506
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
507
+ FragmentC const &c) const {
508
+
509
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
510
+
511
+ double const & A = reinterpret_cast<double const &>(a);
512
+ double const & B = reinterpret_cast<double const &>(b);
513
+
514
+ double const *C = reinterpret_cast<double const *>(&c);
515
+ double *D = reinterpret_cast<double *>(&d);
516
+
517
+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
518
+ : "=d"(D[0]), "=d"(D[1])
519
+ : "d"(A), "d"(B), "d"(C[0]), "d"(C[1]));
520
+
521
+ #else
522
+
523
+ CUTLASS_UNUSED(d);
524
+ CUTLASS_UNUSED(a);
525
+ CUTLASS_UNUSED(b);
526
+ CUTLASS_UNUSED(c);
527
+ CUTLASS_NOT_IMPLEMENTED();
528
+
529
+ #endif
530
+ }
531
+ };
532
+
533
+ ////////////////////////////////////////////////////////////////////////////////
534
+ //
535
+ // Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE
536
+ //
537
+ ////////////////////////////////////////////////////////////////////////////////
538
+
539
+ /// Matrix multiply-add operation: S32 = S8 * S8 + S32
540
+ template <>
541
+ struct Mma<
542
+ gemm::GemmShape<16,8,16>,
543
+ 32,
544
+ int8_t,
545
+ layout::RowMajor,
546
+ int8_t,
547
+ layout::ColumnMajor,
548
+ int,
549
+ layout::RowMajor,
550
+ OpMultiplyAddSaturate> {
551
+
552
+ using Shape = gemm::GemmShape<16,8,16>;
553
+
554
+ using ElementA = int8_t;
555
+ using LayoutA = layout::RowMajor;
556
+ using FragmentA = Array<int8_t, 8>;
557
+
558
+ using ElementB = int8_t;
559
+ using LayoutB = layout::ColumnMajor;
560
+ using FragmentB = Array<int8_t, 4>;
561
+
562
+ using ElementC = int;
563
+ using LayoutC = layout::RowMajor;
564
+ using FragmentC = Array<int, 4>;
565
+
566
+ using Operator = OpMultiplyAddSaturate;
567
+ using ArchTag = arch::Sm80;
568
+
569
+ /// Computes multiply-add
570
+ CUTLASS_HOST_DEVICE
571
+ void operator()(
572
+ FragmentC &d,
573
+ FragmentA const &a,
574
+ FragmentB const &b,
575
+ FragmentC const &c
576
+ ) const {
577
+
578
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
579
+
580
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
581
+ uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
582
+
583
+ int const *C = reinterpret_cast<int const *>(&c);
584
+ int *D = reinterpret_cast<int *>(&d);
585
+
586
+ asm volatile(
587
+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
588
+ "{%6}, {%7,%8,%9,%10};\n"
589
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
590
+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
591
+ "r"(C[3]));
592
+
593
+ #else
594
+ assert(0);
595
+ #endif
596
+ }
597
+ };
598
+
599
+ /// Matrix multiply-add operation: S32 = U8 * S8 + S32
600
+ template <>
601
+ struct Mma<
602
+ gemm::GemmShape<16,8,16>,
603
+ 32,
604
+ uint8_t,
605
+ layout::RowMajor,
606
+ int8_t,
607
+ layout::ColumnMajor,
608
+ int,
609
+ layout::RowMajor,
610
+ OpMultiplyAddSaturate> {
611
+
612
+ using Shape = gemm::GemmShape<16,8,16>;
613
+
614
+ using ElementA = uint8_t;
615
+ using LayoutA = layout::RowMajor;
616
+ using FragmentA = Array<uint8_t, 8>;
617
+
618
+ using ElementB = int8_t;
619
+ using LayoutB = layout::ColumnMajor;
620
+ using FragmentB = Array<int8_t, 4>;
621
+
622
+ using ElementC = int;
623
+ using LayoutC = layout::RowMajor;
624
+ using FragmentC = Array<int, 4>;
625
+
626
+ using Operator = OpMultiplyAddSaturate;
627
+ using ArchTag = arch::Sm80;
628
+
629
+ /// Computes multiply-add
630
+ CUTLASS_HOST_DEVICE
631
+ void operator()(
632
+ FragmentC &d,
633
+ FragmentA const &a,
634
+ FragmentB const &b,
635
+ FragmentC const &c
636
+ ) const {
637
+
638
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
639
+
640
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
641
+ uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
642
+
643
+ int const *C = reinterpret_cast<int const *>(&c);
644
+ int *D = reinterpret_cast<int *>(&d);
645
+
646
+ asm volatile(
647
+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
648
+ "{%6}, {%7,%8,%9,%10};\n"
649
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
650
+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
651
+ "r"(C[3]));
652
+
653
+ #else
654
+ assert(0);
655
+ #endif
656
+ }
657
+ };
658
+
659
+ /// Matrix multiply-add operation: S32 = S8 * U8 + S32
660
+ template <>
661
+ struct Mma<
662
+ gemm::GemmShape<16,8,16>,
663
+ 32,
664
+ int8_t,
665
+ layout::RowMajor,
666
+ uint8_t,
667
+ layout::ColumnMajor,
668
+ int,
669
+ layout::RowMajor,
670
+ OpMultiplyAddSaturate> {
671
+
672
+ using Shape = gemm::GemmShape<16,8,16>;
673
+
674
+ using ElementA = int8_t;
675
+ using LayoutA = layout::RowMajor;
676
+ using FragmentA = Array<int8_t, 8>;
677
+
678
+ using ElementB = uint8_t;
679
+ using LayoutB = layout::ColumnMajor;
680
+ using FragmentB = Array<uint8_t, 4>;
681
+
682
+ using ElementC = int;
683
+ using LayoutC = layout::RowMajor;
684
+ using FragmentC = Array<int, 4>;
685
+
686
+ using Operator = OpMultiplyAddSaturate;
687
+ using ArchTag = arch::Sm80;
688
+
689
+ /// Computes multiply-add
690
+ CUTLASS_HOST_DEVICE
691
+ void operator()(
692
+ FragmentC &d,
693
+ FragmentA const &a,
694
+ FragmentB const &b,
695
+ FragmentC const &c
696
+ ) const {
697
+
698
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
699
+
700
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
701
+ uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
702
+
703
+ int const *C = reinterpret_cast<int const *>(&c);
704
+ int *D = reinterpret_cast<int *>(&d);
705
+
706
+ asm volatile(
707
+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
708
+ "{%6}, {%7,%8,%9,%10};\n"
709
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
710
+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
711
+ "r"(C[3]));
712
+
713
+ #else
714
+ assert(0);
715
+ #endif
716
+ }
717
+ };
718
+
719
+ /// Matrix multiply-add operation: S32 = U8 * U8 + S32
720
+ template <>
721
+ struct Mma<
722
+ gemm::GemmShape<16,8,16>,
723
+ 32,
724
+ uint8_t,
725
+ layout::RowMajor,
726
+ uint8_t,
727
+ layout::ColumnMajor,
728
+ int,
729
+ layout::RowMajor,
730
+ OpMultiplyAddSaturate> {
731
+
732
+ using Shape = gemm::GemmShape<16,8,16>;
733
+
734
+ using ElementA = uint8_t;
735
+ using LayoutA = layout::RowMajor;
736
+ using FragmentA = Array<uint8_t, 8>;
737
+
738
+ using ElementB = uint8_t;
739
+ using LayoutB = layout::ColumnMajor;
740
+ using FragmentB = Array<uint8_t, 4>;
741
+
742
+ using ElementC = int;
743
+ using LayoutC = layout::RowMajor;
744
+ using FragmentC = Array<int, 4>;
745
+
746
+ using Operator = OpMultiplyAddSaturate;
747
+ using ArchTag = arch::Sm80;
748
+
749
+ /// Computes multiply-add
750
+ CUTLASS_HOST_DEVICE
751
+ void operator()(
752
+ FragmentC &d,
753
+ FragmentA const &a,
754
+ FragmentB const &b,
755
+ FragmentC const &c
756
+ ) const {
757
+
758
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
759
+
760
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
761
+ uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
762
+
763
+ int const *C = reinterpret_cast<int const *>(&c);
764
+ int *D = reinterpret_cast<int *>(&d);
765
+
766
+ asm volatile(
767
+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
768
+ "{%6}, {%7,%8,%9,%10};\n"
769
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
770
+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
771
+ "r"(C[3]));
772
+
773
+ #else
774
+ assert(0);
775
+ #endif
776
+ }
777
+ };
778
+
779
+ ////////////////////////////////////////////////////////////////////////////////
780
+ //
781
+ // Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE
782
+ //
783
+ ////////////////////////////////////////////////////////////////////////////////
784
+
785
+ /// Matrix multiply-add operation: S32 = S8 * S8 + S32
786
+ template <>
787
+ struct Mma<
788
+ gemm::GemmShape<16,8,32>,
789
+ 32,
790
+ int8_t,
791
+ layout::RowMajor,
792
+ int8_t,
793
+ layout::ColumnMajor,
794
+ int,
795
+ layout::RowMajor,
796
+ OpMultiplyAddSaturate> {
797
+
798
+ using Shape = gemm::GemmShape<16,8,32>;
799
+
800
+ using ElementA = int8_t;
801
+ using LayoutA = layout::RowMajor;
802
+ using FragmentA = Array<int8_t, 16>;
803
+
804
+ using ElementB = int8_t;
805
+ using LayoutB = layout::ColumnMajor;
806
+ using FragmentB = Array<int8_t, 8>;
807
+
808
+ using ElementC = int;
809
+ using LayoutC = layout::RowMajor;
810
+ using FragmentC = Array<int, 4>;
811
+
812
+ using Operator = OpMultiplyAddSaturate;
813
+ using ArchTag = arch::Sm80;
814
+
815
+ /// Computes multiply-add
816
+ CUTLASS_HOST_DEVICE
817
+ void operator()(
818
+ FragmentC &d,
819
+ FragmentA const &a,
820
+ FragmentB const &b,
821
+ FragmentC const &c
822
+ ) const {
823
+
824
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
825
+
826
+ uint32_t const * A = reinterpret_cast<uint32_t const *>(&a);
827
+ uint32_t const * B = reinterpret_cast<uint32_t const *>(&b);
828
+
829
+ int const *C = reinterpret_cast<int const *>(&c);
830
+ int *D = reinterpret_cast<int *>(&d);
831
+
832
+ asm volatile(
833
+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, "
834
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
835
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
836
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
837
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
838
+
839
+ #else
840
+ assert(0);
841
+ #endif
842
+ }
843
+ };
844
+
845
+ /// Matrix multiply-add operation: S32 = U8 * S8 + S32
846
+ template <>
847
+ struct Mma<
848
+ gemm::GemmShape<16,8,32>,
849
+ 32,
850
+ uint8_t,
851
+ layout::RowMajor,
852
+ int8_t,
853
+ layout::ColumnMajor,
854
+ int,
855
+ layout::RowMajor,
856
+ OpMultiplyAddSaturate> {
857
+
858
+ using Shape = gemm::GemmShape<16,8,32>;
859
+
860
+ using ElementA = uint8_t;
861
+ using LayoutA = layout::RowMajor;
862
+ using FragmentA = Array<uint8_t, 16>;
863
+
864
+ using ElementB = int8_t;
865
+ using LayoutB = layout::ColumnMajor;
866
+ using FragmentB = Array<int8_t, 8>;
867
+
868
+ using ElementC = int;
869
+ using LayoutC = layout::RowMajor;
870
+ using FragmentC = Array<int, 4>;
871
+
872
+ using Operator = OpMultiplyAddSaturate;
873
+ using ArchTag = arch::Sm80;
874
+
875
+ /// Computes multiply-add
876
+ CUTLASS_HOST_DEVICE
877
+ void operator()(
878
+ FragmentC &d,
879
+ FragmentA const &a,
880
+ FragmentB const &b,
881
+ FragmentC const &c
882
+ ) const {
883
+
884
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
885
+
886
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
887
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
888
+
889
+ int const *C = reinterpret_cast<int const *>(&c);
890
+ int *D = reinterpret_cast<int *>(&d);
891
+
892
+ asm volatile(
893
+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, "
894
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
895
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
896
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
897
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
898
+
899
+ #else
900
+ assert(0);
901
+ #endif
902
+ }
903
+ };
904
+
905
+ /// Matrix multiply-add operation: S32 = S8 * U8 + S32
906
+ template <>
907
+ struct Mma<
908
+ gemm::GemmShape<16,8,32>,
909
+ 32,
910
+ int8_t,
911
+ layout::RowMajor,
912
+ uint8_t,
913
+ layout::ColumnMajor,
914
+ int,
915
+ layout::RowMajor,
916
+ OpMultiplyAddSaturate> {
917
+
918
+ using Shape = gemm::GemmShape<16,8,32>;
919
+
920
+ using ElementA = int8_t;
921
+ using LayoutA = layout::RowMajor;
922
+ using FragmentA = Array<int8_t, 16>;
923
+
924
+ using ElementB = uint8_t;
925
+ using LayoutB = layout::ColumnMajor;
926
+ using FragmentB = Array<uint8_t, 8>;
927
+
928
+ using ElementC = int;
929
+ using LayoutC = layout::RowMajor;
930
+ using FragmentC = Array<int, 4>;
931
+
932
+ using Operator = OpMultiplyAddSaturate;
933
+ using ArchTag = arch::Sm80;
934
+
935
+ /// Computes multiply-add
936
+ CUTLASS_HOST_DEVICE
937
+ void operator()(
938
+ FragmentC &d,
939
+ FragmentA const &a,
940
+ FragmentB const &b,
941
+ FragmentC const &c
942
+ ) const {
943
+
944
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
945
+
946
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
947
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
948
+
949
+ int const *C = reinterpret_cast<int const *>(&c);
950
+ int *D = reinterpret_cast<int *>(&d);
951
+
952
+ asm volatile(
953
+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, "
954
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
955
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
956
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
957
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
958
+
959
+ #else
960
+ assert(0);
961
+ #endif
962
+ }
963
+ };
964
+
965
+ /// Matrix multiply-add operation: S32 = U8 * U8 + S32
966
+ template <>
967
+ struct Mma<
968
+ gemm::GemmShape<16,8,32>,
969
+ 32,
970
+ uint8_t,
971
+ layout::RowMajor,
972
+ uint8_t,
973
+ layout::ColumnMajor,
974
+ int,
975
+ layout::RowMajor,
976
+ OpMultiplyAddSaturate> {
977
+
978
+ using Shape = gemm::GemmShape<16,8,32>;
979
+
980
+ using ElementA = uint8_t;
981
+ using LayoutA = layout::RowMajor;
982
+ using FragmentA = Array<uint8_t, 16>;
983
+
984
+ using ElementB = uint8_t;
985
+ using LayoutB = layout::ColumnMajor;
986
+ using FragmentB = Array<uint8_t, 8>;
987
+
988
+ using ElementC = int;
989
+ using LayoutC = layout::RowMajor;
990
+ using FragmentC = Array<int, 4>;
991
+
992
+ using Operator = OpMultiplyAddSaturate;
993
+ using ArchTag = arch::Sm80;
994
+
995
+ /// Computes multiply-add
996
+ CUTLASS_HOST_DEVICE
997
+ void operator()(
998
+ FragmentC &d,
999
+ FragmentA const &a,
1000
+ FragmentB const &b,
1001
+ FragmentC const &c
1002
+ ) const {
1003
+
1004
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
1005
+
1006
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1007
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1008
+
1009
+ int const *C = reinterpret_cast<int const *>(&c);
1010
+ int *D = reinterpret_cast<int *>(&d);
1011
+
1012
+ asm volatile(
1013
+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, "
1014
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
1015
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1016
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1017
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1018
+
1019
+ #else
1020
+ assert(0);
1021
+ #endif
1022
+ }
1023
+ };
1024
+
1025
+ ////////////////////////////////////////////////////////////////////////////////
1026
+ //
1027
+ // Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE
1028
+ //
1029
+ ////////////////////////////////////////////////////////////////////////////////
1030
+
1031
+ /// Matrix multiply-add operation: S32 = S4 * S4 + S32
1032
+ template <>
1033
+ struct Mma<
1034
+ gemm::GemmShape<16, 8, 64>,
1035
+ 32,
1036
+ cutlass::int4b_t,
1037
+ layout::RowMajor,
1038
+ cutlass::int4b_t,
1039
+ layout::ColumnMajor,
1040
+ int,
1041
+ layout::RowMajor,
1042
+ OpMultiplyAddSaturate> {
1043
+
1044
+ using Shape = gemm::GemmShape<16, 8, 64>;
1045
+
1046
+ using ElementA = cutlass::int4b_t;
1047
+ using LayoutA = layout::RowMajor;
1048
+ using FragmentA = Array<cutlass::int4b_t, 32>;
1049
+
1050
+ using ElementB = cutlass::int4b_t;
1051
+ using LayoutB = layout::ColumnMajor;
1052
+ using FragmentB = Array<cutlass::int4b_t, 16>;
1053
+
1054
+ using ElementC = int;
1055
+ using LayoutC = layout::RowMajor;
1056
+ using FragmentC = Array<int, 4>;
1057
+
1058
+ using Operator = OpMultiplyAddSaturate;
1059
+ using ArchTag = arch::Sm80;
1060
+
1061
+ /// Computes multiply-add
1062
+ CUTLASS_HOST_DEVICE
1063
+ void operator()(
1064
+ FragmentC &d,
1065
+ FragmentA const &a,
1066
+ FragmentB const &b,
1067
+ FragmentC const &c
1068
+ ) const {
1069
+
1070
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
1071
+
1072
+ uint32_t const * A = reinterpret_cast<uint32_t const *>(&a);
1073
+ uint32_t const * B = reinterpret_cast<uint32_t const *>(&b);
1074
+
1075
+ int const *C = reinterpret_cast<int const *>(&c);
1076
+ int *D = reinterpret_cast<int *>(&d);
1077
+
1078
+ asm volatile(
1079
+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, "
1080
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
1081
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1082
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1083
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1084
+
1085
+ #else
1086
+ CUTLASS_UNUSED(a);
1087
+ CUTLASS_UNUSED(b);
1088
+ CUTLASS_UNUSED(c);
1089
+ CUTLASS_UNUSED(d);
1090
+ assert(0);
1091
+ #endif
1092
+ }
1093
+ };
1094
+
1095
+ /// Matrix multiply-add operation: S32 = U4 * S4 + S32
1096
+ template <>
1097
+ struct Mma<
1098
+ gemm::GemmShape<16, 8, 64>,
1099
+ 32,
1100
+ cutlass::uint4b_t,
1101
+ layout::RowMajor,
1102
+ cutlass::int4b_t,
1103
+ layout::ColumnMajor,
1104
+ int,
1105
+ layout::RowMajor,
1106
+ OpMultiplyAddSaturate> {
1107
+
1108
+ using Shape = gemm::GemmShape<16, 8, 64>;
1109
+
1110
+ using ElementA = cutlass::uint4b_t;
1111
+ using LayoutA = layout::RowMajor;
1112
+ using FragmentA = Array<cutlass::uint4b_t, 32>;
1113
+
1114
+ using ElementB = cutlass::int4b_t;
1115
+ using LayoutB = layout::ColumnMajor;
1116
+ using FragmentB = Array<cutlass::int4b_t, 16>;
1117
+
1118
+ using ElementC = int;
1119
+ using LayoutC = layout::RowMajor;
1120
+ using FragmentC = Array<int, 4>;
1121
+
1122
+ using Operator = OpMultiplyAddSaturate;
1123
+ using ArchTag = arch::Sm80;
1124
+
1125
+ /// Computes multiply-add
1126
+ CUTLASS_HOST_DEVICE
1127
+ void operator()(
1128
+ FragmentC &d,
1129
+ FragmentA const &a,
1130
+ FragmentB const &b,
1131
+ FragmentC const &c
1132
+ ) const {
1133
+
1134
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
1135
+
1136
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1137
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1138
+
1139
+ int const *C = reinterpret_cast<int const *>(&c);
1140
+ int *D = reinterpret_cast<int *>(&d);
1141
+
1142
+ asm volatile(
1143
+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, "
1144
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
1145
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1146
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1147
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1148
+
1149
+ #else
1150
+ CUTLASS_UNUSED(a);
1151
+ CUTLASS_UNUSED(b);
1152
+ CUTLASS_UNUSED(c);
1153
+ CUTLASS_UNUSED(d);
1154
+ assert(0);
1155
+ #endif
1156
+ }
1157
+ };
1158
+
1159
+ /// Matrix multiply-add operation: S32 = S4 * U4 + S32
1160
+ template <>
1161
+ struct Mma<
1162
+ gemm::GemmShape<16, 8, 64>,
1163
+ 32,
1164
+ cutlass::int4b_t,
1165
+ layout::RowMajor,
1166
+ cutlass::uint4b_t,
1167
+ layout::ColumnMajor,
1168
+ int,
1169
+ layout::RowMajor,
1170
+ OpMultiplyAddSaturate> {
1171
+
1172
+ using Shape = gemm::GemmShape<16, 8, 64>;
1173
+
1174
+ using ElementA = cutlass::int4b_t;
1175
+ using LayoutA = layout::RowMajor;
1176
+ using FragmentA = Array<cutlass::int4b_t, 32>;
1177
+
1178
+ using ElementB = cutlass::uint4b_t;
1179
+ using LayoutB = layout::ColumnMajor;
1180
+ using FragmentB = Array<cutlass::uint4b_t, 16>;
1181
+
1182
+ using ElementC = int;
1183
+ using LayoutC = layout::RowMajor;
1184
+ using FragmentC = Array<int, 4>;
1185
+
1186
+ using Operator = OpMultiplyAddSaturate;
1187
+ using ArchTag = arch::Sm80;
1188
+
1189
+ /// Computes multiply-add
1190
+ CUTLASS_HOST_DEVICE
1191
+ void operator()(
1192
+ FragmentC &d,
1193
+ FragmentA const &a,
1194
+ FragmentB const &b,
1195
+ FragmentC const &c
1196
+ ) const {
1197
+
1198
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
1199
+
1200
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1201
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1202
+
1203
+ int const *C = reinterpret_cast<int const *>(&c);
1204
+ int *D = reinterpret_cast<int *>(&d);
1205
+
1206
+ asm volatile(
1207
+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, "
1208
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
1209
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1210
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1211
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1212
+
1213
+ #else
1214
+ CUTLASS_UNUSED(a);
1215
+ CUTLASS_UNUSED(b);
1216
+ CUTLASS_UNUSED(c);
1217
+ CUTLASS_UNUSED(d);
1218
+ assert(0);
1219
+ #endif
1220
+ }
1221
+ };
1222
+
1223
+ /// Matrix multiply-add operation: S32 = U4 * U4 + S32
1224
+ template <>
1225
+ struct Mma<
1226
+ gemm::GemmShape<16, 8, 64>,
1227
+ 32,
1228
+ cutlass::uint4b_t,
1229
+ layout::RowMajor,
1230
+ cutlass::uint4b_t,
1231
+ layout::ColumnMajor,
1232
+ int,
1233
+ layout::RowMajor,
1234
+ OpMultiplyAddSaturate> {
1235
+
1236
+ using Shape = gemm::GemmShape<16, 8, 64>;
1237
+
1238
+ using ElementA = cutlass::uint4b_t;
1239
+ using LayoutA = layout::RowMajor;
1240
+ using FragmentA = Array<cutlass::uint4b_t, 32>;
1241
+
1242
+ using ElementB = cutlass::uint4b_t;
1243
+ using LayoutB = layout::ColumnMajor;
1244
+ using FragmentB = Array<cutlass::uint4b_t, 16>;
1245
+
1246
+ using ElementC = int;
1247
+ using LayoutC = layout::RowMajor;
1248
+ using FragmentC = Array<int, 4>;
1249
+
1250
+ using Operator = OpMultiplyAddSaturate;
1251
+ using ArchTag = arch::Sm80;
1252
+
1253
+ /// Computes multiply-add
1254
+ CUTLASS_HOST_DEVICE
1255
+ void operator()(
1256
+ FragmentC &d,
1257
+ FragmentA const &a,
1258
+ FragmentB const &b,
1259
+ FragmentC const &c
1260
+ ) const {
1261
+
1262
+ #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
1263
+
1264
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1265
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1266
+
1267
+ int const *C = reinterpret_cast<int const *>(&c);
1268
+ int *D = reinterpret_cast<int *>(&d);
1269
+
1270
+ asm volatile(
1271
+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, "
1272
+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
1273
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1274
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1275
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1276
+
1277
+ #else
1278
+ CUTLASS_UNUSED(a);
1279
+ CUTLASS_UNUSED(b);
1280
+ CUTLASS_UNUSED(c);
1281
+ CUTLASS_UNUSED(d);
1282
+ assert(0);
1283
+ #endif
1284
+ }
1285
+ };
1286
+
1287
+ ////////////////////////////////////////////////////////////////////////////////
1288
+ //
1289
+ // Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC
1290
+ //
1291
+ ////////////////////////////////////////////////////////////////////////////////
1292
+
1293
+ /// Matrix multiply-add operation: S32 = B1 & B1 + S32
1294
+ template <>
1295
+ struct Mma<
1296
+ gemm::GemmShape<16,8,256>,
1297
+ 32,
1298
+ cutlass::uint1b_t,
1299
+ layout::RowMajor,
1300
+ cutlass::uint1b_t,
1301
+ layout::ColumnMajor,
1302
+ int32_t,
1303
+ layout::RowMajor,
1304
+ OpAndPopc> {
1305
+
1306
+ using Shape = gemm::GemmShape<16,8,256>;
1307
+
1308
+ using ElementA = cutlass::uint1b_t;
1309
+ using LayoutA = layout::RowMajor;
1310
+ using FragmentA = Array<cutlass::uint1b_t, 128>;
1311
+
1312
+ using ElementB = cutlass::uint1b_t;
1313
+ using LayoutB = layout::ColumnMajor;
1314
+ using FragmentB = Array<cutlass::uint1b_t, 64>;
1315
+
1316
+ using ElementC = int32_t;
1317
+ using LayoutC = layout::RowMajor;
1318
+ using FragmentC = Array<int32_t, 4>;
1319
+
1320
+ using Operator = OpAndPopc;
1321
+ using ArchTag = arch::Sm80;
1322
+
1323
+ /// Computes multiply-add
1324
+ CUTLASS_HOST_DEVICE
1325
+ void operator()(
1326
+ FragmentC &d,
1327
+ FragmentA const &a,
1328
+ FragmentB const &b,
1329
+ FragmentC const &c
1330
+ ) const {
1331
+
1332
+ #if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED)
1333
+
1334
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1335
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1336
+
1337
+ int const *C = reinterpret_cast<int const *>(&c);
1338
+ int *D = reinterpret_cast<int *>(&d);
1339
+
1340
+ asm volatile(
1341
+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, "
1342
+ "{%4,%5,%6,%7}, "
1343
+ "{%8,%9}, {%10,%11,%12,%13};\n"
1344
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1345
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1346
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1347
+
1348
+ #else
1349
+ CUTLASS_UNUSED(a);
1350
+ CUTLASS_UNUSED(b);
1351
+ CUTLASS_UNUSED(c);
1352
+ CUTLASS_UNUSED(d);
1353
+ assert(0);
1354
+ #endif
1355
+ }
1356
+ };
1357
+
1358
+ /// Matrix multiply-add operation: S32 = B1 & B1 + S32
1359
+ template <>
1360
+ struct Mma<
1361
+ gemm::GemmShape<16,8,256>,
1362
+ 32,
1363
+ cutlass::uint1b_t,
1364
+ layout::RowMajor,
1365
+ cutlass::uint1b_t,
1366
+ layout::ColumnMajor,
1367
+ int,
1368
+ layout::RowMajor,
1369
+ OpMultiplyAdd> {
1370
+
1371
+ using Shape = gemm::GemmShape<16,8,256>;
1372
+
1373
+ using ElementA = cutlass::uint1b_t;
1374
+ using LayoutA = layout::RowMajor;
1375
+ using FragmentA = Array<cutlass::uint1b_t, 128>;
1376
+
1377
+ using ElementB = cutlass::uint1b_t;
1378
+ using LayoutB = layout::ColumnMajor;
1379
+ using FragmentB = Array<cutlass::uint1b_t, 64>;
1380
+
1381
+ using ElementC = int32_t;
1382
+ using LayoutC = layout::RowMajor;
1383
+ using FragmentC = Array<int32_t, 4>;
1384
+
1385
+ using Operator = OpMultiplyAdd;
1386
+ using ArchTag = arch::Sm80;
1387
+
1388
+ /// Computes multiply-add
1389
+ CUTLASS_HOST_DEVICE
1390
+ void operator()(
1391
+ FragmentC &d,
1392
+ FragmentA const &a,
1393
+ FragmentB const &b,
1394
+ FragmentC const &c
1395
+ ) const {
1396
+
1397
+ #if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED)
1398
+
1399
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1400
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1401
+
1402
+ int const *C = reinterpret_cast<int const *>(&c);
1403
+ int *D = reinterpret_cast<int *>(&d);
1404
+
1405
+ asm volatile(
1406
+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, "
1407
+ "{%4,%5,%6,%7}, "
1408
+ "{%8,%9}, {%10,%11,%12,%13};\n"
1409
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1410
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1411
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1412
+
1413
+ #else
1414
+ CUTLASS_UNUSED(a);
1415
+ CUTLASS_UNUSED(b);
1416
+ CUTLASS_UNUSED(c);
1417
+ CUTLASS_UNUSED(d);
1418
+ assert(0);
1419
+ #endif
1420
+ }
1421
+ };
1422
+
1423
+ ////////////////////////////////////////////////////////////////////////////////
1424
+ //
1425
+ // Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC
1426
+ //
1427
+ ////////////////////////////////////////////////////////////////////////////////
1428
+
1429
+ /// Matrix multiply-add operation: S32 = B1 & B1 + S32
1430
+ template <>
1431
+ struct Mma<
1432
+ gemm::GemmShape<16,8,256>,
1433
+ 32,
1434
+ cutlass::uint1b_t,
1435
+ layout::RowMajor,
1436
+ cutlass::uint1b_t,
1437
+ layout::ColumnMajor,
1438
+ int,
1439
+ layout::RowMajor,
1440
+ OpXorPopc> {
1441
+
1442
+ using Shape = gemm::GemmShape<16,8,256>;
1443
+
1444
+ using ElementA = cutlass::uint1b_t;
1445
+ using LayoutA = layout::RowMajor;
1446
+ using FragmentA = Array<cutlass::uint1b_t, 128>;
1447
+
1448
+ using ElementB = cutlass::uint1b_t;
1449
+ using LayoutB = layout::ColumnMajor;
1450
+ using FragmentB = Array<cutlass::uint1b_t, 64>;
1451
+
1452
+ using ElementC = int;
1453
+ using LayoutC = layout::RowMajor;
1454
+ using FragmentC = Array<int, 4>;
1455
+
1456
+ using Operator = OpXorPopc;
1457
+ using ArchTag = arch::Sm80;
1458
+
1459
+ /// Computes multiply-add
1460
+ CUTLASS_HOST_DEVICE
1461
+ void operator()(
1462
+ FragmentC &d,
1463
+ FragmentA const &a,
1464
+ FragmentB const &b,
1465
+ FragmentC const &c
1466
+ ) const {
1467
+
1468
+ #if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED)
1469
+
1470
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1471
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1472
+
1473
+ int const *C = reinterpret_cast<int const *>(&c);
1474
+ int *D = reinterpret_cast<int *>(&d);
1475
+
1476
+ asm volatile(
1477
+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
1478
+ "{%4,%5,%6,%7}, "
1479
+ "{%8,%9}, {%10,%11,%12,%13};\n"
1480
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1481
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
1482
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
1483
+
1484
+ #else
1485
+
1486
+ CUTLASS_UNUSED(a);
1487
+ CUTLASS_UNUSED(b);
1488
+ CUTLASS_UNUSED(c);
1489
+ CUTLASS_UNUSED(d);
1490
+ assert(0);
1491
+
1492
+ #endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED)
1493
+ }
1494
+ };
1495
+
1496
+ ////////////////////////////////////////////////////////////////////////////////
1497
+
1498
+ } // namespace arch
1499
+ } // namespace cutlass
1500
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Matrix multiply-accumulate specialzied for SM89
34
+ */
35
+
36
+ #pragma once
37
+ #include "cutlass/cutlass.h"
38
+ #include CUDA_STD_HEADER(cassert)
39
+
40
+ #include "mma.h"
41
+ #include "cutlass/layout/matrix.h"
42
+ #include "cutlass/numeric_types.h"
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////
45
+
46
+ #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)
47
+ # define CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED
48
+ #endif
49
+
50
+ #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)
51
+ # define CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED
52
+ #endif
53
+
54
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
55
+ # if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED)
56
+ # define CUTLASS_ARCH_MMA_F32_SM89_ENABLED
57
+ # endif
58
+
59
+ # if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED)
60
+ # define CUTLASS_ARCH_MMA_F16_SM89_ENABLED
61
+ # endif
62
+ #endif
63
+
64
+ ////////////////////////////////////////////////////////////////////////////////
65
+
66
+ namespace cutlass {
67
+ namespace arch {
68
+
69
+ ////////////////////////////////////////////////////////////////////////////////
70
+
71
+ namespace detail {
72
+
73
+ // Whether the Mma uses as SM89 staged accumulation policy
74
+ template <class Operator>
75
+ static constexpr bool is_sm89_staged_policy_v =
76
+ (
77
+ // ElementA must be FP8
78
+ platform::is_same<typename Operator::ElementA, cutlass::float_e4m3_t>::value ||
79
+ platform::is_same<typename Operator::ElementA, cutlass::float_e5m2_t>::value
80
+ ) &&
81
+ (
82
+ // ElementB must be FP8
83
+ platform::is_same<typename Operator::ElementB, cutlass::float_e4m3_t>::value ||
84
+ platform::is_same<typename Operator::ElementB, cutlass::float_e5m2_t>::value
85
+ ) &&
86
+ (
87
+ // The instruction shape must be 16x8x32
88
+ Operator::ArchMmaOperator::Shape::kM == 16 &&
89
+ Operator::ArchMmaOperator::Shape::kN == 8 &&
90
+ Operator::ArchMmaOperator::Shape::kK == 32
91
+ ) &&
92
+ (
93
+ // The operator must be OpMultiplyAdd (default)
94
+ platform::is_same<typename Operator::MathOperator, OpMultiplyAdd>::value
95
+ );
96
+ } // namespace detail
97
+
98
+ ////////////////////////////////////////////////////////////////////////////////
99
+
100
+ ////////////////////////////////////////////////////////////////////////////////
101
+ //
102
+ // Matrix Multiply 16832 - Float {E4M3, E5M2}, FP32 accumulation
103
+ //
104
+ ////////////////////////////////////////////////////////////////////////////////
105
+
106
+ /// Matrix multiply-add operation - F32 = fe4m3 * fe4m3 + F32
107
+ template <typename Operator_>
108
+ struct Mma<
109
+ gemm::GemmShape<16, 8, 32>,
110
+ 32,
111
+ cutlass::float_e4m3_t,
112
+ layout::RowMajor,
113
+ cutlass::float_e4m3_t,
114
+ layout::ColumnMajor,
115
+ float,
116
+ layout::RowMajor,
117
+ Operator_> {
118
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
119
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
120
+ "Invalid operator for SM89 FP8 instruction");
121
+
122
+ using Shape = gemm::GemmShape<16, 8, 32>;
123
+
124
+ using ElementA = cutlass::float_e4m3_t;
125
+ using LayoutA = layout::RowMajor;
126
+ using FragmentA = Array<ElementA, 16>;
127
+
128
+ using ElementB = cutlass::float_e4m3_t;
129
+ using LayoutB = layout::ColumnMajor;
130
+ using FragmentB = Array<ElementB, 8>;
131
+
132
+ using ElementC = float;
133
+ using LayoutC = layout::RowMajor;
134
+ using FragmentC = Array<float, 4>;
135
+
136
+ using Operator = Operator_;
137
+ using ArchTag = arch::Sm89;
138
+
139
+ CUTLASS_HOST_DEVICE
140
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
141
+ FragmentC const &c) const {
142
+
143
+ #if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
144
+
145
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
146
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
147
+ float const *C = reinterpret_cast<float const *>(&c);
148
+ float *D = reinterpret_cast<float *>(&d);
149
+
150
+ asm(
151
+ "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
152
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
153
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
154
+ :
155
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
156
+ "r"(B[0]), "r"(B[1]),
157
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
158
+ );
159
+
160
+ #else
161
+
162
+ CUTLASS_UNUSED(d);
163
+ CUTLASS_UNUSED(a);
164
+ CUTLASS_UNUSED(b);
165
+ CUTLASS_UNUSED(c);
166
+ CUTLASS_NOT_IMPLEMENTED();
167
+
168
+ #endif
169
+ }
170
+ };
171
+
172
+ /// Matrix multiply-add operation - F32 = fe4m3 * fe5m2 + F32
173
+ template <typename Operator_>
174
+ struct Mma<
175
+ gemm::GemmShape<16, 8, 32>,
176
+ 32,
177
+ cutlass::float_e4m3_t,
178
+ layout::RowMajor,
179
+ cutlass::float_e5m2_t,
180
+ layout::ColumnMajor,
181
+ float,
182
+ layout::RowMajor,
183
+ Operator_> {
184
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
185
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
186
+ "Invalid operator for SM89 FP8 instruction");
187
+
188
+ using Shape = gemm::GemmShape<16, 8, 32>;
189
+
190
+ using ElementA = cutlass::float_e4m3_t;
191
+ using LayoutA = layout::RowMajor;
192
+ using FragmentA = Array<ElementA, 16>;
193
+
194
+ using ElementB = cutlass::float_e5m2_t;
195
+ using LayoutB = layout::ColumnMajor;
196
+ using FragmentB = Array<ElementB, 8>;
197
+
198
+ using ElementC = float;
199
+ using LayoutC = layout::RowMajor;
200
+ using FragmentC = Array<float, 4>;
201
+
202
+ using Operator = Operator_;
203
+ using ArchTag = arch::Sm89;
204
+
205
+ CUTLASS_HOST_DEVICE
206
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
207
+ FragmentC const &c) const {
208
+
209
+ #if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
210
+
211
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
212
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
213
+ float const *C = reinterpret_cast<float const *>(&c);
214
+ float *D = reinterpret_cast<float *>(&d);
215
+
216
+ asm(
217
+ "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 "
218
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
219
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
220
+ :
221
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
222
+ "r"(B[0]), "r"(B[1]),
223
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
224
+ );
225
+
226
+ #else
227
+
228
+ CUTLASS_UNUSED(d);
229
+ CUTLASS_UNUSED(a);
230
+ CUTLASS_UNUSED(b);
231
+ CUTLASS_UNUSED(c);
232
+ CUTLASS_NOT_IMPLEMENTED();
233
+
234
+ #endif
235
+ }
236
+ };
237
+
238
+ /// Matrix multiply-add operation - F32 = fe5m2 * fe4m3 + F32
239
+ template <typename Operator_>
240
+ struct Mma<
241
+ gemm::GemmShape<16, 8, 32>,
242
+ 32,
243
+ cutlass::float_e5m2_t,
244
+ layout::RowMajor,
245
+ cutlass::float_e4m3_t,
246
+ layout::ColumnMajor,
247
+ float,
248
+ layout::RowMajor,
249
+ Operator_> {
250
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
251
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
252
+ "Invalid operator for SM89 FP8 instruction");
253
+
254
+ using Shape = gemm::GemmShape<16, 8, 32>;
255
+
256
+ using ElementA = cutlass::float_e5m2_t;
257
+ using LayoutA = layout::RowMajor;
258
+ using FragmentA = Array<ElementA, 16>;
259
+
260
+ using ElementB = cutlass::float_e4m3_t;
261
+ using LayoutB = layout::ColumnMajor;
262
+ using FragmentB = Array<ElementB, 8>;
263
+
264
+ using ElementC = float;
265
+ using LayoutC = layout::RowMajor;
266
+ using FragmentC = Array<float, 4>;
267
+
268
+ using Operator = Operator_;
269
+ using ArchTag = arch::Sm89;
270
+
271
+ CUTLASS_HOST_DEVICE
272
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
273
+ FragmentC const &c) const {
274
+
275
+ #if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
276
+
277
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
278
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
279
+ float const *C = reinterpret_cast<float const *>(&c);
280
+ float *D = reinterpret_cast<float *>(&d);
281
+
282
+ asm(
283
+ "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 "
284
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
285
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
286
+ :
287
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
288
+ "r"(B[0]), "r"(B[1]),
289
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
290
+ );
291
+
292
+ #else
293
+
294
+ CUTLASS_UNUSED(d);
295
+ CUTLASS_UNUSED(a);
296
+ CUTLASS_UNUSED(b);
297
+ CUTLASS_UNUSED(c);
298
+ CUTLASS_NOT_IMPLEMENTED();
299
+
300
+ #endif
301
+ }
302
+ };
303
+
304
+ /// Matrix multiply-add operation - F32 = fe5m2 * fe5m2 + F32
305
+ template <typename Operator_>
306
+ struct Mma<
307
+ gemm::GemmShape<16, 8, 32>,
308
+ 32,
309
+ cutlass::float_e5m2_t,
310
+ layout::RowMajor,
311
+ cutlass::float_e5m2_t,
312
+ layout::ColumnMajor,
313
+ float,
314
+ layout::RowMajor,
315
+ Operator_> {
316
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
317
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
318
+ "Invalid operator for SM89 FP8 instruction");
319
+
320
+ using Shape = gemm::GemmShape<16, 8, 32>;
321
+
322
+ using ElementA = cutlass::float_e5m2_t;
323
+ using LayoutA = layout::RowMajor;
324
+ using FragmentA = Array<ElementA, 16>;
325
+
326
+ using ElementB = cutlass::float_e5m2_t;
327
+ using LayoutB = layout::ColumnMajor;
328
+ using FragmentB = Array<ElementB, 8>;
329
+
330
+ using ElementC = float;
331
+ using LayoutC = layout::RowMajor;
332
+ using FragmentC = Array<float, 4>;
333
+
334
+ using Operator = Operator_;
335
+ using ArchTag = arch::Sm89;
336
+
337
+ CUTLASS_HOST_DEVICE
338
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
339
+ FragmentC const &c) const {
340
+
341
+ #if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
342
+
343
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
344
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
345
+ float const *C = reinterpret_cast<float const *>(&c);
346
+ float *D = reinterpret_cast<float *>(&d);
347
+
348
+ asm(
349
+ "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
350
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
351
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
352
+ :
353
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
354
+ "r"(B[0]), "r"(B[1]),
355
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
356
+ );
357
+
358
+ #else
359
+
360
+ CUTLASS_UNUSED(d);
361
+ CUTLASS_UNUSED(a);
362
+ CUTLASS_UNUSED(b);
363
+ CUTLASS_UNUSED(c);
364
+ CUTLASS_NOT_IMPLEMENTED();
365
+
366
+ #endif
367
+ }
368
+ };
369
+
370
+ ////////////////////////////////////////////////////////////////////////////////
371
+ //
372
+ // Matrix Multiply 16832 - Float {E4M3, E5M2}, FP16 accumulation
373
+ //
374
+ ////////////////////////////////////////////////////////////////////////////////
375
+
376
+ /// Matrix multiply-add operation - F16 = fe4m3 * fe4m3 + F16
377
+ template <typename Operator_>
378
+ struct Mma<
379
+ gemm::GemmShape<16, 8, 32>,
380
+ 32,
381
+ cutlass::float_e4m3_t,
382
+ layout::RowMajor,
383
+ cutlass::float_e4m3_t,
384
+ layout::ColumnMajor,
385
+ cutlass::half_t,
386
+ layout::RowMajor,
387
+ Operator_> {
388
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
389
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
390
+ "Invalid operator for SM89 FP8 instruction");
391
+
392
+ using Shape = gemm::GemmShape<16, 8, 32>;
393
+
394
+ using ElementA = cutlass::float_e4m3_t;
395
+ using LayoutA = layout::RowMajor;
396
+ using FragmentA = Array<ElementA, 16>;
397
+
398
+ using ElementB = cutlass::float_e4m3_t;
399
+ using LayoutB = layout::ColumnMajor;
400
+ using FragmentB = Array<ElementB, 8>;
401
+
402
+ using ElementC = cutlass::half_t;
403
+ using LayoutC = layout::RowMajor;
404
+ using FragmentC = Array<cutlass::half_t, 4>;
405
+
406
+ using Operator = Operator_;
407
+ using ArchTag = arch::Sm89;
408
+
409
+ CUTLASS_HOST_DEVICE
410
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
411
+ FragmentC const &c) const {
412
+
413
+ #if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
414
+
415
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
416
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
417
+ uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
418
+ uint32_t *D = reinterpret_cast<uint32_t *>(&d);
419
+
420
+ asm(
421
+ "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 "
422
+ "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
423
+ : "=r"(D[0]), "=r"(D[1])
424
+ :
425
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
426
+ "r"(B[0]), "r"(B[1]),
427
+ "r"(C[0]), "r"(C[1])
428
+ );
429
+
430
+ #else
431
+
432
+ CUTLASS_UNUSED(d);
433
+ CUTLASS_UNUSED(a);
434
+ CUTLASS_UNUSED(b);
435
+ CUTLASS_UNUSED(c);
436
+ CUTLASS_NOT_IMPLEMENTED();
437
+
438
+ #endif
439
+ }
440
+ };
441
+
442
+ /// Matrix multiply-add operation - F16 = fe4m3 * fe5m2 + F16
443
+ template <typename Operator_>
444
+ struct Mma<
445
+ gemm::GemmShape<16, 8, 32>,
446
+ 32,
447
+ cutlass::float_e4m3_t,
448
+ layout::RowMajor,
449
+ cutlass::float_e5m2_t,
450
+ layout::ColumnMajor,
451
+ cutlass::half_t,
452
+ layout::RowMajor,
453
+ Operator_> {
454
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
455
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
456
+ "Invalid operator for SM89 FP8 instruction");
457
+
458
+ using Shape = gemm::GemmShape<16, 8, 32>;
459
+
460
+ using ElementA = cutlass::float_e4m3_t;
461
+ using LayoutA = layout::RowMajor;
462
+ using FragmentA = Array<ElementA, 16>;
463
+
464
+ using ElementB = cutlass::float_e5m2_t;
465
+ using LayoutB = layout::ColumnMajor;
466
+ using FragmentB = Array<ElementB, 8>;
467
+
468
+ using ElementC = cutlass::half_t;
469
+ using LayoutC = layout::RowMajor;
470
+ using FragmentC = Array<cutlass::half_t, 4>;
471
+
472
+ using Operator = Operator_;
473
+ using ArchTag = arch::Sm89;
474
+
475
+ CUTLASS_HOST_DEVICE
476
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
477
+ FragmentC const &c) const {
478
+
479
+ #if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
480
+
481
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
482
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
483
+ uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
484
+ uint32_t *D = reinterpret_cast<uint32_t *>(&d);
485
+
486
+ asm(
487
+ "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 "
488
+ "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
489
+ : "=r"(D[0]), "=r"(D[1])
490
+ :
491
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
492
+ "r"(B[0]), "r"(B[1]),
493
+ "r"(C[0]), "r"(C[1])
494
+ );
495
+
496
+ #else
497
+
498
+ CUTLASS_UNUSED(d);
499
+ CUTLASS_UNUSED(a);
500
+ CUTLASS_UNUSED(b);
501
+ CUTLASS_UNUSED(c);
502
+ CUTLASS_NOT_IMPLEMENTED();
503
+
504
+ #endif
505
+ }
506
+ };
507
+
508
+ /// Matrix multiply-add operation - F16 = fe5m2 * fe4m3 + F16
509
+ template <typename Operator_>
510
+ struct Mma<
511
+ gemm::GemmShape<16, 8, 32>,
512
+ 32,
513
+ cutlass::float_e5m2_t,
514
+ layout::RowMajor,
515
+ cutlass::float_e4m3_t,
516
+ layout::ColumnMajor,
517
+ cutlass::half_t,
518
+ layout::RowMajor,
519
+ Operator_> {
520
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
521
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
522
+ "Invalid operator for SM89 FP8 instruction");
523
+
524
+ using Shape = gemm::GemmShape<16, 8, 32>;
525
+
526
+ using ElementA = cutlass::float_e5m2_t;
527
+ using LayoutA = layout::RowMajor;
528
+ using FragmentA = Array<ElementA, 16>;
529
+
530
+ using ElementB = cutlass::float_e4m3_t;
531
+ using LayoutB = layout::ColumnMajor;
532
+ using FragmentB = Array<ElementB, 8>;
533
+
534
+ using ElementC = cutlass::half_t;
535
+ using LayoutC = layout::RowMajor;
536
+ using FragmentC = Array<cutlass::half_t, 4>;
537
+
538
+ using Operator = Operator_;
539
+ using ArchTag = arch::Sm89;
540
+
541
+ CUTLASS_HOST_DEVICE
542
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
543
+ FragmentC const &c) const {
544
+
545
+ #if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
546
+
547
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
548
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
549
+ uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
550
+ uint32_t *D = reinterpret_cast<uint32_t *>(&d);
551
+
552
+ asm(
553
+ "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 "
554
+ "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
555
+ : "=r"(D[0]), "=r"(D[1])
556
+ :
557
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
558
+ "r"(B[0]), "r"(B[1]),
559
+ "r"(C[0]), "r"(C[1])
560
+ );
561
+
562
+ #else
563
+
564
+ CUTLASS_UNUSED(d);
565
+ CUTLASS_UNUSED(a);
566
+ CUTLASS_UNUSED(b);
567
+ CUTLASS_UNUSED(c);
568
+ CUTLASS_NOT_IMPLEMENTED();
569
+
570
+ #endif
571
+ }
572
+ };
573
+
574
+ /// Matrix multiply-add operation - F16 = fe5m2 * fe5m2 + F16
575
+ template <typename Operator_>
576
+ struct Mma<
577
+ gemm::GemmShape<16, 8, 32>,
578
+ 32,
579
+ cutlass::float_e5m2_t,
580
+ layout::RowMajor,
581
+ cutlass::float_e5m2_t,
582
+ layout::ColumnMajor,
583
+ cutlass::half_t,
584
+ layout::RowMajor,
585
+ Operator_> {
586
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
587
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
588
+ "Invalid operator for SM89 FP8 instruction");
589
+
590
+ using Shape = gemm::GemmShape<16, 8, 32>;
591
+
592
+ using ElementA = cutlass::float_e5m2_t;
593
+ using LayoutA = layout::RowMajor;
594
+ using FragmentA = Array<ElementA, 16>;
595
+
596
+ using ElementB = cutlass::float_e5m2_t;
597
+ using LayoutB = layout::ColumnMajor;
598
+ using FragmentB = Array<ElementB, 8>;
599
+
600
+ using ElementC = cutlass::half_t;
601
+ using LayoutC = layout::RowMajor;
602
+ using FragmentC = Array<cutlass::half_t, 4>;
603
+
604
+ using Operator = Operator_;
605
+ using ArchTag = arch::Sm89;
606
+
607
+ CUTLASS_HOST_DEVICE
608
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
609
+ FragmentC const &c) const {
610
+
611
+ #if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
612
+
613
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
614
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
615
+ uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
616
+ uint32_t *D = reinterpret_cast<uint32_t *>(&d);
617
+
618
+ asm(
619
+ "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 "
620
+ "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
621
+ : "=r"(D[0]), "=r"(D[1])
622
+ :
623
+ "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
624
+ "r"(B[0]), "r"(B[1]),
625
+ "r"(C[0]), "r"(C[1])
626
+ );
627
+
628
+ #else
629
+
630
+ CUTLASS_UNUSED(d);
631
+ CUTLASS_UNUSED(a);
632
+ CUTLASS_UNUSED(b);
633
+ CUTLASS_UNUSED(c);
634
+ CUTLASS_NOT_IMPLEMENTED();
635
+
636
+ #endif
637
+ }
638
+ };
639
+
640
+ } // namespace arch
641
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+ #include "cutlass/cutlass.h"
37
+ #include CUDA_STD_HEADER(cassert)
38
+
39
+ #include "mma.h"
40
+ #include "cutlass/layout/matrix.h"
41
+ #include "cutlass/numeric_types.h"
42
+ #include "cutlass/arch/config.h"
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace cutlass {
47
+ namespace arch {
48
+
49
+ ////////////////////////////////////////////////////////////////////////////////
50
+ /// Matrix Multiply-Add 16x8x4 fp64
51
+ ////////////////////////////////////////////////////////////////////////////////
52
+
53
+ /// Matrix multiply-add operation: F64 = F64 * F64 + F64
54
+ template <>
55
+ struct Mma<
56
+ gemm::GemmShape<16,8,4>,
57
+ 32,
58
+ double,
59
+ layout::RowMajor,
60
+ double,
61
+ layout::ColumnMajor,
62
+ double,
63
+ layout::RowMajor,
64
+ OpMultiplyAdd> {
65
+
66
+ using Shape = gemm::GemmShape<16,8,4>;
67
+
68
+ using ElementA = double;
69
+ using LayoutA = layout::RowMajor;
70
+ using FragmentA = Array<double, 2>;
71
+
72
+ using ElementB = double;
73
+ using LayoutB = layout::ColumnMajor;
74
+ using FragmentB = Array<double, 1>;
75
+
76
+ using ElementC = double;
77
+ using LayoutC = layout::RowMajor;
78
+ using FragmentC = Array<double, 4>;
79
+
80
+ using Operator = OpMultiplyAdd;
81
+
82
+ using ArchTag = arch::Sm90;
83
+
84
+ CUTLASS_HOST_DEVICE
85
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
86
+ FragmentC const &c) const {
87
+
88
+ #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)
89
+
90
+ double const *A = reinterpret_cast<double const *>(&a);
91
+ double const *B = reinterpret_cast<double const *>(&b);
92
+
93
+ double const *C = reinterpret_cast<double const *>(&c);
94
+ double *D = reinterpret_cast<double *>(&d);
95
+
96
+ asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
97
+ : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
98
+ : "d"(A[0]), "d"(A[1]),
99
+ "d"(B[0]),
100
+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
101
+
102
+ #else
103
+ CUTLASS_UNUSED(d);
104
+ CUTLASS_UNUSED(a);
105
+ CUTLASS_UNUSED(b);
106
+ CUTLASS_UNUSED(c);
107
+ CUTLASS_NOT_IMPLEMENTED();
108
+ #endif
109
+ }
110
+ };
111
+
112
+ ////////////////////////////////////////////////////////////////////////////////
113
+ /// Matrix Multiply-Add 16x8x8 fp64
114
+ ////////////////////////////////////////////////////////////////////////////////
115
+
116
+ /// Matrix multiply-add operation: F64 = F64 * F64 + F64
117
+ template <>
118
+ struct Mma<
119
+ gemm::GemmShape<16,8,8>,
120
+ 32,
121
+ double,
122
+ layout::RowMajor,
123
+ double,
124
+ layout::ColumnMajor,
125
+ double,
126
+ layout::RowMajor,
127
+ OpMultiplyAdd> {
128
+
129
+ using Shape = gemm::GemmShape<16,8,8>;
130
+
131
+ using ElementA = double;
132
+ using LayoutA = layout::RowMajor;
133
+ using FragmentA = Array<double, 4>;
134
+
135
+ using ElementB = double;
136
+ using LayoutB = layout::ColumnMajor;
137
+ using FragmentB = Array<double, 2>;
138
+
139
+ using ElementC = double;
140
+ using LayoutC = layout::RowMajor;
141
+ using FragmentC = Array<double, 4>;
142
+
143
+ using Operator = OpMultiplyAdd;
144
+
145
+ using ArchTag = arch::Sm90;
146
+
147
+ CUTLASS_HOST_DEVICE
148
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
149
+ FragmentC const &c) const {
150
+
151
+ #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)
152
+
153
+ double const *A = reinterpret_cast<double const *>(&a);
154
+ double const *B = reinterpret_cast<double const *>(&b);
155
+
156
+ double const *C = reinterpret_cast<double const *>(&c);
157
+ double *D = reinterpret_cast<double *>(&d);
158
+
159
+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
160
+ : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3])
161
+ : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]),
162
+ "d"(B[0]), "d"(B[1]),
163
+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
164
+
165
+ #else
166
+
167
+ CUTLASS_UNUSED(d);
168
+ CUTLASS_UNUSED(a);
169
+ CUTLASS_UNUSED(b);
170
+ CUTLASS_UNUSED(c);
171
+ CUTLASS_NOT_IMPLEMENTED();
172
+ #endif
173
+ }
174
+ };
175
+
176
+ ////////////////////////////////////////////////////////////////////////////////
177
+ /// Matrix Multiply-Add 16x8x16 fp64
178
+ ////////////////////////////////////////////////////////////////////////////////
179
+
180
+ /// Matrix multiply-add operation: F64 = F64 * F64 + F64
181
+ template <>
182
+ struct Mma<
183
+ gemm::GemmShape<16,8,16>,
184
+ 32,
185
+ double,
186
+ layout::RowMajor,
187
+ double,
188
+ layout::ColumnMajor,
189
+ double,
190
+ layout::RowMajor,
191
+ OpMultiplyAdd> {
192
+
193
+ using Shape = gemm::GemmShape<16,8,16>;
194
+
195
+ using ElementA = double;
196
+ using LayoutA = layout::RowMajor;
197
+ using FragmentA = Array<double, 8>;
198
+
199
+ using ElementB = double;
200
+ using LayoutB = layout::ColumnMajor;
201
+ using FragmentB = Array<double, 4>;
202
+
203
+ using ElementC = double;
204
+ using LayoutC = layout::RowMajor;
205
+ using FragmentC = Array<double, 4>;
206
+
207
+ using Operator = OpMultiplyAdd;
208
+
209
+ using ArchTag = arch::Sm90;
210
+
211
+ CUTLASS_HOST_DEVICE
212
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
213
+ FragmentC const &c) const {
214
+
215
+ #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)
216
+
217
+ double const *A = reinterpret_cast<double const *>(&a);
218
+ double const *B = reinterpret_cast<double const *>(&b);
219
+
220
+ double const *C = reinterpret_cast<double const *>(&c);
221
+ double *D = reinterpret_cast<double *>(&d);
222
+
223
+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n"
224
+ : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
225
+ : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]),
226
+ "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]),
227
+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
228
+
229
+ #else
230
+ CUTLASS_NOT_IMPLEMENTED();
231
+ #endif
232
+ }
233
+ };
234
+
235
+ /////////////////////////////////////////////////////////////////////////////////////////////////
236
+
237
+ } // namespace arch
238
+ } // namespace cutlass
239
+
240
+ /////////////////////////////////////////////////////////////////////////////////////////////////
241
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h ADDED
@@ -0,0 +1,1234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Sparse matrix multiply accumulate for SM80
34
+ */
35
+
36
+ #pragma once
37
+ #include "cutlass/cutlass.h"
38
+ #include CUDA_STD_HEADER(cassert)
39
+
40
+ #include "mma.h"
41
+ #include "cutlass/layout/matrix.h"
42
+ #include "cutlass/numeric_types.h"
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))
47
+
48
+ #define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1
49
+
50
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
51
+ #define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED
52
+ #endif
53
+
54
+ #endif
55
+
56
+ /////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ namespace cutlass {
59
+ namespace arch {
60
+
61
+ /////////////////////////////////////////////////////////////////////////////////////////////////
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+ //
65
+ // Sparse Matrix Multiply 16832
66
+ //
67
+ ////////////////////////////////////////////////////////////////////////////////
68
+
69
+ /// Matrix multiply-add operation: F16 = F16 * F16 + F16
70
+ template <>
71
+ struct SparseMma<
72
+ gemm::GemmShape<16, 8, 32>,
73
+ 32,
74
+ half_t,
75
+ layout::RowMajor,
76
+ half_t,
77
+ layout::ColumnMajor,
78
+ half_t,
79
+ layout::RowMajor,
80
+ OpMultiplyAdd,
81
+ SPFormatType::Thread
82
+ > {
83
+
84
+ using Shape = gemm::GemmShape<16, 8, 32>;
85
+
86
+ using ElementA = half_t;
87
+ using LayoutA = layout::RowMajor;
88
+ using FragmentA = Array<half_t, 8>;
89
+
90
+ using ElementB = half_t;
91
+ using LayoutB = layout::ColumnMajor;
92
+ using FragmentB = Array<half_t, 8>;
93
+
94
+ using ElementC = half_t;
95
+ using LayoutC = layout::RowMajor;
96
+ using FragmentC = Array<half_t, 4>;
97
+
98
+ using FragmentE = uint32_t;
99
+
100
+ using Operator = OpMultiplyAdd;
101
+ using ArchTag = arch::Sm80;
102
+
103
+ static int const kSparse = 2;
104
+
105
+ static int const kMetaSizeInBits = 2;
106
+
107
+ static int const kMaxID2 = 2;
108
+
109
+ /// Computes multiply-add
110
+ CUTLASS_HOST_DEVICE
111
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
112
+ FragmentC const &c, uint32_t const &E, int const id2) const {
113
+
114
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
115
+
116
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
117
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
118
+ uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
119
+ uint32_t *D = reinterpret_cast<uint32_t *>(&d);
120
+
121
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
122
+ if (id2 == 0) {
123
+ asm volatile(
124
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
125
+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n"
126
+ : "=r"(D[0]), "=r"(D[1])
127
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
128
+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
129
+ }
130
+ else if (id2 == 1) {
131
+ asm volatile(
132
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
133
+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n"
134
+ : "=r"(D[0]), "=r"(D[1])
135
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
136
+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
137
+ }
138
+ else {
139
+ assert(0);
140
+ }
141
+ #else
142
+ if (id2 == 0) {
143
+ asm volatile(
144
+ "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
145
+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n"
146
+ : "=r"(D[0]), "=r"(D[1])
147
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
148
+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
149
+ }
150
+ else if (id2 == 1) {
151
+ asm volatile(
152
+ "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
153
+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n"
154
+ : "=r"(D[0]), "=r"(D[1])
155
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
156
+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
157
+ }
158
+ else {
159
+ assert(0);
160
+ }
161
+ #endif
162
+
163
+ #else
164
+ CUTLASS_UNUSED(a);
165
+ CUTLASS_UNUSED(b);
166
+ CUTLASS_UNUSED(c);
167
+ CUTLASS_UNUSED(d);
168
+ assert(0);
169
+ #endif
170
+ }
171
+ };
172
+
173
+ ////////////////////////////////////////////////////////////////////////////////
174
+
175
+ /// Matrix multiply-add operation: F32 = F16 * F16 + F32
176
+ template <>
177
+ struct SparseMma<
178
+ gemm::GemmShape<16, 8, 32>,
179
+ 32,
180
+ half_t,
181
+ layout::RowMajor,
182
+ half_t,
183
+ layout::ColumnMajor,
184
+ float,
185
+ layout::RowMajor,
186
+ OpMultiplyAdd,
187
+ SPFormatType::Thread
188
+ > {
189
+
190
+ using Shape = gemm::GemmShape<16, 8, 32>;
191
+
192
+ using ElementA = half_t;
193
+ using LayoutA = layout::RowMajor;
194
+ using FragmentA = Array<half_t, 8>;
195
+
196
+ using ElementB = half_t;
197
+ using LayoutB = layout::ColumnMajor;
198
+ using FragmentB = Array<half_t, 8>;
199
+
200
+ using ElementC = float;
201
+ using LayoutC = layout::RowMajor;
202
+ using FragmentC = Array<float, 4>;
203
+
204
+ using FragmentE = uint32_t;
205
+
206
+ using Operator = OpMultiplyAdd;
207
+ using ArchTag = arch::Sm80;
208
+
209
+ static int const kSparse = 2;
210
+
211
+ static int const kMetaSizeInBits = 2;
212
+
213
+ static int const kMaxID2 = 2;
214
+
215
+ /// Computes multiply-add
216
+ CUTLASS_HOST_DEVICE
217
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
218
+ FragmentC const &c, uint32_t const &E, int const id2) const {
219
+
220
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
221
+
222
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
223
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
224
+ float const *C = reinterpret_cast<float const *>(&c);
225
+ float *D = reinterpret_cast<float *>(&d);
226
+
227
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
228
+ if (id2 == 0) {
229
+ asm volatile(
230
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
231
+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
232
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
233
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
234
+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
235
+ "r"(E));
236
+ }
237
+ else if (id2 == 1) {
238
+ asm volatile(
239
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
240
+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
241
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
242
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
243
+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
244
+ "r"(E));
245
+ }
246
+ else {
247
+ assert(0);
248
+ }
249
+ #else
250
+ if (id2 == 0) {
251
+ asm volatile(
252
+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
253
+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
254
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
255
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
256
+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
257
+ "r"(E));
258
+ }
259
+ else if (id2 == 1) {
260
+ asm volatile(
261
+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
262
+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
263
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
264
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
265
+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
266
+ "r"(E));
267
+ }
268
+ else {
269
+ assert(0);
270
+ }
271
+
272
+ #endif
273
+
274
+ #else
275
+ CUTLASS_UNUSED(a);
276
+ CUTLASS_UNUSED(b);
277
+ CUTLASS_UNUSED(c);
278
+ CUTLASS_UNUSED(d);
279
+ assert(0);
280
+ #endif
281
+ }
282
+ };
283
+
284
+ ////////////////////////////////////////////////////////////////////////////////
285
+ //
286
+ // Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation
287
+ //
288
+ ////////////////////////////////////////////////////////////////////////////////
289
+
290
+ /// Matrix multiply-add operation: F32 = bf16 * bf16 + F32
291
+ template <>
292
+ struct SparseMma<gemm::GemmShape<16, 8, 32>, 32, bfloat16_t, layout::RowMajor,
293
+ bfloat16_t, layout::ColumnMajor, float, layout::RowMajor,
294
+ OpMultiplyAdd, SPFormatType::Thread> {
295
+ using Shape = gemm::GemmShape<16, 8, 32>;
296
+
297
+ using ElementA = bfloat16_t;
298
+ using LayoutA = layout::RowMajor;
299
+ using FragmentA = Array<bfloat16_t, 8>;
300
+
301
+ using ElementB = bfloat16_t;
302
+ using LayoutB = layout::ColumnMajor;
303
+ using FragmentB = Array<bfloat16_t, 8>;
304
+
305
+ using ElementC = float;
306
+ using LayoutC = layout::RowMajor;
307
+ using FragmentC = Array<float, 4>;
308
+
309
+ using FragmentE = uint32_t;
310
+
311
+ using Operator = OpMultiplyAdd;
312
+ using ArchTag = arch::Sm80;
313
+
314
+ static int const kSparse = 2;
315
+
316
+ static int const kMetaSizeInBits = 2;
317
+
318
+ static int const kMaxID2 = 2;
319
+
320
+ CUTLASS_HOST_DEVICE
321
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
322
+ FragmentC const &c, uint32_t const &E, int const id2) const {
323
+
324
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
325
+
326
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
327
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
328
+ float const *C = reinterpret_cast<float const *>(&c);
329
+ float *D = reinterpret_cast<float *>(&d);
330
+
331
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
332
+ if (id2 == 0) {
333
+ asm volatile(
334
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
335
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
336
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
337
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
338
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
339
+ } else if (id2 == 1) {
340
+ asm volatile(
341
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
342
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
343
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
344
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
345
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
346
+ } else {
347
+ assert(0);
348
+ }
349
+ #else
350
+ if (id2 == 0) {
351
+ asm volatile(
352
+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
353
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
354
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
355
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
356
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
357
+ } else if (id2 == 1) {
358
+ asm volatile(
359
+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
360
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
361
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
362
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
363
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
364
+ } else {
365
+ assert(0);
366
+ }
367
+ #endif
368
+
369
+ #else
370
+
371
+ CUTLASS_UNUSED(a);
372
+ CUTLASS_UNUSED(b);
373
+ CUTLASS_UNUSED(c);
374
+ CUTLASS_UNUSED(d);
375
+ assert(0);
376
+ #endif
377
+ }
378
+ };
379
+
380
+ ////////////////////////////////////////////////////////////////////////////////
381
+ //
382
+ // Sparse Matrix Multiply 16816 - Float TF32
383
+ //
384
+ ////////////////////////////////////////////////////////////////////////////////
385
+
386
+ /// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
387
+ template <>
388
+ struct SparseMma<gemm::GemmShape<16, 8, 16>, 32, tfloat32_t, layout::RowMajor,
389
+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor,
390
+ OpMultiplyAdd, SPFormatType::Thread> {
391
+ using Shape = gemm::GemmShape<16, 8, 16>;
392
+
393
+ using ElementA = tfloat32_t;
394
+ using LayoutA = layout::RowMajor;
395
+ using FragmentA = Array<tfloat32_t, 4>;
396
+
397
+ using ElementB = tfloat32_t;
398
+ using LayoutB = layout::ColumnMajor;
399
+ using FragmentB = Array<tfloat32_t, 4>;
400
+
401
+ using ElementC = float;
402
+ using LayoutC = layout::RowMajor;
403
+ using FragmentC = Array<float, 4>;
404
+
405
+ using FragmentE = uint32_t;
406
+
407
+ using Operator = OpMultiplyAdd;
408
+ using ArchTag = arch::Sm80;
409
+
410
+ static int const kSparse = 2;
411
+
412
+ static int const kMetaSizeInBits = 4;
413
+
414
+ static int const kMaxID2 = 2;
415
+
416
+ CUTLASS_HOST_DEVICE
417
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
418
+ FragmentC const &c, uint32_t const &E, int const id2) const {
419
+
420
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
421
+
422
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
423
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
424
+ float const *C = reinterpret_cast<float const *>(&c);
425
+ float *D = reinterpret_cast<float *>(&d);
426
+
427
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
428
+ if (id2 == 0) {
429
+ asm volatile(
430
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
431
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
432
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
433
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
434
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
435
+ } else if (id2 == 1) {
436
+ asm volatile(
437
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
438
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
439
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
440
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
441
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
442
+ } else {
443
+ assert(0);
444
+ }
445
+ #else
446
+ if (id2 == 0) {
447
+ asm volatile(
448
+ "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
449
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
450
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
451
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
452
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
453
+ } else if (id2 == 1) {
454
+ asm volatile(
455
+ "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
456
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
457
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
458
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
459
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
460
+ } else {
461
+ assert(0);
462
+ }
463
+ #endif
464
+
465
+ #else
466
+
467
+ CUTLASS_UNUSED(a);
468
+ CUTLASS_UNUSED(b);
469
+ CUTLASS_UNUSED(c);
470
+ CUTLASS_UNUSED(d);
471
+ assert(0);
472
+ #endif
473
+ }
474
+ };
475
+
476
+ ////////////////////////////////////////////////////////////////////////////////
477
+ //
478
+ // Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE
479
+ //
480
+ ////////////////////////////////////////////////////////////////////////////////
481
+
482
+ /// Matrix multiply-add operation: S32 = S8 * S8 + S32
483
+ template <>
484
+ struct SparseMma<
485
+ gemm::GemmShape<16,8,64>,
486
+ 32,
487
+ int8_t,
488
+ layout::RowMajor,
489
+ int8_t,
490
+ layout::ColumnMajor,
491
+ int,
492
+ layout::RowMajor,
493
+ OpMultiplyAddSaturate,
494
+ SPFormatType::Thread> {
495
+
496
+ using Shape = gemm::GemmShape<16,8,64>;
497
+
498
+ using ElementA = int8_t;
499
+ using LayoutA = layout::RowMajor;
500
+ using FragmentA = Array<int8_t, 16>;
501
+
502
+ using ElementB = int8_t;
503
+ using LayoutB = layout::ColumnMajor;
504
+ using FragmentB = Array<int8_t, 16>;
505
+
506
+ using ElementC = int;
507
+ using LayoutC = layout::RowMajor;
508
+ using FragmentC = Array<int, 4>;
509
+
510
+ using FragmentE = uint32_t;
511
+
512
+ using Operator = OpMultiplyAddSaturate;
513
+ using ArchTag = arch::Sm80;
514
+
515
+ static int const kSparse = 2;
516
+
517
+ static int const kMetaSizeInBits = 2;
518
+
519
+ static int const kMaxID2 = 1;
520
+
521
+ /// Computes multiply-add
522
+ CUTLASS_HOST_DEVICE
523
+ void operator()(
524
+ FragmentC &d,
525
+ FragmentA const &a,
526
+ FragmentB const &b,
527
+ FragmentC const &c,
528
+ uint32_t const &E,
529
+ int const id2
530
+ ) const {
531
+
532
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
533
+
534
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
535
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
536
+
537
+ int const *C = reinterpret_cast<int const *>(&c);
538
+ int *D = reinterpret_cast<int *>(&d);
539
+
540
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
541
+ if (id2 == 0) {
542
+ asm volatile(
543
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
544
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
545
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
546
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
547
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
548
+ } else {
549
+ assert(0);
550
+ }
551
+ #else
552
+ if (id2 == 0) {
553
+ asm volatile(
554
+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
555
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
556
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
557
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
558
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
559
+ } else {
560
+ assert(0);
561
+ }
562
+ #endif
563
+
564
+ #else
565
+ CUTLASS_UNUSED(a);
566
+ CUTLASS_UNUSED(b);
567
+ CUTLASS_UNUSED(c);
568
+ CUTLASS_UNUSED(d);
569
+ assert(0);
570
+ #endif
571
+ }
572
+ };
573
+
574
+ /// Matrix multiply-add operation: S32 = S8 * U8 + S32
575
+ template <>
576
+ struct SparseMma<
577
+ gemm::GemmShape<16,8,64>,
578
+ 32,
579
+ int8_t,
580
+ layout::RowMajor,
581
+ uint8_t,
582
+ layout::ColumnMajor,
583
+ int,
584
+ layout::RowMajor,
585
+ OpMultiplyAddSaturate,
586
+ SPFormatType::Thread> {
587
+
588
+ using Shape = gemm::GemmShape<16,8,64>;
589
+
590
+ using ElementA = int8_t;
591
+ using LayoutA = layout::RowMajor;
592
+ using FragmentA = Array<int8_t, 16>;
593
+
594
+ using ElementB = uint8_t;
595
+ using LayoutB = layout::ColumnMajor;
596
+ using FragmentB = Array<uint8_t, 16>;
597
+
598
+ using ElementC = int;
599
+ using LayoutC = layout::RowMajor;
600
+ using FragmentC = Array<int, 4>;
601
+
602
+ using FragmentE = uint32_t;
603
+
604
+ using Operator = OpMultiplyAddSaturate;
605
+ using ArchTag = arch::Sm80;
606
+
607
+ static int const kSparse = 2;
608
+
609
+ static int const kMetaSizeInBits = 2;
610
+
611
+ static int const kMaxID2 = 1;
612
+
613
+ /// Computes multiply-add
614
+ CUTLASS_HOST_DEVICE
615
+ void operator()(
616
+ FragmentC &d,
617
+ FragmentA const &a,
618
+ FragmentB const &b,
619
+ FragmentC const &c,
620
+ uint32_t const &E,
621
+ int const id2
622
+ ) const {
623
+
624
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
625
+
626
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
627
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
628
+
629
+ int const *C = reinterpret_cast<int const *>(&c);
630
+ int *D = reinterpret_cast<int *>(&d);
631
+
632
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
633
+ if (id2 == 0) {
634
+ asm volatile(
635
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
636
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
637
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
638
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
639
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
640
+ } else {
641
+ assert(0);
642
+ }
643
+ #else
644
+ if (id2 == 0) {
645
+ asm volatile(
646
+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
647
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
648
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
649
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
650
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
651
+ } else {
652
+ assert(0);
653
+ }
654
+ #endif
655
+
656
+ #else
657
+
658
+ CUTLASS_UNUSED(a);
659
+ CUTLASS_UNUSED(b);
660
+ CUTLASS_UNUSED(c);
661
+ CUTLASS_UNUSED(d);
662
+ assert(0);
663
+ #endif
664
+ }
665
+ };
666
+
667
+ /// Matrix multiply-add operation: S32 = U8 * S8 + S32
668
+ template <>
669
+ struct SparseMma<
670
+ gemm::GemmShape<16,8,64>,
671
+ 32,
672
+ uint8_t,
673
+ layout::RowMajor,
674
+ int8_t,
675
+ layout::ColumnMajor,
676
+ int,
677
+ layout::RowMajor,
678
+ OpMultiplyAddSaturate,
679
+ SPFormatType::Thread> {
680
+
681
+ using Shape = gemm::GemmShape<16,8,64>;
682
+
683
+ using ElementA = uint8_t;
684
+ using LayoutA = layout::RowMajor;
685
+ using FragmentA = Array<uint8_t, 16>;
686
+
687
+ using ElementB = int8_t;
688
+ using LayoutB = layout::ColumnMajor;
689
+ using FragmentB = Array<int8_t, 16>;
690
+
691
+ using ElementC = int;
692
+ using LayoutC = layout::RowMajor;
693
+ using FragmentC = Array<int, 4>;
694
+
695
+ using FragmentE = uint32_t;
696
+
697
+ using Operator = OpMultiplyAddSaturate;
698
+ using ArchTag = arch::Sm80;
699
+
700
+ static int const kSparse = 2;
701
+
702
+ static int const kMetaSizeInBits = 2;
703
+
704
+ static int const kMaxID2 = 1;
705
+
706
+ /// Computes multiply-add
707
+ CUTLASS_HOST_DEVICE
708
+ void operator()(
709
+ FragmentC &d,
710
+ FragmentA const &a,
711
+ FragmentB const &b,
712
+ FragmentC const &c,
713
+ uint32_t const &E,
714
+ int const id2
715
+ ) const {
716
+
717
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
718
+
719
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
720
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
721
+
722
+ int const *C = reinterpret_cast<int const *>(&c);
723
+ int *D = reinterpret_cast<int *>(&d);
724
+
725
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
726
+ if (id2 == 0) {
727
+ asm volatile(
728
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
729
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
730
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
731
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
732
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
733
+ } else {
734
+ assert(0);
735
+ }
736
+ #else
737
+ if (id2 == 0) {
738
+ asm volatile(
739
+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
740
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
741
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
742
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
743
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
744
+ } else {
745
+ assert(0);
746
+ }
747
+ #endif
748
+
749
+ #else
750
+ CUTLASS_UNUSED(a);
751
+ CUTLASS_UNUSED(b);
752
+ CUTLASS_UNUSED(c);
753
+ CUTLASS_UNUSED(d);
754
+ assert(0);
755
+ #endif
756
+ }
757
+ };
758
+
759
+ /// Matrix multiply-add operation: S32 = U8 * U8 + S32
760
+ template <>
761
+ struct SparseMma<
762
+ gemm::GemmShape<16,8,64>,
763
+ 32,
764
+ uint8_t,
765
+ layout::RowMajor,
766
+ uint8_t,
767
+ layout::ColumnMajor,
768
+ int,
769
+ layout::RowMajor,
770
+ OpMultiplyAddSaturate,
771
+ SPFormatType::Thread> {
772
+
773
+ using Shape = gemm::GemmShape<16,8,64>;
774
+
775
+ using ElementA = uint8_t;
776
+ using LayoutA = layout::RowMajor;
777
+ using FragmentA = Array<uint8_t, 16>;
778
+
779
+ using ElementB = uint8_t;
780
+ using LayoutB = layout::ColumnMajor;
781
+ using FragmentB = Array<uint8_t, 16>;
782
+
783
+ using ElementC = int;
784
+ using LayoutC = layout::RowMajor;
785
+ using FragmentC = Array<int, 4>;
786
+
787
+ using FragmentE = uint32_t;
788
+
789
+ using Operator = OpMultiplyAddSaturate;
790
+ using ArchTag = arch::Sm80;
791
+
792
+ static int const kSparse = 2;
793
+
794
+ static int const kMetaSizeInBits = 2;
795
+
796
+ static int const kMaxID2 = 1;
797
+
798
+ /// Computes multiply-add
799
+ CUTLASS_HOST_DEVICE
800
+ void operator()(
801
+ FragmentC &d,
802
+ FragmentA const &a,
803
+ FragmentB const &b,
804
+ FragmentC const &c,
805
+ uint32_t const &E,
806
+ int const id2
807
+ ) const {
808
+
809
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
810
+
811
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
812
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
813
+
814
+ int const *C = reinterpret_cast<int const *>(&c);
815
+ int *D = reinterpret_cast<int *>(&d);
816
+
817
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
818
+ if (id2 == 0) {
819
+ asm volatile(
820
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
821
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
822
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
823
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
824
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
825
+ } else {
826
+ assert(0);
827
+ }
828
+ #else
829
+ if (id2 == 0) {
830
+ asm volatile(
831
+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
832
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
833
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
834
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
835
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
836
+ } else {
837
+ assert(0);
838
+ }
839
+ #endif
840
+
841
+ #else
842
+ CUTLASS_UNUSED(a);
843
+ CUTLASS_UNUSED(b);
844
+ CUTLASS_UNUSED(c);
845
+ CUTLASS_UNUSED(d);
846
+ assert(0);
847
+ #endif
848
+ }
849
+ };
850
+
851
+ ////////////////////////////////////////////////////////////////////////////////
852
+ //
853
+ // Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE
854
+ //
855
+ ////////////////////////////////////////////////////////////////////////////////
856
+
857
+ /// Matrix multiply-add operation: S32 = S4 * S4 + S32
858
+ template <>
859
+ struct SparseMma<
860
+ gemm::GemmShape<16,8,128>,
861
+ 32,
862
+ cutlass::int4b_t,
863
+ layout::RowMajor,
864
+ cutlass::int4b_t,
865
+ layout::ColumnMajor,
866
+ int,
867
+ layout::RowMajor,
868
+ OpMultiplyAddSaturate,
869
+ SPFormatType::Thread> {
870
+
871
+ using Shape = gemm::GemmShape<16,8,128>;
872
+
873
+ using ElementA = cutlass::int4b_t;
874
+ using LayoutA = layout::RowMajor;
875
+ using FragmentA = Array<cutlass::int4b_t, 32>;
876
+
877
+ using ElementB = cutlass::int4b_t;
878
+ using LayoutB = layout::ColumnMajor;
879
+ using FragmentB = Array<cutlass::int4b_t, 32>;
880
+
881
+ using ElementC = int;
882
+ using LayoutC = layout::RowMajor;
883
+ using FragmentC = Array<int, 4>;
884
+
885
+ using FragmentE = uint32_t;
886
+
887
+ using Operator = OpMultiplyAddSaturate;
888
+ using ArchTag = arch::Sm80;
889
+
890
+ static int const kSparse = 2;
891
+
892
+ static int const kMetaSizeInBits = 2;
893
+
894
+ static int const kMaxID2 = 1;
895
+
896
+ /// Computes multiply-add
897
+ CUTLASS_HOST_DEVICE
898
+ void operator()(
899
+ FragmentC &d,
900
+ FragmentA const &a,
901
+ FragmentB const &b,
902
+ FragmentC const &c,
903
+ uint32_t const &E,
904
+ int const id2
905
+ ) const {
906
+
907
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
908
+
909
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
910
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
911
+
912
+ int const *C = reinterpret_cast<int const *>(&c);
913
+ int *D = reinterpret_cast<int *>(&d);
914
+
915
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
916
+ if (id2 == 0) {
917
+ asm volatile(
918
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
919
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
920
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
921
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
922
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
923
+ } else {
924
+ assert(0);
925
+ }
926
+ #else
927
+ if (id2 == 0) {
928
+ asm volatile(
929
+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
930
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
931
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
932
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
933
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
934
+ } else {
935
+ assert(0);
936
+ }
937
+ #endif
938
+
939
+ #else
940
+
941
+ CUTLASS_UNUSED(a);
942
+ CUTLASS_UNUSED(b);
943
+ CUTLASS_UNUSED(c);
944
+ CUTLASS_UNUSED(d);
945
+ assert(0);
946
+ #endif
947
+ }
948
+ };
949
+
950
+ /// Matrix multiply-add operation: S32 = S4 * U4 + S32
951
+ template <>
952
+ struct SparseMma<
953
+ gemm::GemmShape<16,8,128>,
954
+ 32,
955
+ cutlass::int4b_t,
956
+ layout::RowMajor,
957
+ cutlass::uint4b_t,
958
+ layout::ColumnMajor,
959
+ int,
960
+ layout::RowMajor,
961
+ OpMultiplyAddSaturate,
962
+ SPFormatType::Thread> {
963
+
964
+ using Shape = gemm::GemmShape<16,8,128>;
965
+
966
+ using ElementA = cutlass::int4b_t;
967
+ using LayoutA = layout::RowMajor;
968
+ using FragmentA = Array<cutlass::int4b_t, 32>;
969
+
970
+ using ElementB = cutlass::uint4b_t;
971
+ using LayoutB = layout::ColumnMajor;
972
+ using FragmentB = Array<cutlass::uint4b_t, 32>;
973
+
974
+ using ElementC = int;
975
+ using LayoutC = layout::RowMajor;
976
+ using FragmentC = Array<int, 4>;
977
+
978
+ using FragmentE = uint32_t;
979
+
980
+ using Operator = OpMultiplyAddSaturate;
981
+ using ArchTag = arch::Sm80;
982
+
983
+ static int const kSparse = 2;
984
+
985
+ static int const kMetaSizeInBits = 2;
986
+
987
+ static int const kMaxID2 = 1;
988
+
989
+ /// Computes multiply-add
990
+ CUTLASS_HOST_DEVICE
991
+ void operator()(
992
+ FragmentC &d,
993
+ FragmentA const &a,
994
+ FragmentB const &b,
995
+ FragmentC const &c,
996
+ uint32_t const &E,
997
+ int const id2
998
+ ) const {
999
+
1000
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
1001
+
1002
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1003
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1004
+
1005
+ int const *C = reinterpret_cast<int const *>(&c);
1006
+ int *D = reinterpret_cast<int *>(&d);
1007
+
1008
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
1009
+ if (id2 == 0) {
1010
+ asm volatile(
1011
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
1012
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
1013
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1014
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
1015
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
1016
+ } else {
1017
+ assert(0);
1018
+ }
1019
+ #else
1020
+ if (id2 == 0) {
1021
+ asm volatile(
1022
+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
1023
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
1024
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1025
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
1026
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
1027
+ } else {
1028
+ assert(0);
1029
+ }
1030
+ #endif
1031
+
1032
+ #else
1033
+
1034
+ CUTLASS_UNUSED(a);
1035
+ CUTLASS_UNUSED(b);
1036
+ CUTLASS_UNUSED(c);
1037
+ CUTLASS_UNUSED(d);
1038
+ assert(0);
1039
+ #endif
1040
+ }
1041
+ };
1042
+
1043
+ /// Matrix multiply-add operation: S32 = U4 * S4 + S32
1044
+ template <>
1045
+ struct SparseMma<
1046
+ gemm::GemmShape<16,8,128>,
1047
+ 32,
1048
+ cutlass::uint4b_t,
1049
+ layout::RowMajor,
1050
+ cutlass::int4b_t,
1051
+ layout::ColumnMajor,
1052
+ int,
1053
+ layout::RowMajor,
1054
+ OpMultiplyAddSaturate,
1055
+ SPFormatType::Thread> {
1056
+
1057
+ using Shape = gemm::GemmShape<16,8,128>;
1058
+
1059
+ using ElementA = cutlass::uint4b_t;
1060
+ using LayoutA = layout::RowMajor;
1061
+ using FragmentA = Array<cutlass::uint4b_t, 32>;
1062
+
1063
+ using ElementB = cutlass::int4b_t;
1064
+ using LayoutB = layout::ColumnMajor;
1065
+ using FragmentB = Array<cutlass::int4b_t, 32>;
1066
+
1067
+ using ElementC = int;
1068
+ using LayoutC = layout::RowMajor;
1069
+ using FragmentC = Array<int, 4>;
1070
+
1071
+ using FragmentE = uint32_t;
1072
+
1073
+ using Operator = OpMultiplyAddSaturate;
1074
+ using ArchTag = arch::Sm80;
1075
+
1076
+ static int const kSparse = 2;
1077
+
1078
+ static int const kMetaSizeInBits = 2;
1079
+
1080
+ static int const kMaxID2 = 1;
1081
+
1082
+ /// Computes multiply-add
1083
+ CUTLASS_HOST_DEVICE
1084
+ void operator()(
1085
+ FragmentC &d,
1086
+ FragmentA const &a,
1087
+ FragmentB const &b,
1088
+ FragmentC const &c,
1089
+ uint32_t const &E,
1090
+ int const id2
1091
+ ) const {
1092
+
1093
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
1094
+
1095
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1096
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1097
+
1098
+ int const *C = reinterpret_cast<int const *>(&c);
1099
+ int *D = reinterpret_cast<int *>(&d);
1100
+
1101
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
1102
+ if (id2 == 0) {
1103
+ asm volatile(
1104
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
1105
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
1106
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1107
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
1108
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
1109
+ } else {
1110
+ assert(0);
1111
+ }
1112
+ #else
1113
+ if (id2 == 0) {
1114
+ asm volatile(
1115
+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
1116
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
1117
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1118
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
1119
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
1120
+ } else {
1121
+ assert(0);
1122
+ }
1123
+ #endif
1124
+
1125
+ #else
1126
+
1127
+ CUTLASS_UNUSED(a);
1128
+ CUTLASS_UNUSED(b);
1129
+ CUTLASS_UNUSED(c);
1130
+ CUTLASS_UNUSED(d);
1131
+ assert(0);
1132
+ #endif
1133
+ }
1134
+ };
1135
+
1136
+ /// Matrix multiply-add operation: S32 = U4 * U4 + S32
1137
+ template <>
1138
+ struct SparseMma<
1139
+ gemm::GemmShape<16,8,128>,
1140
+ 32,
1141
+ cutlass::uint4b_t,
1142
+ layout::RowMajor,
1143
+ cutlass::uint4b_t,
1144
+ layout::ColumnMajor,
1145
+ int,
1146
+ layout::RowMajor,
1147
+ OpMultiplyAddSaturate,
1148
+ SPFormatType::Thread> {
1149
+
1150
+ using Shape = gemm::GemmShape<16,8,128>;
1151
+
1152
+ using ElementA = cutlass::uint4b_t;
1153
+ using LayoutA = layout::RowMajor;
1154
+ using FragmentA = Array<cutlass::uint4b_t, 32>;
1155
+
1156
+ using ElementB = cutlass::uint4b_t;
1157
+ using LayoutB = layout::ColumnMajor;
1158
+ using FragmentB = Array<cutlass::uint4b_t, 32>;
1159
+
1160
+ using ElementC = int;
1161
+ using LayoutC = layout::RowMajor;
1162
+ using FragmentC = Array<int, 4>;
1163
+
1164
+ using FragmentE = uint32_t;
1165
+
1166
+ using Operator = OpMultiplyAddSaturate;
1167
+ using ArchTag = arch::Sm80;
1168
+
1169
+ static int const kSparse = 2;
1170
+
1171
+ static int const kMetaSizeInBits = 2;
1172
+
1173
+ static int const kMaxID2 = 1;
1174
+
1175
+ /// Computes multiply-add
1176
+ CUTLASS_HOST_DEVICE
1177
+ void operator()(
1178
+ FragmentC &d,
1179
+ FragmentA const &a,
1180
+ FragmentB const &b,
1181
+ FragmentC const &c,
1182
+ uint32_t const &E,
1183
+ int const id2
1184
+ ) const {
1185
+
1186
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
1187
+
1188
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
1189
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
1190
+
1191
+ int const *C = reinterpret_cast<int const *>(&c);
1192
+ int *D = reinterpret_cast<int *>(&d);
1193
+
1194
+ #if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
1195
+ if (id2 == 0) {
1196
+ asm volatile(
1197
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
1198
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
1199
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1200
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
1201
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
1202
+ } else {
1203
+ assert(0);
1204
+ }
1205
+ #else
1206
+ if (id2 == 0) {
1207
+ asm volatile(
1208
+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
1209
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
1210
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
1211
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
1212
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
1213
+ } else {
1214
+ assert(0);
1215
+ }
1216
+ #endif
1217
+
1218
+ #else
1219
+
1220
+ CUTLASS_UNUSED(a);
1221
+ CUTLASS_UNUSED(b);
1222
+ CUTLASS_UNUSED(c);
1223
+ CUTLASS_UNUSED(d);
1224
+ assert(0);
1225
+ #endif
1226
+ }
1227
+ };
1228
+
1229
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1230
+
1231
+ } // namespace arch
1232
+ } // namespace cutlass
1233
+
1234
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Sparse matrix multiply accumulate for SM89
34
+ */
35
+
36
+ #pragma once
37
+ #include "cutlass/cutlass.h"
38
+ #include CUDA_STD_HEADER(cassert)
39
+
40
+ #include "mma.h"
41
+ #include "cutlass/layout/matrix.h"
42
+ #include "cutlass/numeric_types.h"
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)
47
+ # define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED
48
+ #endif
49
+
50
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
51
+ # if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED)
52
+ # define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED
53
+ # endif
54
+ #endif
55
+
56
+ /////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ namespace cutlass {
59
+ namespace arch {
60
+
61
+ /////////////////////////////////////////////////////////////////////////////////////////////////
62
+
63
+ /// Matrix multiply-add operation: F32 = fe4m3 * fe4m3 + F32
64
+ template <typename Operator_>
65
+ struct SparseMma<
66
+ gemm::GemmShape<16,8,64>,
67
+ 32,
68
+ cutlass::float_e4m3_t,
69
+ layout::RowMajor,
70
+ cutlass::float_e4m3_t,
71
+ layout::ColumnMajor,
72
+ float,
73
+ layout::RowMajor,
74
+ Operator_,
75
+ SPFormatType::Thread> {
76
+
77
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
78
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
79
+ "Invalid operator for SM89 FP8 instruction");
80
+
81
+ using Shape = gemm::GemmShape<16,8,64>;
82
+
83
+ using ElementA = cutlass::float_e4m3_t;
84
+ using LayoutA = layout::RowMajor;
85
+ using FragmentA = Array<ElementA, 16>;
86
+
87
+ using ElementB = cutlass::float_e4m3_t;
88
+ using LayoutB = layout::ColumnMajor;
89
+ using FragmentB = Array<ElementB, 16>;
90
+
91
+ using ElementC = float;
92
+ using LayoutC = layout::RowMajor;
93
+ using FragmentC = Array<ElementC, 4>;
94
+
95
+ using FragmentE = uint32_t;
96
+
97
+ using Operator = Operator_;
98
+ using ArchTag = arch::Sm89;
99
+
100
+ static int const kSparse = 2;
101
+
102
+ static int const kMetaSizeInBits = 2;
103
+
104
+ static int const kMaxID2 = 1;
105
+
106
+ /// Computes multiply-add
107
+ CUTLASS_HOST_DEVICE
108
+ void operator()(
109
+ FragmentC &d,
110
+ FragmentA const &a,
111
+ FragmentB const &b,
112
+ FragmentC const &c,
113
+ uint32_t const &E,
114
+ int const id2
115
+ ) const {
116
+
117
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
118
+
119
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
120
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
121
+
122
+ float const *C = reinterpret_cast<float const *>(&c);
123
+ float *D = reinterpret_cast<float *>(&d);
124
+
125
+ if (id2 == 0) {
126
+ asm volatile(
127
+ "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
128
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
129
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
130
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
131
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
132
+ }
133
+ else {
134
+ assert(0);
135
+ }
136
+ #else
137
+ CUTLASS_UNUSED(a);
138
+ CUTLASS_UNUSED(b);
139
+ CUTLASS_UNUSED(c);
140
+ CUTLASS_UNUSED(d);
141
+ assert(0);
142
+ #endif
143
+ }
144
+ };
145
+
146
+ /////////////////////////////////////////////////////////////////////////////////////////////////
147
+
148
+ /// Matrix multiply-add operation: F32 = fe4m3 * fe5m2 + F32
149
+ template <typename Operator_>
150
+ struct SparseMma<
151
+ gemm::GemmShape<16,8,64>,
152
+ 32,
153
+ cutlass::float_e4m3_t,
154
+ layout::RowMajor,
155
+ cutlass::float_e5m2_t,
156
+ layout::ColumnMajor,
157
+ float,
158
+ layout::RowMajor,
159
+ Operator_,
160
+ SPFormatType::Thread> {
161
+
162
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
163
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
164
+ "Invalid operator for SM89 FP8 instruction");
165
+
166
+ using Shape = gemm::GemmShape<16,8,64>;
167
+
168
+ using ElementA = cutlass::float_e4m3_t;
169
+ using LayoutA = layout::RowMajor;
170
+ using FragmentA = Array<ElementA, 16>;
171
+
172
+ using ElementB = cutlass::float_e5m2_t;
173
+ using LayoutB = layout::ColumnMajor;
174
+ using FragmentB = Array<ElementB, 16>;
175
+
176
+ using ElementC = float;
177
+ using LayoutC = layout::RowMajor;
178
+ using FragmentC = Array<ElementC, 4>;
179
+
180
+ using FragmentE = uint32_t;
181
+
182
+ using Operator = Operator_;
183
+ using ArchTag = arch::Sm89;
184
+
185
+ static int const kSparse = 2;
186
+
187
+ static int const kMetaSizeInBits = 2;
188
+
189
+ static int const kMaxID2 = 1;
190
+
191
+ /// Computes multiply-add
192
+ CUTLASS_HOST_DEVICE
193
+ void operator()(
194
+ FragmentC &d,
195
+ FragmentA const &a,
196
+ FragmentB const &b,
197
+ FragmentC const &c,
198
+ uint32_t const &E,
199
+ int const id2
200
+ ) const {
201
+
202
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
203
+
204
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
205
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
206
+
207
+ float const *C = reinterpret_cast<float const *>(&c);
208
+ float *D = reinterpret_cast<float *>(&d);
209
+
210
+ if (id2 == 0) {
211
+ asm volatile(
212
+ "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
213
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
214
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
215
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
216
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
217
+ }
218
+ else {
219
+ assert(0);
220
+ }
221
+ #else
222
+ CUTLASS_UNUSED(a);
223
+ CUTLASS_UNUSED(b);
224
+ CUTLASS_UNUSED(c);
225
+ CUTLASS_UNUSED(d);
226
+ assert(0);
227
+ #endif
228
+ }
229
+ };
230
+
231
+ /////////////////////////////////////////////////////////////////////////////////////////////////
232
+
233
+ /// Matrix multiply-add operation: F32 = fe5m2 * fe4m3 + F32
234
+ template <typename Operator_>
235
+ struct SparseMma<
236
+ gemm::GemmShape<16,8,64>,
237
+ 32,
238
+ cutlass::float_e5m2_t,
239
+ layout::RowMajor,
240
+ cutlass::float_e4m3_t,
241
+ layout::ColumnMajor,
242
+ float,
243
+ layout::RowMajor,
244
+ Operator_,
245
+ SPFormatType::Thread> {
246
+
247
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
248
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
249
+ "Invalid operator for SM89 FP8 instruction");
250
+
251
+ using Shape = gemm::GemmShape<16,8,64>;
252
+
253
+ using ElementA = cutlass::float_e5m2_t;
254
+ using LayoutA = layout::RowMajor;
255
+ using FragmentA = Array<ElementA, 16>;
256
+
257
+ using ElementB = cutlass::float_e4m3_t;
258
+ using LayoutB = layout::ColumnMajor;
259
+ using FragmentB = Array<ElementB, 16>;
260
+
261
+ using ElementC = float;
262
+ using LayoutC = layout::RowMajor;
263
+ using FragmentC = Array<ElementC, 4>;
264
+
265
+ using FragmentE = uint32_t;
266
+
267
+ using Operator = Operator_;
268
+ using ArchTag = arch::Sm89;
269
+
270
+ static int const kSparse = 2;
271
+
272
+ static int const kMetaSizeInBits = 2;
273
+
274
+ static int const kMaxID2 = 1;
275
+
276
+ /// Computes multiply-add
277
+ CUTLASS_HOST_DEVICE
278
+ void operator()(
279
+ FragmentC &d,
280
+ FragmentA const &a,
281
+ FragmentB const &b,
282
+ FragmentC const &c,
283
+ uint32_t const &E,
284
+ int const id2
285
+ ) const {
286
+
287
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
288
+
289
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
290
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
291
+
292
+ float const *C = reinterpret_cast<float const *>(&c);
293
+ float *D = reinterpret_cast<float *>(&d);
294
+
295
+ if (id2 == 0) {
296
+ asm volatile(
297
+ "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
298
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
299
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
300
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
301
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
302
+ }
303
+ else {
304
+ assert(0);
305
+ }
306
+ #else
307
+ CUTLASS_UNUSED(a);
308
+ CUTLASS_UNUSED(b);
309
+ CUTLASS_UNUSED(c);
310
+ CUTLASS_UNUSED(d);
311
+ assert(0);
312
+ #endif
313
+ }
314
+ };
315
+
316
+ /////////////////////////////////////////////////////////////////////////////////////////////////
317
+
318
+ /// Matrix multiply-add operation: F32 = fe5m2 * fe5m2 + F32
319
+ template <typename Operator_>
320
+ struct SparseMma<
321
+ gemm::GemmShape<16,8,64>,
322
+ 32,
323
+ cutlass::float_e5m2_t,
324
+ layout::RowMajor,
325
+ cutlass::float_e5m2_t,
326
+ layout::ColumnMajor,
327
+ float,
328
+ layout::RowMajor,
329
+ Operator_,
330
+ SPFormatType::Thread> {
331
+
332
+ static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
333
+ platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
334
+ "Invalid operator for SM89 FP8 instruction");
335
+
336
+ using Shape = gemm::GemmShape<16,8,64>;
337
+
338
+ using ElementA = cutlass::float_e5m2_t;
339
+ using LayoutA = layout::RowMajor;
340
+ using FragmentA = Array<ElementA, 16>;
341
+
342
+ using ElementB = cutlass::float_e5m2_t;
343
+ using LayoutB = layout::ColumnMajor;
344
+ using FragmentB = Array<ElementB, 16>;
345
+
346
+ using ElementC = float;
347
+ using LayoutC = layout::RowMajor;
348
+ using FragmentC = Array<ElementC, 4>;
349
+
350
+ using FragmentE = uint32_t;
351
+
352
+ using Operator = Operator_;
353
+ using ArchTag = arch::Sm89;
354
+
355
+ static int const kSparse = 2;
356
+
357
+ static int const kMetaSizeInBits = 2;
358
+
359
+ static int const kMaxID2 = 1;
360
+
361
+ /// Computes multiply-add
362
+ CUTLASS_HOST_DEVICE
363
+ void operator()(
364
+ FragmentC &d,
365
+ FragmentA const &a,
366
+ FragmentB const &b,
367
+ FragmentC const &c,
368
+ uint32_t const &E,
369
+ int const id2
370
+ ) const {
371
+
372
+ #if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
373
+
374
+ uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
375
+ uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
376
+
377
+ float const *C = reinterpret_cast<float const *>(&c);
378
+ float *D = reinterpret_cast<float *>(&d);
379
+
380
+ if (id2 == 0) {
381
+ asm volatile(
382
+ "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
383
+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
384
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
385
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
386
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
387
+ }
388
+ else {
389
+ assert(0);
390
+ }
391
+ #else
392
+ CUTLASS_UNUSED(a);
393
+ CUTLASS_UNUSED(b);
394
+ CUTLASS_UNUSED(c);
395
+ CUTLASS_UNUSED(d);
396
+ assert(0);
397
+ #endif
398
+ }
399
+ };
400
+
401
+ /////////////////////////////////////////////////////////////////////////////////////////////////
402
+
403
+ } // namespace arch
404
+ } // namespace cutlass
405
+
406
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief PTX for CTA Reconfiguration
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #if defined(__CUDACC_RTC__)
40
+ #include <cuda/std/cstdint>
41
+ #else
42
+ #include <cstdint>
43
+ #endif
44
+
45
+ #ifndef CUDA_CTA_RECONFIG_ACTIVATED
46
+ #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \
47
+ (__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \
48
+ || (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \
49
+ || (__CUDA_ARCH__ == 1010 && defined(__CUDA_ARCH_FEAT_SM101_ALL)) \
50
+ || (__CUDA_ARCH__ == 1030 && defined(__CUDA_ARCH_FEAT_SM103_ALL)) \
51
+ || (__CUDA_ARCH__ == 1200 && defined(__CUDA_ARCH_FEAT_SM120_ALL)) \
52
+ || (__CUDA_ARCH__ == 1210 && defined(__CUDA_ARCH_FEAT_SM121_ALL)) \
53
+ )
54
+ #define CUDA_CTA_RECONFIG_ACTIVATED 1
55
+ #endif
56
+
57
+ #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \
58
+ (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \
59
+ || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \
60
+ || (__CUDA_ARCH__ == 1030 && CUDA_ARCH_FAMILY(1030)) \
61
+ || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \
62
+ || (__CUDA_ARCH__ == 1210 && CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) \
63
+ )
64
+ #define CUDA_CTA_RECONFIG_ACTIVATED 1
65
+ #endif
66
+
67
+ #endif
68
+
69
+ namespace cutlass {
70
+ namespace arch {
71
+
72
+ template<uint32_t RegCount>
73
+ CUTLASS_DEVICE
74
+ void warpgroup_reg_alloc(){
75
+ #if CUDA_CTA_RECONFIG_ACTIVATED
76
+ asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
77
+ #endif
78
+ }
79
+
80
+ template<uint32_t RegCount>
81
+ CUTLASS_DEVICE
82
+ void warpgroup_reg_dealloc(){
83
+ #if CUDA_CTA_RECONFIG_ACTIVATED
84
+ asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
85
+ #endif
86
+ }
87
+
88
+ } // namespace arch
89
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates exposing SIMD operators
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/arch/array.h"
38
+ #include "cutlass/arch/numeric_types.h"
39
+
40
+ namespace cutlass {
41
+ namespace arch {
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ //
46
+ // Element-wise operators
47
+ //
48
+
49
+ CUTLASS_HOST_DEVICE
50
+ template <typename T, int N>
51
+ Array<T, N> operator*(Array<T, N> const &a, Array<T, N> const &b) {
52
+ Array<T, N> d;
53
+ CUTLASS_PRAGMA_UNROLL
54
+ for (int i = 0; i < N; ++i) {
55
+ d[i] = a[i] * b[i];
56
+ }
57
+ return d;
58
+ }
59
+
60
+ CUTLASS_HOST_DEVICE
61
+ template <typename T, int N>
62
+ Array<T, N> operator+(Array<T, N> const &a, Array<T, N> const &b) {
63
+ Array<T, N> d;
64
+ CUTLASS_PRAGMA_UNROLL
65
+ for (int i = 0; i < N; ++i) {
66
+ d[i] = a[i] + b[i];
67
+ }
68
+ return d;
69
+ }
70
+
71
+ CUTLASS_HOST_DEVICE
72
+ template <typename T, int N>
73
+ Array<T, N> operator-(Array<T, N> const &a, Array<T, N> const &b) {
74
+ Array<T, N> d;
75
+ CUTLASS_PRAGMA_UNROLL
76
+ for (int i = 0; i < N; ++i) {
77
+ d[i] = a[i] - b[i];
78
+ }
79
+ return d;
80
+ }
81
+
82
+ /////////////////////////////////////////////////////////////////////////////////////////////////
83
+
84
+ //
85
+ // Multiply-accumulate operators
86
+ //
87
+
88
+ CUTLASS_HOST_DEVICE
89
+ template <typename T, int N>
90
+ Array<T, N> mac(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) {
91
+ Array<T, N> d;
92
+ CUTLASS_PRAGMA_UNROLL
93
+ for (int i = 0; i < N; ++i) {
94
+ d[i] = a[i] * b[i] + c[i];
95
+ }
96
+ return d;
97
+ }
98
+
99
+ /////////////////////////////////////////////////////////////////////////////////////////////////
100
+
101
+ //
102
+ // Dot product operator
103
+ //
104
+
105
+ CUTLASS_HOST_DEVICE
106
+ template <typename Element, typename Accumulator, int N>
107
+ Accumulator dot(Array<T, N> const &a, Array<T, N> const &b, Accumulator accum) {
108
+ CUTLASS_PRAGMA_UNROLL
109
+ for (int i = 0; i < N; ++i) {
110
+ accum += a[i] * b[i];
111
+ }
112
+ return accum;
113
+ }
114
+
115
+ /////////////////////////////////////////////////////////////////////////////////////////////////
116
+
117
+ } // namespace arch
118
+ } // namespace cutlass
119
+
120
+ /////////////////////////////////////////////////////////////////////////////////////////////////
121
+
122
+ #include "simd_sm60.h"
123
+ #include "simd_sm61.h"
124
+
125
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates exposing SIMD operators for SM60
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "simd.h"
38
+
39
+ namespace cutlass {
40
+ namespace arch {
41
+
42
+ /////////////////////////////////////////////////////////////////////////////////////////////////
43
+
44
+ //
45
+ // Element-wise operators - specialized for half_t x 2
46
+ //
47
+
48
+ CUTLASS_HOST_DEVICE
49
+ template <>
50
+ Array<half_t, 2> operator*(Array<half_t, 2> const &a, Array<half_t, 2> const &b) {
51
+ Array<half_t, 2> d;
52
+
53
+ return d;
54
+ }
55
+
56
+ CUTLASS_HOST_DEVICE
57
+ template <>
58
+ Array<half_t, 2> operator+(AArray<half_t, 2> const &a, Array<half_t, 2> const &b) {
59
+ Array<half_t, 2> d;
60
+
61
+ return d;
62
+ }
63
+
64
+ CUTLASS_HOST_DEVICE
65
+ template <>
66
+ Array<half_t, 2> operator-(Array<half_t, 2> const &a, Array<half_t, 2> const &b) {
67
+ Array<T, N> d;
68
+
69
+ return d;
70
+ }
71
+
72
+ /////////////////////////////////////////////////////////////////////////////////////////////////
73
+
74
+ /// Multiply-accumulate operators - specialized for half_t x 2
75
+ CUTLASS_HOST_DEVICE
76
+ template <>
77
+ Array<half_t, 2> mac(Array<half_t, 2> const &a, Array<half_t, 2> const &b, Array<half_t, 2> const &c) {
78
+ Array<half_t, 2> d;
79
+
80
+ return d;
81
+ }
82
+
83
+ /////////////////////////////////////////////////////////////////////////////////////////////////
84
+
85
+ /// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t
86
+ CUTLASS_HOST_DEVICE
87
+ template <>
88
+ half_t dot(Array<half_t, 2> const &a, Array<half_t, 2> const &b, half_t accum) {
89
+
90
+ return accum;
91
+ }
92
+
93
+ /// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float
94
+ CUTLASS_HOST_DEVICE
95
+ template <>
96
+ float dot(Array<half_t, 2> const &a, Array<half_t, 2> const &b, float accum) {
97
+
98
+ return accum;
99
+ }
100
+
101
+ /////////////////////////////////////////////////////////////////////////////////////////////////
102
+
103
+ } // namespace arch
104
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates exposing SIMD operators for SM61
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "simd.h"
38
+
39
+ namespace cutlass {
40
+ namespace arch {
41
+
42
+ /////////////////////////////////////////////////////////////////////////////////////////////////
43
+
44
+ /// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t
45
+ CUTLASS_HOST_DEVICE
46
+ template <>
47
+ int32_t dot(Array<int8_t, 4> const &a, Array<int8_t, 4> const &b, int32_t accum) {
48
+
49
+ return accum;
50
+ }
51
+
52
+ /// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t
53
+ CUTLASS_HOST_DEVICE
54
+ template <>
55
+ int32_t dot(Array<uint8_t, 4> const &a, Array<int8_t, 4> const &b, int32_t accum) {
56
+
57
+ return accum;
58
+ }
59
+
60
+ /// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t
61
+ CUTLASS_HOST_DEVICE
62
+ template <>
63
+ int32_t dot(Array<int8_t, 4> const &a, Array<uint8_t, 4> const &b, int32_t accum) {
64
+
65
+ return accum;
66
+ }
67
+
68
+ /// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t
69
+ CUTLASS_HOST_DEVICE
70
+ template <>
71
+ int32_t dot(Array<uint8_t, 4> const &a, Array<uint8_t, 4> const &b, int32_t accum) {
72
+
73
+ return accum;
74
+ }
75
+
76
+ /////////////////////////////////////////////////////////////////////////////////////////////////
77
+
78
+ /// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t
79
+ CUTLASS_HOST_DEVICE
80
+ template <>
81
+ int32_t dot(Array<int16_t, 2> const &a, Array<int8_t, 2> const &b, int32_t accum) {
82
+
83
+ return accum;
84
+ }
85
+
86
+ /// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t
87
+ CUTLASS_HOST_DEVICE
88
+ template <>
89
+ int32_t dot(Array<uint16_t, 2> const &a, Array<int8_t, 2> const &b, int32_t accum) {
90
+
91
+ return accum;
92
+ }
93
+
94
+ /// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t
95
+ CUTLASS_HOST_DEVICE
96
+ template <>
97
+ int32_t dot(Array<int16_t, 2> const &a, Array<uint8_t, 2> const &b, int32_t accum) {
98
+
99
+ return accum;
100
+ }
101
+
102
+ /// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t
103
+ CUTLASS_HOST_DEVICE
104
+ template <>
105
+ int32_t dot(Array<uint16_t, 2> const &a, Array<uint8_t, 2> const &b, int32_t accum) {
106
+
107
+ return accum;
108
+ }
109
+
110
+ /////////////////////////////////////////////////////////////////////////////////////////////////
111
+
112
+ /// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t
113
+ CUTLASS_HOST_DEVICE
114
+ template <>
115
+ int32_t dot(Array<int16_t, 2> const &a, Array<int16_t, 2> const &b, int32_t accum) {
116
+
117
+ return accum;
118
+ }
119
+
120
+ /// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t
121
+ CUTLASS_HOST_DEVICE
122
+ template <>
123
+ int32_t dot(Array<uint16_t, 2> const &a, Array<int16_t, 2> const &b, int32_t accum) {
124
+
125
+ return accum;
126
+ }
127
+
128
+ /// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t
129
+ CUTLASS_HOST_DEVICE
130
+ template <>
131
+ int32_t dot(Array<int16_t, 2> const &a, Array<uint16_t, 2> const &b, int32_t accum) {
132
+
133
+ return accum;
134
+ }
135
+
136
+ /// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t
137
+ CUTLASS_HOST_DEVICE
138
+ template <>
139
+ int32_t dot(Array<uint16_t, 2> const &a, Array<uint16_t, 2> const &b, int32_t accum) {
140
+
141
+ return accum;
142
+ }
143
+
144
+ /////////////////////////////////////////////////////////////////////////////////////////////////
145
+
146
+ } // namespace arch
147
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp ADDED
@@ -0,0 +1,1271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Synchronization event logging for race condition debugging.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/detail/helper_macros.hpp"
38
+ #include "cutlass/cutlass.h"
39
+ #if defined(__CUDACC_RTC__)
40
+ #include CUDA_STD_HEADER(cstdint)
41
+ #else
42
+ #include <cstdint>
43
+ #endif
44
+
45
+ #if !defined(__CUDACC_RTC__)
46
+ #include <mutex>
47
+ #include <vector>
48
+ #endif
49
+
50
+ namespace cutlass {
51
+ namespace arch {
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
56
+
57
+ constexpr uint32_t synclog_cap = 1 << 26;
58
+
59
+ inline std::mutex synclog_mutex;
60
+ inline std::vector<uint32_t*> synclog_buf_list;
61
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
62
+ CUTLASS_DEVICE uint32_t* synclog_buf;
63
+ #endif
64
+
65
+ CUTLASS_DEVICE
66
+ uint32_t* synclog_alloc(uint32_t n) {
67
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
68
+ uint32_t* buf = synclog_buf;
69
+ if (buf == nullptr) return nullptr;
70
+ uint32_t last = atomicAdd(&buf[0], n);
71
+ if (last + n < synclog_cap) return buf + last + 1;
72
+ if (last >= synclog_cap) atomicAdd(&buf[0], -n);
73
+ #endif
74
+ return nullptr;
75
+ }
76
+
77
+ CUTLASS_DEVICE
78
+ void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) {
79
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
80
+ uint64_t time64;
81
+ asm volatile (
82
+ "mov.u64 %0, %%globaltimer;\n"
83
+ : "=l"(time64) :
84
+ );
85
+ to[0] = header;
86
+ to[1] = line;
87
+ to[2] = time64;
88
+ to[3] = time64 >> 32;
89
+ to[4] = threadIdx.x;
90
+ to[5] = threadIdx.y;
91
+ to[6] = threadIdx.z;
92
+ to[7] = blockIdx.x;
93
+ to[8] = blockIdx.y;
94
+ to[9] = blockIdx.z;
95
+ #endif
96
+ }
97
+
98
+ constexpr uint32_t synclog_header_none = 0;
99
+ constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3;
100
+
101
+ constexpr bool synclog_enable_syncthreads = true;
102
+ constexpr uint32_t synclog_header_syncthreads = 1;
103
+ constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0;
104
+
105
+ constexpr bool synclog_enable_syncwarp = true;
106
+ constexpr uint32_t synclog_header_syncwarp = 2;
107
+ constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0;
108
+
109
+ constexpr bool synclog_enable_named_barrier_arrive_and_wait = true;
110
+ constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3;
111
+ constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2;
112
+
113
+ constexpr bool synclog_enable_named_barrier_arrive = true;
114
+ constexpr uint32_t synclog_header_named_barrier_arrive = 4;
115
+ constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2;
116
+
117
+ constexpr bool synclog_enable_cluster_barrier_init = true;
118
+ constexpr uint32_t synclog_header_cluster_barrier_init = 5;
119
+ constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2;
120
+
121
+ constexpr bool synclog_enable_cluster_barrier_wait = true;
122
+ constexpr uint32_t synclog_header_cluster_barrier_wait = 6;
123
+ constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 2;
124
+ constexpr bool synclog_enable_cluster_barrier_test_wait = true;
125
+ constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7;
126
+ constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 3;
127
+ constexpr bool synclog_enable_cluster_barrier_try_wait = true;
128
+ constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8;
129
+ constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 2;
130
+ constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true;
131
+ constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9;
132
+ constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 3;
133
+ constexpr bool synclog_enable_cluster_barrier_arrive = true;
134
+ constexpr uint32_t synclog_header_cluster_barrier_arrive = 10;
135
+ constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 1;
136
+ constexpr bool synclog_enable_cluster_barrier_invalidate = true;
137
+ constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11;
138
+ constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 1;
139
+ constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true;
140
+ constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12;
141
+ constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 2;
142
+ constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true;
143
+ constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13;
144
+ constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 4;
145
+ constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true;
146
+ constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14;
147
+ constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 2;
148
+ constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true;
149
+ constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15;
150
+ constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 4;
151
+ constexpr bool synclog_enable_fence_barrier_init = true;
152
+ constexpr uint32_t synclog_header_fence_barrier_init = 16;
153
+ constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0;
154
+
155
+ constexpr bool synclog_enable_fence_view_async_shared = true;
156
+ constexpr uint32_t synclog_header_fence_view_async_shared = 17;
157
+ constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0;
158
+
159
+ constexpr bool synclog_enable_cp_async_wait = true;
160
+ constexpr uint32_t synclog_header_cp_async_wait = 18;
161
+ constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1;
162
+
163
+ constexpr bool synclog_enable_cp_async_wait_all = true;
164
+ constexpr uint32_t synclog_header_cp_async_wait_all = 19;
165
+ constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0;
166
+
167
+ constexpr bool synclog_enable_cp_async_fence = true;
168
+ constexpr uint32_t synclog_header_cp_async_fence = 20;
169
+ constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0;
170
+
171
+ constexpr bool synclog_enable_cp_async_nan = true;
172
+ constexpr uint32_t synclog_header_cp_async_nan = 21;
173
+ constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4;
174
+
175
+ constexpr bool synclog_enable_cp_async_zfill = true;
176
+ constexpr uint32_t synclog_header_cp_async_zfill = 22;
177
+ constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5;
178
+
179
+ constexpr bool synclog_enable_cp_async = true;
180
+ constexpr uint32_t synclog_header_cp_async = 23;
181
+ constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5;
182
+
183
+ constexpr bool synclog_enable_tma_load = true;
184
+ constexpr uint32_t synclog_header_tma_load = 24;
185
+ constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4;
186
+
187
+ constexpr bool synclog_enable_tma_store = true;
188
+ constexpr uint32_t synclog_header_tma_store = 25;
189
+ constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3;
190
+
191
+ constexpr bool synclog_enable_tma_store_arrive = true;
192
+ constexpr uint32_t synclog_header_tma_store_arrive = 26;
193
+ constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0;
194
+
195
+ constexpr bool synclog_enable_tma_store_wait = true;
196
+ constexpr uint32_t synclog_header_tma_store_wait = 27;
197
+ constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1;
198
+
199
+ constexpr bool synclog_enable_warpgroup_arrive = true;
200
+ constexpr uint32_t synclog_header_warpgroup_arrive = 28;
201
+ constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0;
202
+
203
+ constexpr bool synclog_enable_warpgroup_wait = true;
204
+ constexpr uint32_t synclog_header_warpgroup_wait = 29;
205
+ constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1;
206
+
207
+ constexpr bool synclog_enable_warpgroup_commit_batch = true;
208
+ constexpr uint32_t synclog_header_warpgroup_commit_batch = 30;
209
+ constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0;
210
+
211
+ constexpr bool synclog_enable_wgmma_reg_smem = true;
212
+ constexpr uint32_t synclog_header_wgmma_reg_smem = 31;
213
+ constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2;
214
+
215
+ constexpr bool synclog_enable_wgmma_smem_smem = true;
216
+ constexpr uint32_t synclog_header_wgmma_smem_smem = 32;
217
+ constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4;
218
+
219
+ constexpr bool synclog_enable_cpasync_barrier_arrive = true;
220
+ constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33;
221
+ constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 1;
222
+ CUTLASS_DEVICE
223
+ bool synclog_condition_emit() {
224
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
225
+ return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 &&
226
+ blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0;
227
+ #else
228
+ return 0;
229
+ #endif
230
+ }
231
+
232
+ CUTLASS_DEVICE
233
+ bool synclog_condition_print() {
234
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
235
+ return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 &&
236
+ blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0;
237
+ #else
238
+ return false;
239
+ #endif
240
+ }
241
+
242
+ CUTLASS_DEVICE
243
+ void synclog_print_prefix(char const* header, uint32_t at) {
244
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
245
+ uint32_t line = synclog_buf[at + 1];
246
+ uint32_t timeLo = synclog_buf[at + 2];
247
+ uint32_t timeHi = synclog_buf[at + 3];
248
+ uint32_t threadIdxX = synclog_buf[at + 4];
249
+ uint32_t threadIdxY = synclog_buf[at + 5];
250
+ uint32_t threadIdxZ = synclog_buf[at + 6];
251
+ uint32_t blockIdxX = synclog_buf[at + 7];
252
+ uint32_t blockIdxY = synclog_buf[at + 8];
253
+ uint32_t blockIdxZ = synclog_buf[at + 9];
254
+ printf(
255
+ "%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ",
256
+ header, line,
257
+ (uint64_t)timeHi << 32 | timeLo,
258
+ threadIdxX, threadIdxY, threadIdxZ,
259
+ blockIdxX, blockIdxY, blockIdxZ
260
+ );
261
+ #endif
262
+ }
263
+
264
+ CUTLASS_DEVICE
265
+ void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) {
266
+ CUTLASS_UNUSED(hi);
267
+ uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4;
268
+ printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep);
269
+ }
270
+
271
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
272
+
273
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
274
+
275
+ inline void synclog_setup() {
276
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
277
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
278
+ std::scoped_lock lock(synclog_mutex);
279
+ auto fail = [] () {
280
+ fprintf(stderr, "synclog_setup() failed\n");
281
+ std::terminate();
282
+ };
283
+ int orig_device = 0;
284
+ if (cudaGetDevice(&orig_device) != cudaSuccess) {
285
+ fail();
286
+ }
287
+ int device_count = 0;
288
+ if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
289
+ fail();
290
+ }
291
+ if (synclog_buf_list.size() == 0) {
292
+ for (int device = 0; device < device_count; device++) {
293
+ uint32_t* buf = 0;
294
+ if (cudaSetDevice(device) != cudaSuccess ||
295
+ cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) {
296
+ fail();
297
+ }
298
+ synclog_buf_list.push_back(buf);
299
+ }
300
+ }
301
+ for (int device = 0; device < device_count; device++) {
302
+ uint32_t* buf = synclog_buf_list.at(device);
303
+ if (cudaSetDevice(device) != cudaSuccess ||
304
+ cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess ||
305
+ cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) {
306
+ fail();
307
+ }
308
+ }
309
+ if (cudaSetDevice(orig_device) != cudaSuccess) {
310
+ fail();
311
+ }
312
+ #endif
313
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
314
+ }
315
+
316
+ CUTLASS_DEVICE
317
+ void synclog_emit_syncthreads(uint32_t line) {
318
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
319
+ if constexpr (!synclog_enable_syncthreads) return;
320
+ if (!synclog_condition_emit()) return;
321
+ uint32_t* to = synclog_alloc(synclog_length_syncthreads);
322
+ if (to == nullptr) return;
323
+ synclog_emit_prefix(to, synclog_header_syncthreads, line);
324
+ #else
325
+ CUTLASS_UNUSED(line);
326
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
327
+ }
328
+
329
+ CUTLASS_DEVICE
330
+ void synclog_emit_syncwarp(uint32_t line) {
331
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
332
+ if constexpr (!synclog_enable_syncwarp) return;
333
+ if (!synclog_condition_emit()) return;
334
+ uint32_t* to = synclog_alloc(synclog_length_syncwarp);
335
+ if (to == nullptr) return;
336
+ synclog_emit_prefix(to, synclog_header_syncwarp, line);
337
+ #else
338
+ CUTLASS_UNUSED(line);
339
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
340
+ }
341
+
342
+ CUTLASS_DEVICE
343
+ void synclog_emit_named_barrier_arrive_and_wait(
344
+ uint32_t line,
345
+ uint32_t num_threads,
346
+ uint32_t barrier_id) {
347
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
348
+ if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return;
349
+ if (!synclog_condition_emit()) return;
350
+ uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait);
351
+ if (to == nullptr) return;
352
+ synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line);
353
+ to[synclog_length_prefix + 0] = num_threads;
354
+ to[synclog_length_prefix + 1] = barrier_id;
355
+ #else
356
+ CUTLASS_UNUSED(line);
357
+ CUTLASS_UNUSED(num_threads);
358
+ CUTLASS_UNUSED(barrier_id);
359
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
360
+ }
361
+
362
+ CUTLASS_DEVICE
363
+ void synclog_emit_named_barrier_arrive(
364
+ uint32_t line,
365
+ uint32_t num_threads,
366
+ uint32_t barrier_id) {
367
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
368
+ if constexpr (!synclog_enable_named_barrier_arrive) return;
369
+ if (!synclog_condition_emit()) return;
370
+ uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive);
371
+ if (to == nullptr) return;
372
+ synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line);
373
+ to[synclog_length_prefix + 0] = num_threads;
374
+ to[synclog_length_prefix + 1] = barrier_id;
375
+ #else
376
+ CUTLASS_UNUSED(line);
377
+ CUTLASS_UNUSED(num_threads);
378
+ CUTLASS_UNUSED(barrier_id);
379
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
380
+ }
381
+
382
+ CUTLASS_DEVICE
383
+ void synclog_emit_cluster_barrier_init(
384
+ uint32_t line,
385
+ uint32_t smem_addr,
386
+ uint32_t arrive_count) {
387
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
388
+ if constexpr (!synclog_enable_cluster_barrier_init) return;
389
+ if (!synclog_condition_emit()) return;
390
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init);
391
+ if (to == nullptr) return;
392
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line);
393
+ to[synclog_length_prefix + 0] = smem_addr;
394
+ to[synclog_length_prefix + 1] = arrive_count;
395
+ #else
396
+ CUTLASS_UNUSED(line);
397
+ CUTLASS_UNUSED(smem_addr);
398
+ CUTLASS_UNUSED(arrive_count);
399
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
400
+ }
401
+
402
+ CUTLASS_DEVICE
403
+ void synclog_emit_cluster_barrier_wait(
404
+ uint32_t line,
405
+ uint32_t smem_addr,
406
+ uint32_t phase) {
407
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
408
+ if constexpr (!synclog_enable_cluster_barrier_wait) return;
409
+ if (!synclog_condition_emit()) return;
410
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait);
411
+ if (to == nullptr) return;
412
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line);
413
+ to[synclog_length_prefix + 0] = smem_addr;
414
+ to[synclog_length_prefix + 1] = phase;
415
+ #else
416
+ CUTLASS_UNUSED(line);
417
+ CUTLASS_UNUSED(smem_addr);
418
+ CUTLASS_UNUSED(phase);
419
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
420
+ }
421
+
422
+ CUTLASS_DEVICE
423
+ void synclog_emit_cluster_barrier_test_wait(
424
+ uint32_t line,
425
+ uint32_t smem_addr,
426
+ uint32_t phase,
427
+ uint32_t pred) {
428
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
429
+ if constexpr (!synclog_enable_cluster_barrier_test_wait) return;
430
+ if (!synclog_condition_emit()) return;
431
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait);
432
+ if (to == nullptr) return;
433
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line);
434
+ to[synclog_length_prefix + 0] = smem_addr;
435
+ to[synclog_length_prefix + 1] = phase;
436
+ to[synclog_length_prefix + 2] = pred;
437
+ #else
438
+ CUTLASS_UNUSED(line);
439
+ CUTLASS_UNUSED(smem_addr);
440
+ CUTLASS_UNUSED(phase);
441
+ CUTLASS_UNUSED(pred);
442
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
443
+ }
444
+
445
+ CUTLASS_DEVICE
446
+ void synclog_emit_cluster_barrier_try_wait(
447
+ uint32_t line,
448
+ uint32_t smem_addr,
449
+ uint32_t phase) {
450
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
451
+ if constexpr (!synclog_enable_cluster_barrier_try_wait) return;
452
+ if (!synclog_condition_emit()) return;
453
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait);
454
+ if (to == nullptr) return;
455
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line);
456
+ to[synclog_length_prefix + 0] = smem_addr;
457
+ to[synclog_length_prefix + 1] = phase;
458
+ #else
459
+ CUTLASS_UNUSED(line);
460
+ CUTLASS_UNUSED(smem_addr);
461
+ CUTLASS_UNUSED(phase);
462
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
463
+ }
464
+
465
+ CUTLASS_DEVICE
466
+ void synclog_emit_cluster_barrier_arrive_cluster(
467
+ uint32_t line,
468
+ uint32_t smem_addr,
469
+ uint32_t cta_id,
470
+ uint32_t pred) {
471
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
472
+ if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return;
473
+ if (!synclog_condition_emit()) return;
474
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster);
475
+ if (to == nullptr) return;
476
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line);
477
+ to[synclog_length_prefix + 0] = smem_addr;
478
+ to[synclog_length_prefix + 1] = cta_id;
479
+ to[synclog_length_prefix + 2] = pred;
480
+ #else
481
+ CUTLASS_UNUSED(line);
482
+ CUTLASS_UNUSED(smem_addr);
483
+ CUTLASS_UNUSED(cta_id);
484
+ CUTLASS_UNUSED(pred);
485
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
486
+ }
487
+
488
+ CUTLASS_DEVICE
489
+ void synclog_emit_cluster_barrier_arrive(
490
+ uint32_t line,
491
+ uint32_t smem_addr) {
492
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
493
+ if constexpr (!synclog_enable_cluster_barrier_arrive) return;
494
+ if (!synclog_condition_emit()) return;
495
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive);
496
+ if (to == nullptr) return;
497
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line);
498
+ to[synclog_length_prefix + 0] = smem_addr;
499
+ #else
500
+ CUTLASS_UNUSED(line);
501
+ CUTLASS_UNUSED(smem_addr);
502
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
503
+ }
504
+
505
+ CUTLASS_DEVICE
506
+ void synclog_emit_cluster_barrier_invalidate(
507
+ uint32_t line,
508
+ uint32_t smem_addr) {
509
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
510
+ if constexpr (!synclog_enable_cluster_barrier_invalidate) return;
511
+ if (!synclog_condition_emit()) return;
512
+ uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate);
513
+ if (to == nullptr) return;
514
+ synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line);
515
+ to[synclog_length_prefix + 0] = smem_addr;
516
+ #else
517
+ CUTLASS_UNUSED(line);
518
+ CUTLASS_UNUSED(smem_addr);
519
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
520
+ }
521
+
522
+ CUTLASS_DEVICE
523
+ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(
524
+ uint32_t line,
525
+ uint32_t smem_addr,
526
+ uint32_t transaction_bytes) {
527
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
528
+ if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return;
529
+ if (!synclog_condition_emit()) return;
530
+ uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx);
531
+ if (to == nullptr) return;
532
+ synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line);
533
+ to[synclog_length_prefix + 0] = smem_addr;
534
+ to[synclog_length_prefix + 1] = transaction_bytes;
535
+ #else
536
+ CUTLASS_UNUSED(line);
537
+ CUTLASS_UNUSED(smem_addr);
538
+ CUTLASS_UNUSED(transaction_bytes);
539
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
540
+ }
541
+
542
+ CUTLASS_DEVICE
543
+ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster(
544
+ uint32_t line,
545
+ uint32_t smem_addr,
546
+ uint32_t transaction_bytes,
547
+ uint32_t cta_id,
548
+ uint32_t pred) {
549
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
550
+ if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return;
551
+ if (!synclog_condition_emit()) return;
552
+ uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster);
553
+ if (to == nullptr) return;
554
+ synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line);
555
+ to[synclog_length_prefix + 0] = smem_addr;
556
+ to[synclog_length_prefix + 1] = transaction_bytes;
557
+ to[synclog_length_prefix + 2] = cta_id;
558
+ to[synclog_length_prefix + 3] = pred;
559
+ #else
560
+ CUTLASS_UNUSED(line);
561
+ CUTLASS_UNUSED(smem_addr);
562
+ CUTLASS_UNUSED(transaction_bytes);
563
+ CUTLASS_UNUSED(cta_id);
564
+ CUTLASS_UNUSED(pred);
565
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
566
+ }
567
+
568
+ CUTLASS_DEVICE
569
+ void synclog_emit_cluster_transaction_barrier_expect_transaction(
570
+ uint32_t line,
571
+ uint32_t smem_addr,
572
+ uint32_t transaction_bytes) {
573
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
574
+ if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return;
575
+ if (!synclog_condition_emit()) return;
576
+ uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction);
577
+ if (to == nullptr) return;
578
+ synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line);
579
+ to[synclog_length_prefix + 0] = smem_addr;
580
+ to[synclog_length_prefix + 1] = transaction_bytes;
581
+ #else
582
+ CUTLASS_UNUSED(line);
583
+ CUTLASS_UNUSED(smem_addr);
584
+ CUTLASS_UNUSED(transaction_bytes);
585
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
586
+ }
587
+
588
+ CUTLASS_DEVICE
589
+ void synclog_emit_cluster_transaction_barrier_complete_transaction(
590
+ uint32_t line,
591
+ uint32_t smem_addr,
592
+ uint32_t dst_cta_id,
593
+ uint32_t transaction_bytes,
594
+ uint32_t pred) {
595
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
596
+ if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return;
597
+ if (!synclog_condition_emit()) return;
598
+ uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction);
599
+ if (to == nullptr) return;
600
+ synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line);
601
+ to[synclog_length_prefix + 0] = smem_addr;
602
+ to[synclog_length_prefix + 1] = dst_cta_id;
603
+ to[synclog_length_prefix + 2] = transaction_bytes;
604
+ to[synclog_length_prefix + 3] = pred;
605
+ #else
606
+ CUTLASS_UNUSED(line);
607
+ CUTLASS_UNUSED(smem_addr);
608
+ CUTLASS_UNUSED(dst_cta_id);
609
+ CUTLASS_UNUSED(transaction_bytes);
610
+ CUTLASS_UNUSED(pred);
611
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
612
+ }
613
+
614
+ CUTLASS_DEVICE
615
+ void synclog_emit_fence_barrier_init(uint32_t line) {
616
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
617
+ if constexpr (!synclog_enable_fence_barrier_init) return;
618
+ if (!synclog_condition_emit()) return;
619
+ uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init);
620
+ if (to == nullptr) return;
621
+ synclog_emit_prefix(to, synclog_header_fence_barrier_init, line);
622
+ #else
623
+ CUTLASS_UNUSED(line);
624
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
625
+ }
626
+
627
+ CUTLASS_DEVICE
628
+ void synclog_emit_fence_view_async_shared(uint32_t line) {
629
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
630
+ if constexpr (!synclog_enable_fence_view_async_shared) return;
631
+ if (!synclog_condition_emit()) return;
632
+ uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared);
633
+ if (to == nullptr) return;
634
+ synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line);
635
+ #else
636
+ CUTLASS_UNUSED(line);
637
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
638
+ }
639
+
640
+ CUTLASS_DEVICE
641
+ void synclog_emit_cp_async_wait(
642
+ uint32_t line,
643
+ uint32_t n) {
644
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
645
+ if constexpr (!synclog_enable_cp_async_wait) return;
646
+ if (!synclog_condition_emit()) return;
647
+ uint32_t* to = synclog_alloc(synclog_length_cp_async_wait);
648
+ if (to == nullptr) return;
649
+ synclog_emit_prefix(to, synclog_header_cp_async_wait, line);
650
+ to[synclog_length_prefix + 0] = n;
651
+ #else
652
+ CUTLASS_UNUSED(line);
653
+ CUTLASS_UNUSED(n);
654
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
655
+ }
656
+
657
+ CUTLASS_DEVICE
658
+ void synclog_emit_cp_async_wait_all(uint32_t line) {
659
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
660
+ if constexpr (!synclog_enable_cp_async_wait_all) return;
661
+ if (!synclog_condition_emit()) return;
662
+ uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all);
663
+ if (to == nullptr) return;
664
+ synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line);
665
+ #else
666
+ CUTLASS_UNUSED(line);
667
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
668
+ }
669
+
670
+ CUTLASS_DEVICE
671
+ void synclog_emit_cp_async_fence(uint32_t line) {
672
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
673
+ if constexpr (!synclog_enable_cp_async_fence) return;
674
+ if (!synclog_condition_emit()) return;
675
+ uint32_t* to = synclog_alloc(synclog_length_cp_async_fence);
676
+ if (to == nullptr) return;
677
+ synclog_emit_prefix(to, synclog_header_cp_async_fence, line);
678
+ #else
679
+ CUTLASS_UNUSED(line);
680
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
681
+ }
682
+
683
+ CUTLASS_DEVICE
684
+ void synclog_emit_cp_async_nan(
685
+ uint32_t line,
686
+ uint32_t smem_addr,
687
+ const void* gmem_ptr,
688
+ uint32_t pred) {
689
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
690
+ if constexpr (!synclog_enable_cp_async_nan) return;
691
+ if (!synclog_condition_emit()) return;
692
+ uint32_t* to = synclog_alloc(synclog_length_cp_async_nan);
693
+ if (to == nullptr) return;
694
+ synclog_emit_prefix(to, synclog_header_cp_async_nan, line);
695
+ to[synclog_length_prefix + 0] = smem_addr;
696
+ to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr);
697
+ to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32);
698
+ to[synclog_length_prefix + 3] = pred;
699
+ #else
700
+ CUTLASS_UNUSED(line);
701
+ CUTLASS_UNUSED(smem_addr);
702
+ CUTLASS_UNUSED(gmem_ptr);
703
+ CUTLASS_UNUSED(pred);
704
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
705
+ }
706
+
707
+ CUTLASS_DEVICE
708
+ void synclog_emit_cp_async_zfill(
709
+ uint32_t line,
710
+ uint32_t smem_addr,
711
+ const void* gmem_ptr,
712
+ uint32_t pred,
713
+ uint32_t size) {
714
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
715
+ if constexpr (!synclog_enable_cp_async_zfill) return;
716
+ if (!synclog_condition_emit()) return;
717
+ uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill);
718
+ if (to == nullptr) return;
719
+ synclog_emit_prefix(to, synclog_header_cp_async_zfill, line);
720
+ to[synclog_length_prefix + 0] = smem_addr;
721
+ to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr);
722
+ to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32);
723
+ to[synclog_length_prefix + 3] = pred;
724
+ to[synclog_length_prefix + 4] = size;
725
+ #else
726
+ CUTLASS_UNUSED(line);
727
+ CUTLASS_UNUSED(smem_addr);
728
+ CUTLASS_UNUSED(gmem_ptr);
729
+ CUTLASS_UNUSED(pred);
730
+ CUTLASS_UNUSED(size);
731
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
732
+ }
733
+
734
+ CUTLASS_DEVICE
735
+ void synclog_emit_cp_async(
736
+ uint32_t line,
737
+ uint32_t smem_addr,
738
+ const void* gmem_ptr,
739
+ uint32_t pred,
740
+ uint32_t size) {
741
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
742
+ if constexpr (!synclog_enable_cp_async) return;
743
+ if (!synclog_condition_emit()) return;
744
+ uint32_t* to = synclog_alloc(synclog_length_cp_async);
745
+ if (to == nullptr) return;
746
+ synclog_emit_prefix(to, synclog_header_cp_async, line);
747
+ to[synclog_length_prefix + 0] = smem_addr;
748
+ to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr);
749
+ to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32);
750
+ to[synclog_length_prefix + 3] = pred;
751
+ to[synclog_length_prefix + 4] = size;
752
+ #else
753
+ CUTLASS_UNUSED(line);
754
+ CUTLASS_UNUSED(smem_addr);
755
+ CUTLASS_UNUSED(gmem_ptr);
756
+ CUTLASS_UNUSED(pred);
757
+ CUTLASS_UNUSED(size);
758
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
759
+ }
760
+
761
+ CUTLASS_DEVICE
762
+ void synclog_emit_tma_load(
763
+ uint32_t line,
764
+ uint64_t gmem_int_desc,
765
+ uint32_t smem_int_mbar,
766
+ uint32_t smem_int_ptr) {
767
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
768
+ if constexpr (!synclog_enable_tma_load) return;
769
+ if (!synclog_condition_emit()) return;
770
+ uint32_t* to = synclog_alloc(synclog_length_tma_load);
771
+ if (to == nullptr) return;
772
+ synclog_emit_prefix(to, synclog_header_tma_load, line);
773
+ to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc);
774
+ to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32);
775
+ to[synclog_length_prefix + 2] = smem_int_mbar;
776
+ to[synclog_length_prefix + 3] = smem_int_ptr;
777
+ #else
778
+ CUTLASS_UNUSED(line);
779
+ CUTLASS_UNUSED(gmem_int_desc);
780
+ CUTLASS_UNUSED(smem_int_mbar);
781
+ CUTLASS_UNUSED(smem_int_ptr);
782
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
783
+ }
784
+
785
+ CUTLASS_DEVICE
786
+ void synclog_emit_tma_store(
787
+ uint32_t line,
788
+ uint64_t gmem_int_desc,
789
+ uint32_t smem_int_ptr) {
790
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
791
+ if constexpr (!synclog_enable_tma_store) return;
792
+ if (!synclog_condition_emit()) return;
793
+ uint32_t* to = synclog_alloc(synclog_length_tma_store);
794
+ if (to == nullptr) return;
795
+ synclog_emit_prefix(to, synclog_header_tma_store, line);
796
+ to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc);
797
+ to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32);
798
+ to[synclog_length_prefix + 2] = smem_int_ptr;
799
+ #else
800
+ CUTLASS_UNUSED(line);
801
+ CUTLASS_UNUSED(gmem_int_desc);
802
+ CUTLASS_UNUSED(smem_int_ptr);
803
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
804
+ }
805
+
806
+ CUTLASS_DEVICE
807
+ void synclog_emit_tma_store_arrive(uint32_t line) {
808
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
809
+ if constexpr (!synclog_enable_tma_store_arrive) return;
810
+ if (!synclog_condition_emit()) return;
811
+ uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive);
812
+ if (to == nullptr) return;
813
+ synclog_emit_prefix(to, synclog_header_tma_store_arrive, line);
814
+ #else
815
+ CUTLASS_UNUSED(line);
816
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
817
+ }
818
+
819
+ CUTLASS_DEVICE
820
+ void synclog_emit_tma_store_wait(
821
+ uint32_t line,
822
+ uint32_t count) {
823
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
824
+ if constexpr (!synclog_enable_tma_store_wait) return;
825
+ if (!synclog_condition_emit()) return;
826
+ uint32_t* to = synclog_alloc(synclog_length_tma_store_wait);
827
+ if (to == nullptr) return;
828
+ synclog_emit_prefix(to, synclog_header_tma_store_wait, line);
829
+ to[synclog_length_prefix + 0] = count;
830
+ #else
831
+ CUTLASS_UNUSED(line);
832
+ CUTLASS_UNUSED(count);
833
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
834
+ }
835
+
836
+ CUTLASS_DEVICE
837
+ void synclog_emit_warpgroup_arrive(
838
+ uint32_t line) {
839
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
840
+ if constexpr (!synclog_enable_warpgroup_arrive) return;
841
+ if (!synclog_condition_emit()) return;
842
+ uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive);
843
+ if (to == nullptr) return;
844
+ synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line);
845
+ #else
846
+ CUTLASS_UNUSED(line);
847
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
848
+ }
849
+
850
+ CUTLASS_DEVICE
851
+ void synclog_emit_warpgroup_wait(
852
+ uint32_t line,
853
+ uint32_t n) {
854
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
855
+ if constexpr (!synclog_enable_warpgroup_wait) return;
856
+ if (!synclog_condition_emit()) return;
857
+ uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait);
858
+ if (to == nullptr) return;
859
+ synclog_emit_prefix(to, synclog_header_warpgroup_wait, line);
860
+ to[synclog_length_prefix + 0] = n;
861
+ #else
862
+ CUTLASS_UNUSED(line);
863
+ CUTLASS_UNUSED(n);
864
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
865
+ }
866
+
867
+ CUTLASS_DEVICE
868
+ void synclog_emit_warpgroup_commit_batch(
869
+ uint32_t line) {
870
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
871
+ if constexpr (!synclog_enable_warpgroup_commit_batch) return;
872
+ if (!synclog_condition_emit()) return;
873
+ uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch);
874
+ if (to == nullptr) return;
875
+ synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line);
876
+ #else
877
+ CUTLASS_UNUSED(line);
878
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
879
+ }
880
+
881
+ CUTLASS_DEVICE
882
+ void synclog_emit_wgmma_reg_smem(
883
+ uint32_t line,
884
+ uint64_t desc_b) {
885
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
886
+ if constexpr (!synclog_enable_wgmma_reg_smem) return;
887
+ if (!synclog_condition_emit()) return;
888
+ uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem);
889
+ if (to == nullptr) return;
890
+ synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line);
891
+ to[synclog_length_prefix + 0] = desc_b;
892
+ to[synclog_length_prefix + 1] = desc_b >> 32;
893
+ #else
894
+ CUTLASS_UNUSED(line);
895
+ CUTLASS_UNUSED(desc_b);
896
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
897
+ }
898
+
899
+ CUTLASS_DEVICE
900
+ void synclog_emit_wgmma_smem_smem(
901
+ uint32_t line,
902
+ uint64_t desc_a,
903
+ uint64_t desc_b) {
904
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
905
+ if constexpr (!synclog_enable_wgmma_smem_smem) return;
906
+ if (!synclog_condition_emit()) return;
907
+ uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem);
908
+ if (to == nullptr) return;
909
+ synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line);
910
+ to[synclog_length_prefix + 0] = desc_a;
911
+ to[synclog_length_prefix + 1] = desc_a >> 32;
912
+ to[synclog_length_prefix + 2] = desc_b;
913
+ to[synclog_length_prefix + 3] = desc_b >> 32;
914
+ #else
915
+ CUTLASS_UNUSED(line);
916
+ CUTLASS_UNUSED(desc_a);
917
+ CUTLASS_UNUSED(desc_b);
918
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
919
+ }
920
+
921
+ CUTLASS_DEVICE
922
+ void synclog_emit_cpasync_barrier_arrive(
923
+ uint32_t line,
924
+ uint32_t smem_addr) {
925
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
926
+ if constexpr (!synclog_enable_cpasync_barrier_arrive) return;
927
+ if (!synclog_condition_emit()) return;
928
+ uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive);
929
+ if (to == nullptr) return;
930
+ synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line);
931
+ to[synclog_length_prefix + 0] = smem_addr;
932
+ #else
933
+ CUTLASS_UNUSED(line);
934
+ CUTLASS_UNUSED(smem_addr);
935
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
936
+ }
937
+
938
+ #if !defined(CUTLASS_ENABLE_SYNCLOG)
939
+ CUTLASS_DEVICE
940
+ #elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
941
+ static __attribute__((__noinline__)) __device__
942
+ #else
943
+ static __attribute__((__noinline__))
944
+ #endif
945
+ void synclog_print() {
946
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
947
+ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
948
+ if (synclog_buf == nullptr || !synclog_condition_print()) {
949
+ return;
950
+ }
951
+ printf("synclog start\n");
952
+ for (uint32_t at = 1; at < synclog_cap; ) {
953
+ uint32_t header = synclog_buf[at];
954
+ if (header == synclog_header_none) {
955
+ break;
956
+ }
957
+ printf("synclog at %u: ", at);
958
+ if constexpr (synclog_enable_syncthreads) {
959
+ if (header == synclog_header_syncthreads) {
960
+ synclog_print_prefix("syncthreads", at);
961
+ at += synclog_length_syncthreads;
962
+ printf("\n");
963
+ continue;
964
+ }
965
+ }
966
+ if constexpr (synclog_enable_syncwarp) {
967
+ if (header == synclog_header_syncwarp) {
968
+ synclog_print_prefix("syncwarp", at);
969
+ at += synclog_length_syncwarp;
970
+ printf("\n");
971
+ continue;
972
+ }
973
+ }
974
+ if constexpr (synclog_enable_named_barrier_arrive_and_wait) {
975
+ if (header == synclog_header_named_barrier_arrive_and_wait) {
976
+ synclog_print_prefix("named_barrier_arrive_and_wait", at);
977
+ at += synclog_length_named_barrier_arrive_and_wait;
978
+ printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
979
+ continue;
980
+ }
981
+ }
982
+ if constexpr (synclog_enable_named_barrier_arrive) {
983
+ if (header == synclog_header_named_barrier_arrive) {
984
+ synclog_print_prefix("named_barrier_arrive", at);
985
+ at += synclog_length_named_barrier_arrive;
986
+ printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
987
+ continue;
988
+ }
989
+ }
990
+ if constexpr (synclog_enable_cluster_barrier_init) {
991
+ if (header == synclog_header_cluster_barrier_init) {
992
+ synclog_print_prefix("cluster_barrier_init", at);
993
+ at += synclog_length_cluster_barrier_init;
994
+ printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
995
+ continue;
996
+ }
997
+ }
998
+ if constexpr (synclog_enable_cluster_barrier_wait) {
999
+ if (header == synclog_header_cluster_barrier_wait) {
1000
+ synclog_print_prefix("cluster_barrier_wait", at);
1001
+ at += synclog_length_cluster_barrier_wait;
1002
+ printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
1003
+ continue;
1004
+ }
1005
+ }
1006
+ if constexpr (synclog_enable_cluster_barrier_test_wait) {
1007
+ if (header == synclog_header_cluster_barrier_test_wait) {
1008
+ synclog_print_prefix("cluster_barrier_test_wait", at);
1009
+ at += synclog_length_cluster_barrier_test_wait;
1010
+ printf("smem_addr=%u phase=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
1011
+ continue;
1012
+ }
1013
+ }
1014
+ if constexpr (synclog_enable_cluster_barrier_try_wait) {
1015
+ if (header == synclog_header_cluster_barrier_try_wait) {
1016
+ synclog_print_prefix("cluster_barrier_try_wait", at);
1017
+ at += synclog_length_cluster_barrier_try_wait;
1018
+ printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
1019
+ continue;
1020
+ }
1021
+ }
1022
+ if constexpr (synclog_enable_cluster_barrier_arrive_cluster) {
1023
+ if (header == synclog_header_cluster_barrier_arrive_cluster) {
1024
+ synclog_print_prefix("cluster_barrier_arrive_cluster", at);
1025
+ at += synclog_length_cluster_barrier_arrive_cluster;
1026
+ printf("smem_addr=%u cta_id=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
1027
+ continue;
1028
+ }
1029
+ }
1030
+ if constexpr (synclog_enable_cluster_barrier_arrive) {
1031
+ if (header == synclog_header_cluster_barrier_arrive) {
1032
+ synclog_print_prefix("cluster_barrier_arrive", at);
1033
+ at += synclog_length_cluster_barrier_arrive;
1034
+ printf("smem_addr=%u\n", synclog_buf[at-1]);
1035
+ continue;
1036
+ }
1037
+ }
1038
+ if constexpr (synclog_enable_cluster_barrier_invalidate) {
1039
+ if (header == synclog_header_cluster_barrier_invalidate) {
1040
+ synclog_print_prefix("cluster_barrier_invalidate", at);
1041
+ at += synclog_length_cluster_barrier_invalidate;
1042
+ printf("smem_addr=%u\n", synclog_buf[at-1]);
1043
+ continue;
1044
+ }
1045
+ }
1046
+ if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) {
1047
+ if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) {
1048
+ synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at);
1049
+ at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx;
1050
+ printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
1051
+ continue;
1052
+ }
1053
+ }
1054
+ if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) {
1055
+ if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) {
1056
+ synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at);
1057
+ at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster;
1058
+ printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
1059
+ continue;
1060
+ }
1061
+ }
1062
+ if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) {
1063
+ if (header == synclog_header_cluster_transaction_barrier_expect_transaction) {
1064
+ synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at);
1065
+ at += synclog_length_cluster_transaction_barrier_expect_transaction;
1066
+ printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
1067
+ continue;
1068
+ }
1069
+ }
1070
+ if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) {
1071
+ if (header == synclog_header_cluster_transaction_barrier_complete_transaction) {
1072
+ synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at);
1073
+ at += synclog_length_cluster_transaction_barrier_complete_transaction;
1074
+ printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
1075
+ continue;
1076
+ }
1077
+ }
1078
+ if constexpr (synclog_enable_fence_barrier_init) {
1079
+ if (header == synclog_header_fence_barrier_init) {
1080
+ synclog_print_prefix("fence_barrier_init", at);
1081
+ at += synclog_length_fence_barrier_init;
1082
+ printf("\n");
1083
+ continue;
1084
+ }
1085
+ }
1086
+ if constexpr (synclog_enable_fence_view_async_shared) {
1087
+ if (header == synclog_header_fence_view_async_shared) {
1088
+ synclog_print_prefix("fence_view_async_shared", at);
1089
+ at += synclog_length_fence_view_async_shared;
1090
+ printf("\n");
1091
+ continue;
1092
+ }
1093
+ }
1094
+ if constexpr (synclog_enable_cp_async_wait) {
1095
+ if (header == synclog_header_cp_async_wait) {
1096
+ synclog_print_prefix("cp_async_wait", at);
1097
+ at += synclog_length_cp_async_wait;
1098
+ printf("n=%u\n", synclog_buf[at-1]);
1099
+ continue;
1100
+ }
1101
+ }
1102
+ if constexpr (synclog_enable_cp_async_wait_all) {
1103
+ if (header == synclog_header_cp_async_wait_all) {
1104
+ synclog_print_prefix("cp_async_wait_all", at);
1105
+ at += synclog_length_cp_async_wait_all;
1106
+ printf("\n");
1107
+ continue;
1108
+ }
1109
+ }
1110
+ if constexpr (synclog_enable_cp_async_fence) {
1111
+ if (header == synclog_header_cp_async_fence) {
1112
+ synclog_print_prefix("cp_async_fence", at);
1113
+ at += synclog_length_cp_async_fence;
1114
+ printf("\n");
1115
+ continue;
1116
+ }
1117
+ }
1118
+ if constexpr (synclog_enable_cp_async_nan) {
1119
+ if (header == synclog_header_cp_async_nan) {
1120
+ synclog_print_prefix("cp_async_nan", at);
1121
+ at += synclog_length_cp_async_nan;
1122
+ uint64_t gmem_addr = synclog_buf[at-3];
1123
+ gmem_addr += (uint64_t)synclog_buf[at-2] << 32;
1124
+ printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]);
1125
+ continue;
1126
+ }
1127
+ }
1128
+ if constexpr (synclog_enable_cp_async_zfill) {
1129
+ if (header == synclog_header_cp_async_zfill) {
1130
+ synclog_print_prefix("cp_async_zfill", at);
1131
+ at += synclog_length_cp_async_zfill;
1132
+ uint64_t gmem_addr = synclog_buf[at-4];
1133
+ gmem_addr += (uint64_t)synclog_buf[at-3] << 32;
1134
+ printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]);
1135
+ continue;
1136
+ }
1137
+ }
1138
+ if constexpr (synclog_enable_cp_async) {
1139
+ if (header == synclog_header_cp_async) {
1140
+ synclog_print_prefix("cp_async", at);
1141
+ at += synclog_length_cp_async;
1142
+ uint64_t gmem_addr = synclog_buf[at-4];
1143
+ gmem_addr += (uint64_t)synclog_buf[at-3] << 32;
1144
+ printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]);
1145
+ continue;
1146
+ }
1147
+ }
1148
+ if constexpr (synclog_enable_tma_load) {
1149
+ if (header == synclog_header_tma_load) {
1150
+ synclog_print_prefix("tma_load", at);
1151
+ at += synclog_length_tma_load;
1152
+ uint64_t gmem_int_desc = synclog_buf[at-4];
1153
+ gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32;
1154
+ printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]);
1155
+ continue;
1156
+ }
1157
+ }
1158
+ if constexpr (synclog_enable_tma_store) {
1159
+ if (header == synclog_header_tma_store) {
1160
+ synclog_print_prefix("tma_store", at);
1161
+ at += synclog_length_tma_store;
1162
+ uint64_t gmem_int_desc = synclog_buf[at-3];
1163
+ gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32;
1164
+ printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]);
1165
+ continue;
1166
+ }
1167
+ }
1168
+ if constexpr (synclog_enable_tma_store_arrive) {
1169
+ if (header == synclog_header_tma_store_arrive) {
1170
+ synclog_print_prefix("tma_store_arrive", at);
1171
+ at += synclog_length_tma_store_arrive;
1172
+ printf("\n");
1173
+ continue;
1174
+ }
1175
+ }
1176
+ if constexpr (synclog_enable_tma_store_wait) {
1177
+ if (header == synclog_header_tma_store_wait) {
1178
+ synclog_print_prefix("tma_store_wait", at);
1179
+ at += synclog_length_tma_store_wait;
1180
+ printf("count=%u\n", synclog_buf[at-1]);
1181
+ continue;
1182
+ }
1183
+ }
1184
+ if constexpr (synclog_enable_warpgroup_arrive) {
1185
+ if (header == synclog_header_warpgroup_arrive) {
1186
+ synclog_print_prefix("warpgroup_arrive", at);
1187
+ at += synclog_length_warpgroup_arrive;
1188
+ printf("\n");
1189
+ continue;
1190
+ }
1191
+ }
1192
+ if constexpr (synclog_enable_warpgroup_wait) {
1193
+ if (header == synclog_header_warpgroup_wait) {
1194
+ synclog_print_prefix("warpgroup_wait", at);
1195
+ at += synclog_length_warpgroup_wait;
1196
+ printf("n=%u\n", synclog_buf[at-1]);
1197
+ continue;
1198
+ }
1199
+ }
1200
+ if constexpr (synclog_enable_warpgroup_commit_batch) {
1201
+ if (header == synclog_header_warpgroup_commit_batch) {
1202
+ synclog_print_prefix("warpgroup_commit_batch", at);
1203
+ at += synclog_length_warpgroup_commit_batch;
1204
+ printf("\n");
1205
+ continue;
1206
+ }
1207
+ }
1208
+ if constexpr (synclog_enable_wgmma_reg_smem) {
1209
+ if (header == synclog_header_wgmma_reg_smem) {
1210
+ synclog_print_prefix("wgmma_reg_smem", at);
1211
+ at += synclog_length_wgmma_reg_smem;
1212
+ synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], "");
1213
+ printf("\n");
1214
+ continue;
1215
+ }
1216
+ }
1217
+ if constexpr (synclog_enable_wgmma_smem_smem) {
1218
+ if (header == synclog_header_wgmma_smem_smem) {
1219
+ synclog_print_prefix("wgmma_smem_smem", at);
1220
+ at += synclog_length_wgmma_smem_smem;
1221
+ synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " ");
1222
+ synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], "");
1223
+ printf("\n");
1224
+ continue;
1225
+ }
1226
+ }
1227
+ if constexpr (synclog_enable_cpasync_barrier_arrive) {
1228
+ if (header == synclog_header_cpasync_barrier_arrive) {
1229
+ synclog_print_prefix("cpasync_barrier_arrive", at);
1230
+ at += synclog_length_cpasync_barrier_arrive;
1231
+ printf("smem_addr=%u\n", synclog_buf[at-1]);
1232
+ continue;
1233
+ }
1234
+ }
1235
+ asm volatile ("brkpt;\n" ::);
1236
+ }
1237
+ if (synclog_buf[0] >= synclog_cap) {
1238
+ printf(
1239
+ "synclog was truncated (exceeded capacity of %lu bytes)\n",
1240
+ (synclog_cap - 1) * sizeof(uint32_t)
1241
+ );
1242
+ }
1243
+ printf("synclog end\n");
1244
+ #endif
1245
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
1246
+ }
1247
+
1248
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1249
+
1250
+
1251
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
1252
+ #undef __syncthreads
1253
+ #define __syncthreads() do {\
1254
+ cutlass::arch::synclog_emit_syncthreads(__LINE__);\
1255
+ __syncthreads();\
1256
+ } while (0)
1257
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
1258
+
1259
+ #if defined(CUTLASS_ENABLE_SYNCLOG)
1260
+ #undef __syncwarp
1261
+ #define __syncwarp(...) do {\
1262
+ cutlass::arch::synclog_emit_syncwarp(__LINE__);\
1263
+ __syncwarp(__VA_ARGS__);\
1264
+ } while (0)
1265
+ #endif // defined(CUTLASS_ENABLE_SYNCLOG)
1266
+
1267
+
1268
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1269
+
1270
+ } // namespace arch
1271
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #if (__CUDACC_VER_MAJOR__ >= 9)
38
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))
39
+ #define CUTLASS_ARCH_WMMA_ENABLED
40
+ #define CUTLASS_ARCH_WMMA_SM70_ENABLED
41
+ #endif
42
+ #endif
43
+
44
+ #if (__CUDACC_VER_MAJOR__ >= 10)
45
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720))
46
+ #define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED
47
+ #define CUTLASS_ARCH_WMMA_SM72_ENABLED
48
+ #endif
49
+ #endif
50
+
51
+ #if (__CUDACC_VER_MAJOR__ >= 10)
52
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750))
53
+ #define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED
54
+ #define CUTLASS_ARCH_WMMA_SM75_ENABLED
55
+ #endif
56
+ #endif
57
+
58
+ #if defined(CUTLASS_ARCH_WMMA_ENABLED)
59
+
60
+ #include <mma.h>
61
+ #include "cutlass/arch/mma.h"
62
+ #include "cutlass/array.h"
63
+ #include "cutlass/numeric_types.h"
64
+ #include "cutlass/gemm/gemm.h"
65
+
66
+
67
+ /////////////////////////////////////////////////////////////////////////////////////////////////
68
+
69
+ namespace cutlass {
70
+ namespace arch {
71
+
72
+ ////////////////////////////////////////////////////////////////////////////////////////////////
73
+ /// Statically maps cutlass data types => nvcuda::wmma data types
74
+ /////////////////////////////////////////////////////////////////////////////////////////////////
75
+ template <typename Type_>
76
+ struct CutlassToWmmaDataType{
77
+ using Type = Type_;
78
+ };
79
+
80
+ /// Statically maps cutlass::half_t => __half
81
+ template<>
82
+ struct CutlassToWmmaDataType<cutlass::half_t> {
83
+ using Type = __half;
84
+ };
85
+
86
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
87
+ template<>
88
+ struct CutlassToWmmaDataType<cutlass::bfloat16_t> {
89
+ using Type = __nv_bfloat16;
90
+ };
91
+ #endif
92
+
93
+ /// Statically maps int8_t => char
94
+ template<>
95
+ struct CutlassToWmmaDataType<int8_t> {
96
+ using Type = signed char;
97
+ };
98
+
99
+ /// Statically maps uint8_t => char
100
+ template<>
101
+ struct CutlassToWmmaDataType<uint8_t> {
102
+ using Type = unsigned char;
103
+ };
104
+
105
+ /// Statically maps int32_t => int
106
+ template<>
107
+ struct CutlassToWmmaDataType<int32_t> {
108
+ using Type = int;
109
+ };
110
+
111
+ #if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED)
112
+ /// Statically maps cutlass::int4b_t => experimental::precision::s4
113
+ template<>
114
+ struct CutlassToWmmaDataType<cutlass::int4b_t> {
115
+ using Type = nvcuda::wmma::experimental::precision::s4;
116
+ };
117
+
118
+ /// Statically maps cutlass::uint4b_t => experimental::precision::s4
119
+ template<>
120
+ struct CutlassToWmmaDataType<cutlass::uint4b_t> {
121
+ using Type = nvcuda::wmma::experimental::precision::u4;
122
+ };
123
+
124
+ /// Statically maps cutlass::uint1b_t => experimental::precision::b1
125
+ template<>
126
+ struct CutlassToWmmaDataType<cutlass::uint1b_t> {
127
+ using Type = nvcuda::wmma::experimental::precision::b1;
128
+ };
129
+ #endif
130
+
131
+ ////////////////////////////////////////////////////////////////////////////////////////////////
132
+ /// Statically maps cutlass::layout => nvcuda::wmma layout tags
133
+ ////////////////////////////////////////////////////////////////////////////////////////////////
134
+ template <typename Layout_>
135
+ struct CutlassToWmmaLayout {
136
+ };
137
+
138
+ /// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags
139
+ template <>
140
+ struct CutlassToWmmaLayout<cutlass::layout::RowMajor> {
141
+ using Layout = nvcuda::wmma::row_major;
142
+ static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major;
143
+ };
144
+
145
+ ////////////////////////////////////////////////////////////////////////////////////////////////
146
+ /// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags
147
+ ////////////////////////////////////////////////////////////////////////////////////////////////
148
+ template <>
149
+ struct CutlassToWmmaLayout<cutlass::layout::ColumnMajor> {
150
+ using Layout = nvcuda::wmma::col_major;
151
+ static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major;
152
+ };
153
+ ////////////////////////////////////////////////////////////////////////////////////////////////
154
+
155
+ ////////////////////////////////////////////////////////////////////////////////////////////////
156
+ /// Statically maps nvcuda::wmma data types => cutlass data types
157
+ /////////////////////////////////////////////////////////////////////////////////////////////////
158
+ template <typename Type_>
159
+ struct WmmaToCutlassDataType{
160
+ using Type = Type_;
161
+ };
162
+
163
+ /// Statically maps __half => cutlass::half_t
164
+ template<>
165
+ struct WmmaToCutlassDataType<__half> {
166
+ using Type = cutlass::half_t;
167
+ };
168
+
169
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
170
+ template<>
171
+ struct WmmaToCutlassDataType<__nv_bfloat16> {
172
+ using Type = cutlass::bfloat16_t;
173
+ };
174
+ #endif
175
+
176
+ ////////////////////////////////////////////////////////////////////////////////////////////////
177
+
178
+ /////////////////////////////////////////////////////////////////////////////////////////////////
179
+ // WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks
180
+ // for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]),
181
+ // and native wmma size (Shape)
182
+ /////////////////////////////////////////////////////////////////////////////////////////////////
183
+ template <
184
+ typename Shape_, ///< Size of the matrix product (concept: GemmShape)
185
+ typename ElementA_, ///< Data type of A elements
186
+ typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout)
187
+ typename ElementB_, ///< Data type of B elements
188
+ typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout)
189
+ typename ElementC_, ///< Element type of C matrix
190
+ typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout)
191
+ typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc)
192
+ >
193
+ struct Wmma;
194
+ /////////////////////////////////////////////////////////////////////////////////////////////////
195
+
196
+ } // namespace arch
197
+ } // namespace cutlass
198
+
199
+ /////////////////////////////////////////////////////////////////////////////////////////////////
200
+
201
+ //
202
+ // Specializations for each compute capability
203
+ //
204
+ #ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED
205
+ #include "cutlass/arch/wmma_sm70.h"
206
+ #endif
207
+
208
+ #ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED
209
+ #include "cutlass/arch/wmma_sm72.h"
210
+ #endif
211
+
212
+ #ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED
213
+ #include "cutlass/arch/wmma_sm75.h"
214
+ #endif
215
+
216
+ /////////////////////////////////////////////////////////////////////////////////////////////////
217
+
218
+ #endif //CUTLASS_ARCH_WMMA_ENABLED
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+ #include "cutlass/cutlass.h"
37
+ #include CUDA_STD_HEADER(cassert)
38
+ #include "cutlass/layout/matrix.h"
39
+
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+ namespace cutlass {
42
+ namespace arch {
43
+
44
+
45
+ ////////////////////////////////////////////////////////////////////////////////
46
+ //
47
+ // WMMA template structure defines nvcuda::wmma::fragments and static assert for
48
+ // wmma native instruction sizes supported for half
49
+ //
50
+ ////////////////////////////////////////////////////////////////////////////////
51
+ template <
52
+ typename Shape_,
53
+ typename LayoutA_,
54
+ typename LayoutB_,
55
+ typename ElementC_,
56
+ typename LayoutC_>
57
+ struct Wmma<
58
+ Shape_, ///< Size of the matrix product (concept: GemmShape)
59
+ cutlass::half_t, ///< ElementA
60
+ LayoutA_, ///< LayoutA
61
+ cutlass::half_t, ///< ElementB
62
+ LayoutB_, ///< LayoutB
63
+ ElementC_, ///< ElementC
64
+ LayoutC_, ///< LayoutC
65
+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
66
+ > {
67
+
68
+ #if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED)
69
+ using Shape = Shape_;
70
+ using ElementA = cutlass::half_t;
71
+ using LayoutA = LayoutA_;
72
+ using ElementB = cutlass::half_t;
73
+ using LayoutB = LayoutB_;
74
+ using ElementC = ElementC_;
75
+ using LayoutC = LayoutC_;
76
+ using Operator = cutlass::arch::OpMultiplyAdd;
77
+ using ArchTag = arch::Sm70;
78
+
79
+ // check supported wmma shape for the given multiplicand data types
80
+ static_assert(
81
+ platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
82
+ platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
83
+ platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
84
+ "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
85
+
86
+ // check supported wmma output data type for the given multiplicand data types
87
+ static_assert(
88
+ platform::is_same<cutlass::half_t, ElementC>::value || platform::is_same<float, ElementC>::value,
89
+ "Supported of wmma output data type for f16 multiplicands are: f16 and f32");
90
+
91
+ // Wmma Fragment
92
+ using FragmentA = nvcuda::wmma::fragment<
93
+ nvcuda::wmma::matrix_a,
94
+ Shape::kM,
95
+ Shape::kN,
96
+ Shape::kK,
97
+ typename CutlassToWmmaDataType<ElementA>::Type,
98
+ typename CutlassToWmmaLayout<LayoutA>::Layout>;
99
+
100
+ using FragmentB = nvcuda::wmma::fragment<
101
+ nvcuda::wmma::matrix_b,
102
+ Shape::kM,
103
+ Shape::kN,
104
+ Shape::kK,
105
+ typename CutlassToWmmaDataType<ElementB>::Type,
106
+ typename CutlassToWmmaLayout<LayoutB>::Layout>;
107
+
108
+ using FragmentC = nvcuda::wmma::fragment<
109
+ nvcuda::wmma::accumulator,
110
+ Shape::kM,
111
+ Shape::kN,
112
+ Shape::kK,
113
+ typename CutlassToWmmaDataType<ElementC>::Type>;
114
+
115
+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation
116
+ CUTLASS_DEVICE
117
+ void operator()(
118
+ FragmentC &D,
119
+ FragmentA const &A,
120
+ FragmentB const &B,
121
+ FragmentC const &C) const {
122
+
123
+ nvcuda::wmma::mma_sync(D, A, B, C);
124
+ }
125
+ #else
126
+ static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond");
127
+ #endif
128
+
129
+ };
130
+
131
+ } // namespace arch
132
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+ #include "cutlass/cutlass.h"
37
+ #include CUDA_STD_HEADER(cassert)
38
+ #include "cutlass/layout/matrix.h"
39
+
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+ namespace cutlass {
42
+ namespace arch {
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////
45
+ //
46
+ // WMMA template structure defines nvcuda::wmma::fragments and static assert for
47
+ // wmma native instruction sizes supported for int8_t
48
+ //
49
+ ////////////////////////////////////////////////////////////////////////////////
50
+ template <
51
+ typename Shape_,
52
+ typename LayoutA_,
53
+ typename LayoutB_,
54
+ typename LayoutC_>
55
+ struct Wmma<
56
+ Shape_, ///< Size of the matrix product (concept: GemmShape)
57
+ int8_t, ///< ElementA
58
+ LayoutA_, ///< LayoutA
59
+ int8_t, ///< ElementB
60
+ LayoutB_, ///< LayoutB
61
+ int32_t, ///< ElementC
62
+ LayoutC_, ///< LayoutC
63
+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
64
+ > {
65
+ #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
66
+ using Shape = Shape_;
67
+ using ElementA = int8_t;
68
+ using LayoutA = LayoutA_;
69
+ using ElementB = int8_t;
70
+ using LayoutB = LayoutB_;
71
+ using ElementC = int32_t;
72
+ using LayoutC = LayoutC_;
73
+ using Operator = cutlass::arch::OpMultiplyAdd;
74
+ using ArchTag = arch::Sm72;
75
+
76
+ // check supported wmma shape for the given multiplicand data types
77
+ static_assert(
78
+ platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
79
+ platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
80
+ platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
81
+ "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
82
+
83
+
84
+ // Wmma Fragment
85
+ using FragmentA = nvcuda::wmma::fragment<
86
+ nvcuda::wmma::matrix_a,
87
+ Shape::kM,
88
+ Shape::kN,
89
+ Shape::kK,
90
+ typename CutlassToWmmaDataType<ElementA>::Type,
91
+ typename CutlassToWmmaLayout<LayoutA>::Layout>;
92
+
93
+ using FragmentB = nvcuda::wmma::fragment<
94
+ nvcuda::wmma::matrix_b,
95
+ Shape::kM,
96
+ Shape::kN,
97
+ Shape::kK,
98
+ typename CutlassToWmmaDataType<ElementB>::Type,
99
+ typename CutlassToWmmaLayout<LayoutB>::Layout>;
100
+
101
+ using FragmentC = nvcuda::wmma::fragment<
102
+ nvcuda::wmma::accumulator,
103
+ Shape::kM,
104
+ Shape::kN,
105
+ Shape::kK,
106
+ typename CutlassToWmmaDataType<ElementC>::Type>;
107
+
108
+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation
109
+ CUTLASS_DEVICE
110
+ void operator()(
111
+ FragmentC &D,
112
+ FragmentA const &A,
113
+ FragmentB const &B,
114
+ FragmentC const &C) const {
115
+
116
+ nvcuda::wmma::mma_sync(D, A, B, C);
117
+ }
118
+
119
+ #else
120
+ static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond");
121
+ #endif
122
+
123
+ };
124
+
125
+ ////////////////////////////////////////////////////////////////////////////////
126
+ //
127
+ // WMMA template structure defines nvcuda::wmma::fragments and static assert for
128
+ // wmma native instruction sizes supported for uint8_t
129
+ //
130
+ ////////////////////////////////////////////////////////////////////////////////
131
+ template <
132
+ typename Shape_,
133
+ typename LayoutA_,
134
+ typename LayoutB_,
135
+ typename LayoutC_>
136
+ struct Wmma<
137
+ Shape_, ///< Size of the matrix product (concept: GemmShape)
138
+ uint8_t, ///< ElementA
139
+ LayoutA_, ///< LayoutA
140
+ uint8_t, ///< ElementB
141
+ LayoutB_, ///< LayoutB
142
+ int32_t, ///< ElementC
143
+ LayoutC_, ///< LayoutC
144
+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
145
+ > {
146
+ #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
147
+ using Shape = Shape_;
148
+ using ElementA = uint8_t;
149
+ using LayoutA = LayoutA_;
150
+ using ElementB = uint8_t;
151
+ using LayoutB = LayoutB_;
152
+ using ElementC = int32_t;
153
+ using LayoutC = LayoutC_;
154
+ using Operator = cutlass::arch::OpMultiplyAdd;
155
+ using ArchTag = arch::Sm72;
156
+
157
+ // check supported wmma shape for the given multiplicand data types
158
+ static_assert(
159
+ platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
160
+ platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
161
+ platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
162
+ "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
163
+
164
+ // Wmma Fragment
165
+ using FragmentA = nvcuda::wmma::fragment<
166
+ nvcuda::wmma::matrix_a,
167
+ Shape::kM,
168
+ Shape::kN,
169
+ Shape::kK,
170
+ typename CutlassToWmmaDataType<ElementA>::Type,
171
+ typename CutlassToWmmaLayout<LayoutA>::Layout>;
172
+
173
+ using FragmentB = nvcuda::wmma::fragment<
174
+ nvcuda::wmma::matrix_b,
175
+ Shape::kM,
176
+ Shape::kN,
177
+ Shape::kK,
178
+ typename CutlassToWmmaDataType<ElementB>::Type,
179
+ typename CutlassToWmmaLayout<LayoutB>::Layout>;
180
+
181
+ using FragmentC = nvcuda::wmma::fragment<
182
+ nvcuda::wmma::accumulator,
183
+ Shape::kM,
184
+ Shape::kN,
185
+ Shape::kK,
186
+ typename CutlassToWmmaDataType<ElementC>::Type>;
187
+
188
+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation
189
+ CUTLASS_DEVICE
190
+ void operator()(
191
+ FragmentC &D,
192
+ FragmentA const &A,
193
+ FragmentB const &B,
194
+ FragmentC const &C) const {
195
+
196
+ nvcuda::wmma::mma_sync(D, A, B, C);
197
+ }
198
+
199
+ #else
200
+ static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond");
201
+ #endif
202
+
203
+ };
204
+
205
+ } // namespace arch
206
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Matrix multiply
33
+ */
34
+
35
+ #pragma once
36
+ #include "cutlass/cutlass.h"
37
+ #include CUDA_STD_HEADER(cassert)
38
+ #include "cutlass/layout/matrix.h"
39
+
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+ namespace cutlass {
42
+ namespace arch {
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////
45
+ //
46
+ // WMMA template structure defines nvcuda::wmma::fragments and static assert for
47
+ // wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4).
48
+ //
49
+ ////////////////////////////////////////////////////////////////////////////////
50
+ template <
51
+ typename Shape_,
52
+ typename LayoutA_,
53
+ typename LayoutB_,
54
+ typename LayoutC_>
55
+ struct Wmma<
56
+ Shape_, ///< Size of the matrix product (concept: GemmShape)
57
+ cutlass::int4b_t, ///< ElementA
58
+ LayoutA_, ///< LayoutA
59
+ cutlass::int4b_t, ///< ElementB
60
+ LayoutB_, ///< LayoutB
61
+ int32_t, ///< ElementC
62
+ LayoutC_, ///< LayoutC
63
+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
64
+ > {
65
+ #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
66
+ using Shape = Shape_;
67
+ using ElementA = cutlass::int4b_t;
68
+ using LayoutA = LayoutA_;
69
+ using ElementB = cutlass::int4b_t;
70
+ using LayoutB = LayoutB_;
71
+ using ElementC = int32_t;
72
+ using LayoutC = LayoutC_;
73
+ using Operator = cutlass::arch::OpMultiplyAdd;
74
+ using ArchTag = arch::Sm75;
75
+
76
+ // check supported wmma shape for the given multiplicand data types
77
+ static_assert(
78
+ platform::is_same<cutlass::gemm::GemmShape<8, 8, 32>, Shape>::value,
79
+ "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");
80
+
81
+
82
+ // Wmma Fragment
83
+ using FragmentA = nvcuda::wmma::fragment<
84
+ nvcuda::wmma::matrix_a,
85
+ Shape::kM,
86
+ Shape::kN,
87
+ Shape::kK,
88
+ typename CutlassToWmmaDataType<ElementA>::Type,
89
+ typename CutlassToWmmaLayout<LayoutA>::Layout>;
90
+
91
+ using FragmentB = nvcuda::wmma::fragment<
92
+ nvcuda::wmma::matrix_b,
93
+ Shape::kM,
94
+ Shape::kN,
95
+ Shape::kK,
96
+ typename CutlassToWmmaDataType<ElementB>::Type,
97
+ typename CutlassToWmmaLayout<LayoutB>::Layout>;
98
+
99
+ using FragmentC = nvcuda::wmma::fragment<
100
+ nvcuda::wmma::accumulator,
101
+ Shape::kM,
102
+ Shape::kN,
103
+ Shape::kK,
104
+ typename CutlassToWmmaDataType<ElementC>::Type>;
105
+
106
+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation
107
+ CUTLASS_DEVICE
108
+ void operator()(
109
+ FragmentC &D,
110
+ FragmentA const &A,
111
+ FragmentB const &B,
112
+ FragmentC const &C) const {
113
+ nvcuda::wmma::mma_sync(D, A, B, C);
114
+
115
+ }
116
+
117
+ #else
118
+ static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond");
119
+ #endif
120
+
121
+ };
122
+
123
+ ////////////////////////////////////////////////////////////////////////////////
124
+ //
125
+ // WMMA template structure defines nvcuda::wmma::fragments and static assert for
126
+ // wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1).
127
+ //
128
+ ////////////////////////////////////////////////////////////////////////////////
129
+ template <
130
+ typename Shape_,
131
+ typename LayoutA_,
132
+ typename LayoutB_,
133
+ typename LayoutC_>
134
+ struct Wmma<
135
+ Shape_, ///< Size of the matrix product (concept: GemmShape)
136
+ cutlass::uint1b_t, ///< ElementA
137
+ LayoutA_, ///< LayoutA
138
+ cutlass::uint1b_t, ///< ElementB
139
+ LayoutB_, ///< LayoutB
140
+ int32_t, ///< ElementC
141
+ LayoutC_, ///< LayoutC
142
+ cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc)
143
+ > {
144
+ #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
145
+ using Shape = Shape_;
146
+ using ElementA = cutlass::uint1b_t;
147
+ using LayoutA = LayoutA_;
148
+ using ElementB = cutlass::uint1b_t;
149
+ using LayoutB = LayoutB_;
150
+ using ElementC = int32_t;
151
+ using LayoutC = LayoutC_;
152
+ using Operator = cutlass::arch::OpXorPopc;
153
+ using ArchTag = arch::Sm75;
154
+
155
+ // check supported wmma shape for the given multiplicand data types
156
+ static_assert(
157
+ platform::is_same<cutlass::gemm::GemmShape<8, 8, 128>, Shape>::value,
158
+ "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");
159
+
160
+
161
+ // Wmma Fragment
162
+ using FragmentA = nvcuda::wmma::fragment<
163
+ nvcuda::wmma::matrix_a,
164
+ Shape::kM,
165
+ Shape::kN,
166
+ Shape::kK,
167
+ typename CutlassToWmmaDataType<ElementA>::Type,
168
+ typename CutlassToWmmaLayout<LayoutA>::Layout>;
169
+
170
+ using FragmentB = nvcuda::wmma::fragment<
171
+ nvcuda::wmma::matrix_b,
172
+ Shape::kM,
173
+ Shape::kN,
174
+ Shape::kK,
175
+ typename CutlassToWmmaDataType<ElementB>::Type,
176
+ typename CutlassToWmmaLayout<LayoutB>::Layout>;
177
+
178
+ using FragmentC = nvcuda::wmma::fragment<
179
+ nvcuda::wmma::accumulator,
180
+ Shape::kM,
181
+ Shape::kN,
182
+ Shape::kK,
183
+ typename CutlassToWmmaDataType<ElementC>::Type>;
184
+
185
+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation
186
+ CUTLASS_DEVICE
187
+ void operator()(
188
+ FragmentC &D,
189
+ FragmentA const &A,
190
+ FragmentB const &B,
191
+ FragmentC const &C) const {
192
+ nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
193
+ nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
194
+ }
195
+
196
+ #else
197
+ static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond");
198
+ #endif
199
+
200
+ };
201
+
202
+ } // namespace arch
203
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h ADDED
@@ -0,0 +1,2860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
33
+ and is safe to use in a union.
34
+ */
35
+
36
+ #pragma once
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/functional.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/platform/platform.h"
41
+ namespace cutlass {
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ /// Statically sized array for any data type
46
+ template <
47
+ typename T,
48
+ int N,
49
+ bool RegisterSized = sizeof_bits<T>::value >= 32
50
+ >
51
+ struct Array;
52
+
53
+ namespace detail {
54
+
55
+ template<class T>
56
+ struct is_Array : platform::false_type {};
57
+
58
+ template <
59
+ typename T,
60
+ int N,
61
+ bool RegisterSized
62
+ >
63
+ struct is_Array<Array<T, N, RegisterSized> > : platform::true_type {};
64
+
65
+ template<typename T>
66
+ constexpr bool is_Array_v = is_Array<T>::value;
67
+
68
+ } // namespace detail
69
+
70
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
71
+
72
+ /// Defines the size of an Array<> in bits
73
+ template <typename T, int N, bool RegisterSized>
74
+ struct sizeof_bits<Array<T, N, RegisterSized> > {
75
+ static constexpr int value = sizeof(Array<T, N, RegisterSized>) * 8;
76
+ };
77
+
78
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
79
+
80
+ /// Returns true if the argument is a power of 2
81
+ CUTLASS_HOST_DEVICE
82
+ constexpr bool ispow2(unsigned x) {
83
+ return x && (!(x & (x - 1)));
84
+ }
85
+
86
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
87
+
88
+ /// Returns the largest power of two not greater than the argument.
89
+ CUTLASS_HOST_DEVICE
90
+ constexpr unsigned floor_pow_2(unsigned x) {
91
+ return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1);
92
+ }
93
+
94
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
95
+
96
+ /// Statically sized array for any data type
97
+ template <
98
+ typename T,
99
+ int N
100
+ >
101
+ struct Array<T, N, true> {
102
+
103
+ /// Storage type
104
+ using Storage = T;
105
+
106
+ /// Element type
107
+ using Element = T;
108
+
109
+ /// Number of storage elements
110
+ //static std::size_t const kStorageElements = N;
111
+ static constexpr size_t kStorageElements = N;
112
+
113
+ /// Number of logical elements
114
+ static constexpr size_t kElements = N;
115
+
116
+ //
117
+ // C++ standard members
118
+ //
119
+
120
+ typedef T value_type;
121
+ typedef size_t size_type;
122
+ typedef ptrdiff_t difference_type;
123
+ typedef value_type &reference;
124
+ typedef value_type const & const_reference;
125
+ typedef value_type *pointer;
126
+ typedef value_type const * const_pointer;
127
+
128
+ //
129
+ // Iterators
130
+ //
131
+
132
+ /// Bidirectional iterator over elements
133
+ class iterator {
134
+
135
+ /// Pointer to object
136
+ T *ptr_;
137
+
138
+ public:
139
+
140
+ CUTLASS_HOST_DEVICE
141
+ iterator(): ptr_(nullptr) { }
142
+
143
+ CUTLASS_HOST_DEVICE
144
+ iterator(T *_ptr): ptr_(_ptr) { }
145
+
146
+ CUTLASS_HOST_DEVICE
147
+ iterator &operator++() {
148
+ ++ptr_;
149
+ return *this;
150
+ }
151
+
152
+ CUTLASS_HOST_DEVICE
153
+ iterator &operator--() {
154
+ --ptr_;
155
+ return *this;
156
+ }
157
+
158
+ CUTLASS_HOST_DEVICE
159
+ iterator operator++(int) {
160
+ iterator ret(*this);
161
+ ++ptr_;
162
+ return ret;
163
+ }
164
+
165
+ CUTLASS_HOST_DEVICE
166
+ iterator operator--(int) {
167
+ iterator ret(*this);
168
+ --ptr_;
169
+ return ret;
170
+ }
171
+
172
+ CUTLASS_HOST_DEVICE
173
+ T &operator*() const {
174
+ return *ptr_;
175
+ }
176
+
177
+ CUTLASS_HOST_DEVICE
178
+ bool operator==(iterator const &other) const {
179
+ return ptr_ == other.ptr_;
180
+ }
181
+
182
+ CUTLASS_HOST_DEVICE
183
+ bool operator!=(iterator const &other) const {
184
+ return ptr_ != other.ptr_;
185
+ }
186
+ };
187
+
188
+ /// Bidirectional constant iterator over elements
189
+ class const_iterator {
190
+
191
+ /// Pointer to object
192
+ const T *ptr_;
193
+
194
+ public:
195
+
196
+ CUTLASS_HOST_DEVICE
197
+ const_iterator(): ptr_(nullptr) { }
198
+
199
+ CUTLASS_HOST_DEVICE
200
+ const_iterator(T const *_ptr): ptr_(_ptr) { }
201
+
202
+ CUTLASS_HOST_DEVICE
203
+ const_iterator &operator++() {
204
+ ++ptr_;
205
+ return *this;
206
+ }
207
+
208
+ CUTLASS_HOST_DEVICE
209
+ const_iterator &operator--() {
210
+ --ptr_;
211
+ return *this;
212
+ }
213
+
214
+ CUTLASS_HOST_DEVICE
215
+ const_iterator operator++(int) {
216
+ const_iterator ret(*this);
217
+ ++ptr_;
218
+ return ret;
219
+ }
220
+
221
+ CUTLASS_HOST_DEVICE
222
+ const_iterator operator--(int) {
223
+ const_iterator ret(*this);
224
+ --ptr_;
225
+ return ret;
226
+ }
227
+
228
+ CUTLASS_HOST_DEVICE
229
+ T const &operator*() const {
230
+ return *ptr_;
231
+ }
232
+
233
+ CUTLASS_HOST_DEVICE
234
+ bool operator==(const_iterator const &other) const {
235
+ return ptr_ == other.ptr_;
236
+ }
237
+
238
+ CUTLASS_HOST_DEVICE
239
+ bool operator!=(const_iterator const &other) const {
240
+ return ptr_ != other.ptr_;
241
+ }
242
+ };
243
+
244
+ /// Bidirectional iterator over elements
245
+ class reverse_iterator {
246
+
247
+ /// Pointer to object
248
+ T *ptr_;
249
+
250
+ public:
251
+
252
+ CUTLASS_HOST_DEVICE
253
+ reverse_iterator(): ptr_(nullptr) { }
254
+
255
+ CUTLASS_HOST_DEVICE
256
+ reverse_iterator(T *_ptr): ptr_(_ptr) { }
257
+
258
+ CUTLASS_HOST_DEVICE
259
+ reverse_iterator &operator++() {
260
+ --ptr_;
261
+ return *this;
262
+ }
263
+
264
+ CUTLASS_HOST_DEVICE
265
+ reverse_iterator &operator--() {
266
+ ++ptr_;
267
+ return *this;
268
+ }
269
+
270
+ CUTLASS_HOST_DEVICE
271
+ reverse_iterator operator++(int) {
272
+ iterator ret(*this);
273
+ --ptr_;
274
+ return ret;
275
+ }
276
+
277
+ CUTLASS_HOST_DEVICE
278
+ reverse_iterator operator--(int) {
279
+ iterator ret(*this);
280
+ ++ptr_;
281
+ return ret;
282
+ }
283
+
284
+ CUTLASS_HOST_DEVICE
285
+ T &operator*() const {
286
+ return *(ptr_ - 1);
287
+ }
288
+
289
+ CUTLASS_HOST_DEVICE
290
+ bool operator==(reverse_iterator const &other) const {
291
+ return ptr_ == other.ptr_;
292
+ }
293
+
294
+ CUTLASS_HOST_DEVICE
295
+ bool operator!=(reverse_iterator const &other) const {
296
+ return ptr_ != other.ptr_;
297
+ }
298
+ };
299
+
300
+ /// Bidirectional constant iterator over elements
301
+ class const_reverse_iterator {
302
+
303
+ /// Pointer to object
304
+ T const *ptr_;
305
+
306
+ public:
307
+
308
+ CUTLASS_HOST_DEVICE
309
+ const_reverse_iterator(): ptr_(nullptr) { }
310
+
311
+ CUTLASS_HOST_DEVICE
312
+ const_reverse_iterator(T const *_ptr): ptr_(_ptr) { }
313
+
314
+ CUTLASS_HOST_DEVICE
315
+ const_reverse_iterator &operator++() {
316
+ --ptr_;
317
+ return *this;
318
+ }
319
+
320
+ CUTLASS_HOST_DEVICE
321
+ const_reverse_iterator &operator--() {
322
+ ++ptr_;
323
+ return *this;
324
+ }
325
+
326
+ CUTLASS_HOST_DEVICE
327
+ const_reverse_iterator operator++(int) {
328
+ const_reverse_iterator ret(*this);
329
+ --ptr_;
330
+ return ret;
331
+ }
332
+
333
+ CUTLASS_HOST_DEVICE
334
+ const_reverse_iterator operator--(int) {
335
+ const_reverse_iterator ret(*this);
336
+ ++ptr_;
337
+ return ret;
338
+ }
339
+
340
+ CUTLASS_HOST_DEVICE
341
+ T const &operator*() const {
342
+ return *(ptr_ - 1);
343
+ }
344
+
345
+ CUTLASS_HOST_DEVICE
346
+ bool operator==(const_iterator const &other) const {
347
+ return ptr_ == other.ptr_;
348
+ }
349
+
350
+ CUTLASS_HOST_DEVICE
351
+ bool operator!=(const_iterator const &other) const {
352
+ return ptr_ != other.ptr_;
353
+ }
354
+ };
355
+
356
+ /// Internal storage
357
+ Storage storage[kElements];
358
+
359
+ /// Efficient clear method
360
+ CUTLASS_HOST_DEVICE
361
+ void clear() {
362
+ fill(T(0));
363
+ }
364
+
365
+ CUTLASS_HOST_DEVICE
366
+ reference at(size_type pos) {
367
+ return reinterpret_cast<reference>(storage[pos]);
368
+ }
369
+
370
+ CUTLASS_HOST_DEVICE
371
+ const_reference at(size_type pos) const {
372
+ return reinterpret_cast<const_reference>(storage[pos]);
373
+ }
374
+
375
+ CUTLASS_HOST_DEVICE
376
+ reference operator[](size_type pos) {
377
+ return reinterpret_cast<reference>(storage[pos]);
378
+ }
379
+
380
+ CUTLASS_HOST_DEVICE
381
+ const_reference operator[](size_type pos) const {
382
+ return reinterpret_cast<const_reference>(storage[pos]);
383
+ }
384
+
385
+ CUTLASS_HOST_DEVICE
386
+ reference front() {
387
+ return reinterpret_cast<reference>(storage[0]);
388
+ }
389
+
390
+ CUTLASS_HOST_DEVICE
391
+ const_reference front() const {
392
+ return reinterpret_cast<const_reference>(storage[0]);
393
+ }
394
+
395
+ CUTLASS_HOST_DEVICE
396
+ reference back() {
397
+ return reinterpret_cast<reference>(storage[kStorageElements - 1]);
398
+ }
399
+
400
+ CUTLASS_HOST_DEVICE
401
+ const_reference back() const {
402
+ return reinterpret_cast<const_reference>(storage[kStorageElements - 1]);
403
+ }
404
+
405
+ CUTLASS_HOST_DEVICE
406
+ pointer data() {
407
+ return reinterpret_cast<pointer>(storage);
408
+ }
409
+
410
+ CUTLASS_HOST_DEVICE
411
+ const_pointer data() const {
412
+ return reinterpret_cast<const_pointer>(storage);
413
+ }
414
+
415
+ CUTLASS_HOST_DEVICE
416
+ pointer raw_data() {
417
+ return reinterpret_cast<pointer>(storage);
418
+ }
419
+
420
+ CUTLASS_HOST_DEVICE
421
+ const_pointer raw_data() const {
422
+ return reinterpret_cast<const_pointer>(storage);
423
+ }
424
+
425
+
426
+ CUTLASS_HOST_DEVICE
427
+ constexpr bool empty() const {
428
+ return !kElements;
429
+ }
430
+
431
+ CUTLASS_HOST_DEVICE
432
+ constexpr size_type size() const {
433
+ return kElements;
434
+ }
435
+
436
+ CUTLASS_HOST_DEVICE
437
+ constexpr size_type max_size() const {
438
+ return kElements;
439
+ }
440
+
441
+ CUTLASS_HOST_DEVICE
442
+ void fill(T const &value) {
443
+ CUTLASS_PRAGMA_UNROLL
444
+ for (int i = 0; i < int(kElements); ++i) {
445
+ storage[i] = static_cast<Storage>(value);
446
+ }
447
+ }
448
+
449
+ CUTLASS_HOST_DEVICE
450
+ iterator begin() {
451
+ return iterator(storage);
452
+ }
453
+
454
+ CUTLASS_HOST_DEVICE
455
+ const_iterator begin() const {
456
+ return cbegin();
457
+ }
458
+
459
+ CUTLASS_HOST_DEVICE
460
+ const_iterator cbegin() const {
461
+ return const_iterator(storage);
462
+ }
463
+
464
+ CUTLASS_HOST_DEVICE
465
+ iterator end() {
466
+ return iterator(reinterpret_cast<pointer>(storage + kStorageElements));
467
+ }
468
+
469
+ CUTLASS_HOST_DEVICE
470
+ const_iterator end() const {
471
+ return cend();
472
+ }
473
+
474
+ CUTLASS_HOST_DEVICE
475
+ const_iterator cend() const {
476
+ return const_iterator(reinterpret_cast<const_pointer>(storage + kStorageElements));
477
+ }
478
+
479
+ CUTLASS_HOST_DEVICE
480
+ reverse_iterator rbegin() {
481
+ return reverse_iterator(reinterpret_cast<pointer>(storage + kStorageElements));
482
+ }
483
+
484
+ CUTLASS_HOST_DEVICE
485
+ const_reverse_iterator rbegin() const {
486
+ return crbegin();
487
+ }
488
+
489
+ CUTLASS_HOST_DEVICE
490
+ const_reverse_iterator crbegin() const {
491
+ return const_reverse_iterator(reinterpret_cast<const_pointer>(storage + kStorageElements));
492
+ }
493
+
494
+ CUTLASS_HOST_DEVICE
495
+ reverse_iterator rend() {
496
+ return reverse_iterator(reinterpret_cast<pointer>(storage));
497
+ }
498
+
499
+ CUTLASS_HOST_DEVICE
500
+ const_reverse_iterator rend() const {
501
+ return crend();
502
+ }
503
+
504
+ CUTLASS_HOST_DEVICE
505
+ const_reverse_iterator crend() const {
506
+ return const_reverse_iterator(reinterpret_cast<const_pointer>(storage));
507
+ }
508
+
509
+ //
510
+ // Comparison operators
511
+ //
512
+
513
+ };
514
+
515
+
516
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
517
+ // Factories
518
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
519
+
520
+ template <typename Element>
521
+ CUTLASS_HOST_DEVICE
522
+ Array<Element, 1> make_Array(Element x) {
523
+ return {x};
524
+ }
525
+
526
+ template <typename Element>
527
+ CUTLASS_HOST_DEVICE
528
+ Array<Element, 2> make_Array(Element x, Element y) {
529
+ return {x,y};
530
+ }
531
+
532
+ template <typename Element>
533
+ CUTLASS_HOST_DEVICE
534
+ Array<Element, 3> make_Array(Element x, Element y, Element z) {
535
+ return {x,y,z};
536
+ }
537
+
538
+ template <typename Element>
539
+ CUTLASS_HOST_DEVICE
540
+ Array<Element, 4> make_Array(Element x, Element y, Element z, Element w) {
541
+ return {x,y,z,w};
542
+ }
543
+
544
+
545
+ /////////////////////////////////////////////////////////////////////////////////////////////////
546
+ // functional.h numeric specializations
547
+ /////////////////////////////////////////////////////////////////////////////////////////////////
548
+
549
+ template <typename T, int N>
550
+ struct absolute_value_op< Array<T, N> > {
551
+
552
+ CUTLASS_HOST_DEVICE
553
+ Array<T, N> operator()(Array<T, N> const &lhs) const {
554
+
555
+ Array<T, N> result;
556
+ absolute_value_op<T> scalar_op;
557
+
558
+ CUTLASS_PRAGMA_UNROLL
559
+ for (int i = 0; i < N; ++i) {
560
+ result[i] = scalar_op(lhs[i]);
561
+ }
562
+
563
+ return result;
564
+ }
565
+ };
566
+
567
+ template <typename T, int N>
568
+ struct plus<Array<T, N>> {
569
+ CUTLASS_HOST_DEVICE
570
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
571
+
572
+ Array<T, N> result;
573
+ plus<T> scalar_op;
574
+
575
+ CUTLASS_PRAGMA_UNROLL
576
+ for (int i = 0; i < N; ++i) {
577
+ result[i] = scalar_op(lhs[i], rhs[i]);
578
+ }
579
+
580
+ return result;
581
+ }
582
+
583
+ CUTLASS_HOST_DEVICE
584
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
585
+
586
+ Array<T, N> result;
587
+ plus<T> scalar_op;
588
+
589
+ CUTLASS_PRAGMA_UNROLL
590
+ for (int i = 0; i < N; ++i) {
591
+ result[i] = scalar_op(lhs[i], scalar);
592
+ }
593
+
594
+ return result;
595
+ }
596
+
597
+ CUTLASS_HOST_DEVICE
598
+ Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
599
+
600
+ Array<T, N> result;
601
+ plus<T> scalar_op;
602
+
603
+ CUTLASS_PRAGMA_UNROLL
604
+ for (int i = 0; i < N; ++i) {
605
+ result[i] = scalar_op(scalar, rhs[i]);
606
+ }
607
+
608
+ return result;
609
+ }
610
+ };
611
+ template <typename T, int N>
612
+ struct minus<Array<T, N>> {
613
+
614
+ CUTLASS_HOST_DEVICE
615
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
616
+
617
+ Array<T, N> result;
618
+ minus<T> scalar_op;
619
+
620
+ CUTLASS_PRAGMA_UNROLL
621
+ for (int i = 0; i < N; ++i) {
622
+ result[i] = scalar_op(lhs[i], rhs[i]);
623
+ }
624
+
625
+ return result;
626
+ }
627
+
628
+ CUTLASS_HOST_DEVICE
629
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
630
+
631
+ Array<T, N> result;
632
+ minus<T> scalar_op;
633
+
634
+ CUTLASS_PRAGMA_UNROLL
635
+ for (int i = 0; i < N; ++i) {
636
+ result[i] = scalar_op(lhs[i], scalar);
637
+ }
638
+
639
+ return result;
640
+ }
641
+
642
+ CUTLASS_HOST_DEVICE
643
+ Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
644
+
645
+ Array<T, N> result;
646
+ minus<T> scalar_op;
647
+
648
+ CUTLASS_PRAGMA_UNROLL
649
+ for (int i = 0; i < N; ++i) {
650
+ result[i] = scalar_op(scalar, rhs[i]);
651
+ }
652
+
653
+ return result;
654
+ }
655
+ };
656
+
657
+ template <typename T, int N>
658
+ struct multiplies<Array<T, N>> {
659
+
660
+ CUTLASS_HOST_DEVICE
661
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
662
+
663
+ Array<T, N> result;
664
+ multiplies<T> scalar_op;
665
+
666
+ CUTLASS_PRAGMA_UNROLL
667
+ for (int i = 0; i < N; ++i) {
668
+ result[i] = scalar_op(lhs[i], rhs[i]);
669
+ }
670
+
671
+ return result;
672
+ }
673
+
674
+ CUTLASS_HOST_DEVICE
675
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
676
+
677
+ Array<T, N> result;
678
+ multiplies<T> scalar_op;
679
+
680
+ CUTLASS_PRAGMA_UNROLL
681
+ for (int i = 0; i < N; ++i) {
682
+ result[i] = scalar_op(lhs[i], scalar);
683
+ }
684
+
685
+ return result;
686
+ }
687
+
688
+ CUTLASS_HOST_DEVICE
689
+ Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
690
+
691
+ Array<T, N> result;
692
+ multiplies<T> scalar_op;
693
+
694
+ CUTLASS_PRAGMA_UNROLL
695
+ for (int i = 0; i < N; ++i) {
696
+ result[i] = scalar_op(scalar, rhs[i]);
697
+ }
698
+
699
+ return result;
700
+ }
701
+ };
702
+
703
+ template <typename T, int N, bool PropogateNaN>
704
+ struct maximum_absolute_value_reduction<Array<T, N>, PropogateNaN> {
705
+
706
+ CUTLASS_HOST_DEVICE
707
+ T operator() (T const& scalar, Array<T, N> const& rhs) const {
708
+
709
+ T result = scalar;
710
+ maximum_absolute_value_reduction<T, PropogateNaN> scalar_op;
711
+
712
+ CUTLASS_PRAGMA_UNROLL
713
+ for (int i = 0; i < N; ++i) {
714
+ result = scalar_op(result, rhs[i]);
715
+ }
716
+
717
+ return result;
718
+ }
719
+ };
720
+
721
+ template <typename T, int N>
722
+ struct scale<Array<T, N>> {
723
+ T const scaling_factor_;
724
+
725
+ CUTLASS_HOST_DEVICE
726
+ scale(T scaling_factor) : scaling_factor_(scaling_factor) {
727
+ }
728
+
729
+ CUTLASS_HOST_DEVICE
730
+ Array<T, N> operator()(Array<T, N> const & rhs) const {
731
+ Array<T, N> result;
732
+
733
+ CUTLASS_PRAGMA_UNROLL
734
+ for (int i = 0; i < N; ++i) {
735
+ result[i] = rhs[i] * scaling_factor_;
736
+ }
737
+
738
+ return result;
739
+ }
740
+ };
741
+
742
+ template <typename T, int N>
743
+ struct divides<Array<T, N>> {
744
+
745
+ CUTLASS_HOST_DEVICE
746
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
747
+
748
+ Array<T, N> result;
749
+ divides<T> scalar_op;
750
+
751
+ CUTLASS_PRAGMA_UNROLL
752
+ for (int i = 0; i < N; ++i) {
753
+ result[i] = scalar_op(lhs[i], rhs[i]);
754
+ }
755
+
756
+ return result;
757
+ }
758
+
759
+ CUTLASS_HOST_DEVICE
760
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
761
+
762
+ Array<T, N> result;
763
+ divides<T> scalar_op;
764
+
765
+ CUTLASS_PRAGMA_UNROLL
766
+ for (int i = 0; i < N; ++i) {
767
+ result[i] = scalar_op(lhs[i], scalar);
768
+ }
769
+
770
+ return result;
771
+ }
772
+
773
+ CUTLASS_HOST_DEVICE
774
+ Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
775
+
776
+ Array<T, N> result;
777
+ divides<T> scalar_op;
778
+
779
+ CUTLASS_PRAGMA_UNROLL
780
+ for (int i = 0; i < N; ++i) {
781
+ result[i] = scalar_op(scalar, rhs[i]);
782
+ }
783
+
784
+ return result;
785
+ }
786
+ };
787
+
788
+ template <typename T, int N>
789
+ struct reciprocal_approximate<Array<T, N>> {
790
+
791
+ CUTLASS_HOST_DEVICE
792
+ Array<T, N> operator()(Array<T, N> const &lhs) const {
793
+
794
+ Array<T, N> result;
795
+ reciprocal_approximate<T> scalar_op;
796
+
797
+ CUTLASS_PRAGMA_UNROLL
798
+ for (int i = 0; i < N; ++i) {
799
+ result[i] = scalar_op(lhs[i]);
800
+ }
801
+
802
+ return result;
803
+ }
804
+ };
805
+
806
+ template <typename T, int N>
807
+ struct reciprocal_approximate_ftz<Array<T, N>> {
808
+
809
+ CUTLASS_HOST_DEVICE
810
+ Array<T, N> operator()(Array<T, N> const &lhs) const {
811
+
812
+ Array<T, N> result;
813
+ reciprocal_approximate_ftz<T> scalar_op;
814
+
815
+ CUTLASS_PRAGMA_UNROLL
816
+ for (int i = 0; i < N; ++i) {
817
+ result[i] = scalar_op(lhs[i]);
818
+ }
819
+
820
+ return result;
821
+ }
822
+ };
823
+
824
+ template <typename T, int N, bool PropagateNaN>
825
+ struct maximum<Array<T, N>, PropagateNaN> {
826
+
827
+ CUTLASS_HOST_DEVICE
828
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
829
+
830
+ Array<T, N> result;
831
+ maximum<T, PropagateNaN> scalar_op;
832
+
833
+ CUTLASS_PRAGMA_UNROLL
834
+ for (int i = 0; i < N; ++i) {
835
+ result[i] = scalar_op(lhs[i], rhs[i]);
836
+ }
837
+
838
+ return result;
839
+ }
840
+
841
+ CUTLASS_HOST_DEVICE
842
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
843
+
844
+ Array<T, N> result;
845
+ maximum<T, PropagateNaN> scalar_op;
846
+
847
+ CUTLASS_PRAGMA_UNROLL
848
+ for (int i = 0; i < N; ++i) {
849
+ result[i] = scalar_op(lhs[i], scalar);
850
+ }
851
+
852
+ return result;
853
+ }
854
+
855
+ CUTLASS_HOST_DEVICE
856
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &rhs) const {
857
+
858
+ Array<T, N> result;
859
+ maximum<T, PropagateNaN> scalar_op;
860
+
861
+ CUTLASS_PRAGMA_UNROLL
862
+ for (int i = 0; i < N; ++i) {
863
+ result[i] = scalar_op(scalar, rhs[i]);
864
+ }
865
+
866
+ return result;
867
+ }
868
+ };
869
+
870
+ template <typename T, int N, bool PropagateNaN>
871
+ struct minimum<Array<T, N>, PropagateNaN> {
872
+
873
+ CUTLASS_HOST_DEVICE
874
+ static T scalar_op(T const &lhs, T const &rhs) {
875
+ return (rhs < lhs ? rhs : lhs);
876
+ }
877
+
878
+ CUTLASS_HOST_DEVICE
879
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
880
+
881
+ Array<T, N> result;
882
+ minimum<T, PropagateNaN> scalar_op;
883
+
884
+ CUTLASS_PRAGMA_UNROLL
885
+ for (int i = 0; i < N; ++i) {
886
+ result[i] = scalar_op(lhs[i], rhs[i]);
887
+ }
888
+
889
+ return result;
890
+ }
891
+
892
+ CUTLASS_HOST_DEVICE
893
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
894
+
895
+ Array<T, N> result;
896
+ minimum<T, PropagateNaN> scalar_op;
897
+
898
+ CUTLASS_PRAGMA_UNROLL
899
+ for (int i = 0; i < N; ++i) {
900
+ result[i] = scalar_op(lhs[i], scalar);
901
+ }
902
+
903
+ return result;
904
+ }
905
+
906
+ CUTLASS_HOST_DEVICE
907
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &rhs) const {
908
+
909
+ Array<T, N> result;
910
+ minimum<T, PropagateNaN> scalar_op;
911
+
912
+ CUTLASS_PRAGMA_UNROLL
913
+ for (int i = 0; i < N; ++i) {
914
+ result[i] = scalar_op(scalar, rhs[i]);
915
+ }
916
+
917
+ return result;
918
+ }
919
+ };
920
+
921
+ template <typename T, int N>
922
+ struct minimum_with_nan_propagation<Array<T, N>> : minimum<Array<T, N>, true>
923
+ {};
924
+
925
+ template <typename T, int N>
926
+ struct negate<Array<T, N>> {
927
+
928
+ CUTLASS_HOST_DEVICE
929
+ Array<T, N> operator()(Array<T, N> const &lhs) const {
930
+
931
+ Array<T, N> result;
932
+ negate<T> scalar_op;
933
+
934
+ CUTLASS_PRAGMA_UNROLL
935
+ for (int i = 0; i < N; ++i) {
936
+ result[i] = scalar_op(lhs[i]);
937
+ }
938
+
939
+ return result;
940
+ }
941
+ };
942
+
943
+ /// Fused multiply-add
944
+ template <typename T, int N>
945
+ struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
946
+
947
+ CUTLASS_HOST_DEVICE
948
+ Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
949
+
950
+ Array<T, N> result;
951
+ multiply_add<T> scalar_op;
952
+
953
+ CUTLASS_PRAGMA_UNROLL
954
+ for (int i = 0; i < N; ++i) {
955
+ result[i] = scalar_op(a[i], b[i], c[i]);
956
+ }
957
+
958
+ return result;
959
+ }
960
+
961
+ CUTLASS_HOST_DEVICE
962
+ Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
963
+
964
+ Array<T, N> result;
965
+ multiply_add<T> scalar_op;
966
+
967
+ CUTLASS_PRAGMA_UNROLL
968
+ for (int i = 0; i < N; ++i) {
969
+ result[i] = scalar_op(a[i], scalar, c[i]);
970
+ }
971
+
972
+ return result;
973
+ }
974
+
975
+ CUTLASS_HOST_DEVICE
976
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
977
+
978
+ Array<T, N> result;
979
+ multiply_add<T> scalar_op;
980
+
981
+ CUTLASS_PRAGMA_UNROLL
982
+ for (int i = 0; i < N; ++i) {
983
+ result[i] = scalar_op(scalar, b[i], c[i]);
984
+ }
985
+
986
+ return result;
987
+ }
988
+
989
+ CUTLASS_HOST_DEVICE
990
+ Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, T const &scalar) const {
991
+
992
+ Array<T, N> result;
993
+ multiply_add<T> scalar_op;
994
+
995
+ CUTLASS_PRAGMA_UNROLL
996
+ for (int i = 0; i < N; ++i) {
997
+ result[i] = scalar_op(a[i], b[i], scalar);
998
+ }
999
+
1000
+ return result;
1001
+ }
1002
+
1003
+
1004
+ CUTLASS_HOST_DEVICE
1005
+ Array<T, N> operator()(Array<T, N> const &a, T const &scalar_b, T const &scalar_c) const {
1006
+
1007
+ Array<T, N> result;
1008
+ multiply_add<T> scalar_op;
1009
+
1010
+ CUTLASS_PRAGMA_UNROLL
1011
+ for (int i = 0; i < N; ++i) {
1012
+ result[i] = scalar_op(a[i], scalar_b, scalar_c);
1013
+ }
1014
+
1015
+ return result;
1016
+ }
1017
+ };
1018
+
1019
+ /// Fused square-and-plus
1020
+ template <typename T, int N>
1021
+ struct square_and_plus<Array<T, N>> {
1022
+
1023
+ CUTLASS_HOST_DEVICE
1024
+ Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
1025
+ multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> ma_op;
1026
+ return ma_op(rhs, rhs, lhs);
1027
+ }
1028
+
1029
+ CUTLASS_HOST_DEVICE
1030
+ Array<T, N> operator()(Array<T, N> const &lhs, T const &rhs) const {
1031
+ plus<Array<T, N>> plus_op;
1032
+ multiplies<T> multiplies_op;
1033
+ return plus_op(multiplies_op(rhs, rhs), lhs);
1034
+ }
1035
+ };
1036
+
1037
+ /// Inverse-square-root
1038
+ template <typename T, int N>
1039
+ struct inverse_square_root<Array<T, N>> {
1040
+ CUTLASS_HOST_DEVICE
1041
+ Array<T, N> operator()(Array<T, N> const &a) const {
1042
+ Array<T, N> result;
1043
+ inverse_square_root<T> scalar_op;
1044
+
1045
+ CUTLASS_PRAGMA_UNROLL
1046
+ for (int i = 0; i < N; ++i) {
1047
+ result[i] = scalar_op(a[i]);
1048
+ }
1049
+ return result;
1050
+ }
1051
+ };
1052
+
1053
+ template <int N>
1054
+ struct inverse_square_root<Array<half_t, N>> {
1055
+ CUTLASS_HOST_DEVICE
1056
+ Array<half_t, N> operator()(Array<half_t, N> const & a) const {
1057
+ Array<half_t, N> result;
1058
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1059
+
1060
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1061
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1062
+
1063
+ CUTLASS_PRAGMA_UNROLL
1064
+ for (int i = 0; i < N / 2; ++i) {
1065
+ result_ptr[i] = h2rsqrt(a_ptr[i]);
1066
+ }
1067
+
1068
+ if constexpr (N % 2) {
1069
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1070
+ __half d_residual = hrsqrt(a_residual_ptr[N - 1]);
1071
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1072
+ }
1073
+
1074
+ #else
1075
+
1076
+ inverse_square_root<half_t> scalar_op;
1077
+
1078
+ CUTLASS_PRAGMA_UNROLL
1079
+ for (int i = 0; i < N; ++i) {
1080
+ result[i] = scalar_op(a[i]);
1081
+ }
1082
+
1083
+ #endif
1084
+
1085
+ return result;
1086
+ }
1087
+ };
1088
+
1089
+ /// Fused multiply-add-relu0
1090
+ template <typename T, int N>
1091
+ struct multiply_add_relu0<Array<T, N>, Array<T, N>, Array<T, N>> {
1092
+
1093
+ CUTLASS_HOST_DEVICE
1094
+ Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
1095
+
1096
+ Array<T, N> result;
1097
+ multiply_add<T> scalar_op;
1098
+ maximum<T> mx;
1099
+
1100
+ CUTLASS_PRAGMA_UNROLL
1101
+ for (int i = 0; i < N; ++i) {
1102
+ result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0));
1103
+ }
1104
+
1105
+ return result;
1106
+ }
1107
+
1108
+ CUTLASS_HOST_DEVICE
1109
+ Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
1110
+
1111
+ Array<T, N> result;
1112
+ multiply_add<T> scalar_op;
1113
+ maximum<T> mx;
1114
+
1115
+ CUTLASS_PRAGMA_UNROLL
1116
+ for (int i = 0; i < N; ++i) {
1117
+ result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0));
1118
+ }
1119
+
1120
+ return result;
1121
+ }
1122
+
1123
+ CUTLASS_HOST_DEVICE
1124
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
1125
+
1126
+ Array<T, N> result;
1127
+ multiply_add<T> scalar_op;
1128
+ maximum<T> mx;
1129
+
1130
+ CUTLASS_PRAGMA_UNROLL
1131
+ for (int i = 0; i < N; ++i) {
1132
+ result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0));
1133
+ }
1134
+
1135
+ return result;
1136
+ }
1137
+ };
1138
+
1139
+
1140
+ template <typename T, int N>
1141
+ struct conjugate<Array<T, N> > {
1142
+ CUTLASS_HOST_DEVICE
1143
+ Array<T, N> operator()(Array<T, N> const &a) const {
1144
+
1145
+ conjugate<T> conj_op;
1146
+
1147
+ Array<T, N> ca;
1148
+ CUTLASS_PRAGMA_UNROLL
1149
+ for (int i = 0; i < N; ++i) {
1150
+ ca[i] = conj_op(a[i]);
1151
+ }
1152
+ return ca;
1153
+ }
1154
+ };
1155
+
1156
+
1157
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1158
+ // functional.h numeric specializations targeting SIMD instructions in device code.
1159
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1160
+
1161
+ template <int N>
1162
+ struct plus<Array<half_t, N>> {
1163
+ CUTLASS_HOST_DEVICE
1164
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
1165
+ Array<half_t, N> result;
1166
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1167
+
1168
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1169
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1170
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1171
+
1172
+ CUTLASS_PRAGMA_UNROLL
1173
+ for (int i = 0; i < N / 2; ++i) {
1174
+ result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]);
1175
+ }
1176
+
1177
+ if constexpr (N % 2) {
1178
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1179
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1180
+ __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
1181
+
1182
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1183
+ }
1184
+
1185
+ #else
1186
+
1187
+ CUTLASS_PRAGMA_UNROLL
1188
+ for (int i = 0; i < N; ++i) {
1189
+ result[i] = lhs[i] + rhs[i];
1190
+ }
1191
+ #endif
1192
+
1193
+ return result;
1194
+ }
1195
+
1196
+ CUTLASS_HOST_DEVICE
1197
+ Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
1198
+ Array<half_t, N> result;
1199
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1200
+
1201
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1202
+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
1203
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1204
+
1205
+ CUTLASS_PRAGMA_UNROLL
1206
+ for (int i = 0; i < N / 2; ++i) {
1207
+ result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]);
1208
+ }
1209
+
1210
+ if constexpr (N % 2) {
1211
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1212
+ __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
1213
+
1214
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1215
+ }
1216
+
1217
+ #else
1218
+
1219
+ CUTLASS_PRAGMA_UNROLL
1220
+ for (int i = 0; i < N; ++i) {
1221
+ result[i] = lhs + rhs[i];
1222
+ }
1223
+ #endif
1224
+
1225
+ return result;
1226
+ }
1227
+
1228
+ CUTLASS_HOST_DEVICE
1229
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
1230
+ Array<half_t, N> result;
1231
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1232
+
1233
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1234
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1235
+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
1236
+
1237
+ CUTLASS_PRAGMA_UNROLL
1238
+ for (int i = 0; i < N / 2; ++i) {
1239
+ result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair);
1240
+ }
1241
+
1242
+ if constexpr (N % 2) {
1243
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1244
+ __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
1245
+
1246
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1247
+ }
1248
+
1249
+ #else
1250
+
1251
+ CUTLASS_PRAGMA_UNROLL
1252
+ for (int i = 0; i < N; ++i) {
1253
+ result[i] = lhs[i] + rhs;
1254
+ }
1255
+ #endif
1256
+
1257
+ return result;
1258
+ }
1259
+ };
1260
+
1261
+ template <int N>
1262
+ struct minus<Array<half_t, N>> {
1263
+ CUTLASS_HOST_DEVICE
1264
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
1265
+ Array<half_t, N> result;
1266
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1267
+
1268
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1269
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1270
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1271
+
1272
+ CUTLASS_PRAGMA_UNROLL
1273
+ for (int i = 0; i < N / 2; ++i) {
1274
+ result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]);
1275
+ }
1276
+
1277
+ if constexpr (N % 2) {
1278
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1279
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1280
+ __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
1281
+
1282
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1283
+ }
1284
+
1285
+ #else
1286
+
1287
+ CUTLASS_PRAGMA_UNROLL
1288
+ for (int i = 0; i < N; ++i) {
1289
+ result[i] = lhs[i] - rhs[i];
1290
+ }
1291
+ #endif
1292
+
1293
+ return result;
1294
+ }
1295
+
1296
+ CUTLASS_HOST_DEVICE
1297
+ Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
1298
+ Array<half_t, N> result;
1299
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1300
+
1301
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1302
+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
1303
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1304
+
1305
+ CUTLASS_PRAGMA_UNROLL
1306
+ for (int i = 0; i < N / 2; ++i) {
1307
+ result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]);
1308
+ }
1309
+
1310
+ if constexpr (N % 2) {
1311
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1312
+ __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
1313
+
1314
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1315
+ }
1316
+
1317
+ #else
1318
+
1319
+ CUTLASS_PRAGMA_UNROLL
1320
+ for (int i = 0; i < N; ++i) {
1321
+ result[i] = lhs - rhs[i];
1322
+ }
1323
+ #endif
1324
+
1325
+ return result;
1326
+ }
1327
+
1328
+ CUTLASS_HOST_DEVICE
1329
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
1330
+ Array<half_t, N> result;
1331
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1332
+
1333
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1334
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1335
+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
1336
+
1337
+ CUTLASS_PRAGMA_UNROLL
1338
+ for (int i = 0; i < N / 2; ++i) {
1339
+ result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair);
1340
+ }
1341
+
1342
+ if constexpr (N % 2) {
1343
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1344
+ __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
1345
+
1346
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1347
+ }
1348
+
1349
+ #else
1350
+
1351
+ CUTLASS_PRAGMA_UNROLL
1352
+ for (int i = 0; i < N; ++i) {
1353
+ result[i] = lhs[i] - rhs;
1354
+ }
1355
+ #endif
1356
+
1357
+ return result;
1358
+ }
1359
+ };
1360
+
1361
+ template <int N>
1362
+ struct multiplies<Array<half_t, N>> {
1363
+ CUTLASS_HOST_DEVICE
1364
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
1365
+ Array<half_t, N> result;
1366
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1367
+
1368
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1369
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1370
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1371
+
1372
+ CUTLASS_PRAGMA_UNROLL
1373
+ for (int i = 0; i < N / 2; ++i) {
1374
+ result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]);
1375
+ }
1376
+
1377
+ if constexpr (N % 2) {
1378
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1379
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1380
+ __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
1381
+
1382
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1383
+ }
1384
+
1385
+ #else
1386
+
1387
+ CUTLASS_PRAGMA_UNROLL
1388
+ for (int i = 0; i < N; ++i) {
1389
+ result[i] = lhs[i] * rhs[i];
1390
+ }
1391
+ #endif
1392
+
1393
+ return result;
1394
+ }
1395
+
1396
+ CUTLASS_HOST_DEVICE
1397
+ Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
1398
+ Array<half_t, N> result;
1399
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1400
+
1401
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1402
+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
1403
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1404
+
1405
+ CUTLASS_PRAGMA_UNROLL
1406
+ for (int i = 0; i < N / 2; ++i) {
1407
+ result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]);
1408
+ }
1409
+
1410
+ if constexpr (N % 2) {
1411
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1412
+
1413
+ __half d_residual = __hmul(
1414
+ reinterpret_cast<__half const &>(lhs),
1415
+ b_residual_ptr[N - 1]);
1416
+
1417
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1418
+ }
1419
+
1420
+ #else
1421
+
1422
+ CUTLASS_PRAGMA_UNROLL
1423
+ for (int i = 0; i < N; ++i) {
1424
+ result[i] = lhs * rhs[i];
1425
+ }
1426
+ #endif
1427
+
1428
+ return result;
1429
+ }
1430
+
1431
+ CUTLASS_HOST_DEVICE
1432
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
1433
+ Array<half_t, N> result;
1434
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1435
+
1436
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1437
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1438
+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
1439
+
1440
+ CUTLASS_PRAGMA_UNROLL
1441
+ for (int i = 0; i < N / 2; ++i) {
1442
+ result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair);
1443
+ }
1444
+
1445
+ if constexpr (N % 2) {
1446
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1447
+
1448
+ __half d_residual = __hmul(
1449
+ a_residual_ptr[N - 1],
1450
+ reinterpret_cast<__half const &>(rhs));
1451
+
1452
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1453
+ }
1454
+
1455
+ #else
1456
+
1457
+ CUTLASS_PRAGMA_UNROLL
1458
+ for (int i = 0; i < N; ++i) {
1459
+ result[i] = lhs[i] * rhs;
1460
+ }
1461
+ #endif
1462
+
1463
+ return result;
1464
+ }
1465
+ };
1466
+
1467
+ template <int N>
1468
+ struct divides<Array<half_t, N>> {
1469
+ CUTLASS_HOST_DEVICE
1470
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
1471
+ Array<half_t, N> result;
1472
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1473
+
1474
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1475
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1476
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1477
+
1478
+ CUTLASS_PRAGMA_UNROLL
1479
+ for (int i = 0; i < N / 2; ++i) {
1480
+ result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]);
1481
+ }
1482
+
1483
+ if constexpr (N % 2) {
1484
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1485
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1486
+
1487
+ __half d_residual = __hdiv(
1488
+ a_residual_ptr[N - 1],
1489
+ b_residual_ptr[N - 1]);
1490
+
1491
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1492
+ }
1493
+
1494
+ #else
1495
+
1496
+ CUTLASS_PRAGMA_UNROLL
1497
+ for (int i = 0; i < N; ++i) {
1498
+ result[i] = lhs[i] / rhs[i];
1499
+ }
1500
+ #endif
1501
+
1502
+ return result;
1503
+ }
1504
+
1505
+ CUTLASS_HOST_DEVICE
1506
+ Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
1507
+ Array<half_t, N> result;
1508
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1509
+
1510
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1511
+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
1512
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
1513
+
1514
+ CUTLASS_PRAGMA_UNROLL
1515
+ for (int i = 0; i < N / 2; ++i) {
1516
+ result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]);
1517
+ }
1518
+
1519
+ if constexpr (N % 2) {
1520
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
1521
+
1522
+ __half d_residual = __hdiv(
1523
+ reinterpret_cast<__half const &>(lhs),
1524
+ b_residual_ptr[N - 1]);
1525
+
1526
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1527
+ }
1528
+
1529
+ #else
1530
+
1531
+ CUTLASS_PRAGMA_UNROLL
1532
+ for (int i = 0; i < N; ++i) {
1533
+ result[i] = lhs / rhs[i];
1534
+ }
1535
+ #endif
1536
+
1537
+ return result;
1538
+ }
1539
+
1540
+ CUTLASS_HOST_DEVICE
1541
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
1542
+ Array<half_t, N> result;
1543
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1544
+
1545
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1546
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
1547
+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
1548
+
1549
+ CUTLASS_PRAGMA_UNROLL
1550
+ for (int i = 0; i < N / 2; ++i) {
1551
+ result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair);
1552
+ }
1553
+
1554
+ if constexpr (N % 2) {
1555
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
1556
+
1557
+ __half d_residual = __hdiv(
1558
+ a_residual_ptr[N - 1],
1559
+ reinterpret_cast<__half const &>(rhs));
1560
+
1561
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1562
+ }
1563
+
1564
+ #else
1565
+
1566
+ CUTLASS_PRAGMA_UNROLL
1567
+ for (int i = 0; i < N; ++i) {
1568
+ result[i] = lhs[i] / rhs;
1569
+ }
1570
+ #endif
1571
+
1572
+ return result;
1573
+ }
1574
+ };
1575
+
1576
+ template <int N>
1577
+ struct negate<Array<half_t, N>> {
1578
+ CUTLASS_HOST_DEVICE
1579
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs) const {
1580
+ Array<half_t, N> result;
1581
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1582
+
1583
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1584
+ __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs);
1585
+
1586
+ CUTLASS_PRAGMA_UNROLL
1587
+ for (int i = 0; i < N / 2; ++i) {
1588
+ result_ptr[i] = __hneg2(source_ptr[i]);
1589
+ }
1590
+
1591
+ if constexpr (N % 2) {
1592
+ half_t x = -lhs[N - 1];
1593
+ __half lhs_val = reinterpret_cast<__half const &>(x);
1594
+ result[N - 1] = reinterpret_cast<half_t const &>(lhs_val);
1595
+ }
1596
+
1597
+ #else
1598
+
1599
+ CUTLASS_PRAGMA_UNROLL
1600
+ for (int i = 0; i < N; ++i) {
1601
+ result[i] = -lhs[i];
1602
+ }
1603
+ #endif
1604
+
1605
+ return result;
1606
+ }
1607
+ };
1608
+
1609
+ /// Fused multiply-add
1610
+ template <int N>
1611
+ struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
1612
+
1613
+ CUTLASS_HOST_DEVICE
1614
+ Array<half_t, N> operator()(
1615
+ Array<half_t, N> const &a,
1616
+ Array<half_t, N> const &b,
1617
+ Array<half_t, N> const &c) const {
1618
+
1619
+ Array<half_t, N> result;
1620
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1621
+
1622
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1623
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1624
+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1625
+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1626
+
1627
+ CUTLASS_PRAGMA_UNROLL
1628
+ for (int i = 0; i < N / 2; ++i) {
1629
+ result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]);
1630
+ }
1631
+
1632
+ if constexpr (N % 2) {
1633
+
1634
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1635
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1636
+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1637
+
1638
+ __half d_residual = __hfma(
1639
+ a_residual_ptr[N - 1],
1640
+ b_residual_ptr[N - 1],
1641
+ c_residual_ptr[N - 1]);
1642
+
1643
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1644
+ }
1645
+
1646
+ #else
1647
+
1648
+ multiply_add<half_t> op;
1649
+
1650
+ CUTLASS_PRAGMA_UNROLL
1651
+ for (int i = 0; i < N; ++i) {
1652
+ result[i] = op(a[i], b[i], c[i]);
1653
+ }
1654
+ #endif
1655
+
1656
+ return result;
1657
+ }
1658
+
1659
+ CUTLASS_HOST_DEVICE
1660
+ Array<half_t, N> operator()(
1661
+ half_t const &a,
1662
+ Array<half_t, N> const &b,
1663
+ Array<half_t, N> const &c) const {
1664
+
1665
+ Array<half_t, N> result;
1666
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1667
+
1668
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1669
+ __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
1670
+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1671
+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1672
+
1673
+ CUTLASS_PRAGMA_UNROLL
1674
+ for (int i = 0; i < N / 2; ++i) {
1675
+ result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]);
1676
+ }
1677
+
1678
+ if constexpr (N % 2) {
1679
+
1680
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1681
+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1682
+ __half d_residual = __hfma(
1683
+ reinterpret_cast<__half const &>(a),
1684
+ b_residual_ptr[N - 1],
1685
+ c_residual_ptr[N - 1]);
1686
+
1687
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1688
+ }
1689
+
1690
+ #else
1691
+
1692
+ multiply_add<half_t> op;
1693
+
1694
+ CUTLASS_PRAGMA_UNROLL
1695
+ for (int i = 0; i < N; ++i) {
1696
+ result[i] = op(a, b[i], c[i]);
1697
+ }
1698
+ #endif
1699
+
1700
+ return result;
1701
+ }
1702
+
1703
+ CUTLASS_HOST_DEVICE
1704
+ Array<half_t, N> operator()(
1705
+ Array<half_t, N> const &a,
1706
+ half_t const &b,
1707
+ Array<half_t, N> const &c) const {
1708
+
1709
+ Array<half_t, N> result;
1710
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1711
+
1712
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1713
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1714
+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
1715
+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1716
+
1717
+ CUTLASS_PRAGMA_UNROLL
1718
+ for (int i = 0; i < N / 2; ++i) {
1719
+ result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]);
1720
+ }
1721
+
1722
+ if constexpr (N % 2) {
1723
+
1724
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1725
+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1726
+
1727
+ __half d_residual = __hfma(
1728
+ a_residual_ptr[N - 1],
1729
+ reinterpret_cast<__half const &>(b),
1730
+ c_residual_ptr[N - 1]);
1731
+
1732
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1733
+ }
1734
+
1735
+ #else
1736
+
1737
+ multiply_add<half_t> op;
1738
+
1739
+ CUTLASS_PRAGMA_UNROLL
1740
+ for (int i = 0; i < N; ++i) {
1741
+ result[i] = op(a[i], b, c[i]);
1742
+ }
1743
+ #endif
1744
+
1745
+ return result;
1746
+ }
1747
+
1748
+ CUTLASS_HOST_DEVICE
1749
+ Array<half_t, N> operator()(
1750
+ Array<half_t, N> const &a,
1751
+ Array<half_t, N> const &b,
1752
+ half_t const &c) const {
1753
+
1754
+ Array<half_t, N> result;
1755
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1756
+
1757
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1758
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1759
+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1760
+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
1761
+
1762
+ CUTLASS_PRAGMA_UNROLL
1763
+ for (int i = 0; i < N / 2; ++i) {
1764
+ result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair);
1765
+ }
1766
+
1767
+ if constexpr (N % 2) {
1768
+
1769
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1770
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1771
+
1772
+ __half d_residual = __hfma(
1773
+ a_residual_ptr[N - 1],
1774
+ b_residual_ptr[N - 1],
1775
+ reinterpret_cast<__half const &>(c));
1776
+
1777
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1778
+ }
1779
+
1780
+ #else
1781
+
1782
+ multiply_add<half_t> op;
1783
+
1784
+ CUTLASS_PRAGMA_UNROLL
1785
+ for (int i = 0; i < N; ++i) {
1786
+ result[i] = op(a[i], b[i], c);
1787
+ }
1788
+ #endif
1789
+
1790
+ return result;
1791
+ }
1792
+
1793
+ CUTLASS_HOST_DEVICE
1794
+ Array<half_t, N> operator()(
1795
+ Array<half_t, N> const &a,
1796
+ half_t const &b,
1797
+ half_t const &c) const {
1798
+
1799
+ Array<half_t, N> result;
1800
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1801
+
1802
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1803
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1804
+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
1805
+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
1806
+
1807
+ CUTLASS_PRAGMA_UNROLL
1808
+ for (int i = 0; i < N / 2; ++i) {
1809
+ result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_pair);
1810
+ }
1811
+
1812
+ if constexpr (N % 2) {
1813
+
1814
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1815
+
1816
+ __half d_residual = __hfma(
1817
+ a_residual_ptr[N - 1],
1818
+ reinterpret_cast<__half const &>(b),
1819
+ reinterpret_cast<__half const &>(c));
1820
+
1821
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1822
+ }
1823
+
1824
+ #else
1825
+
1826
+ multiply_add<half_t> op;
1827
+
1828
+ CUTLASS_PRAGMA_UNROLL
1829
+ for (int i = 0; i < N; ++i) {
1830
+ result[i] = op(a[i], b, c);
1831
+ }
1832
+ #endif
1833
+
1834
+ return result;
1835
+ }
1836
+ };
1837
+
1838
+ /// Fused multiply-add-relu0
1839
+ template <int N>
1840
+ struct multiply_add_relu0<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
1841
+
1842
+ CUTLASS_HOST_DEVICE
1843
+ Array<half_t, N> operator()(
1844
+ Array<half_t, N> const &a,
1845
+ Array<half_t, N> const &b,
1846
+ Array<half_t, N> const &c) const {
1847
+
1848
+ Array<half_t, N> result;
1849
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1850
+
1851
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1852
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1853
+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1854
+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1855
+
1856
+ CUTLASS_PRAGMA_UNROLL
1857
+ for (int i = 0; i < N / 2; ++i) {
1858
+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]);
1859
+ }
1860
+
1861
+ if constexpr (N % 2) {
1862
+
1863
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1864
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1865
+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1866
+
1867
+ __half d_residual = __hfma_relu(
1868
+ a_residual_ptr[N - 1],
1869
+ b_residual_ptr[N - 1],
1870
+ c_residual_ptr[N - 1]);
1871
+
1872
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1873
+ }
1874
+
1875
+ #else
1876
+
1877
+ multiply_add<half_t> op;
1878
+ maximum<half_t> mx;
1879
+
1880
+ CUTLASS_PRAGMA_UNROLL
1881
+ for (int i = 0; i < N; ++i) {
1882
+ result[i] = mx(op(a[i], b[i], c[i]), (half_t)0);
1883
+ }
1884
+ #endif
1885
+
1886
+ return result;
1887
+ }
1888
+
1889
+ CUTLASS_HOST_DEVICE
1890
+ Array<half_t, N> operator()(
1891
+ half_t const &a,
1892
+ Array<half_t, N> const &b,
1893
+ Array<half_t, N> const &c) const {
1894
+
1895
+ Array<half_t, N> result;
1896
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1897
+
1898
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1899
+ __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
1900
+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1901
+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1902
+
1903
+ CUTLASS_PRAGMA_UNROLL
1904
+ for (int i = 0; i < N / 2; ++i) {
1905
+ result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]);
1906
+ }
1907
+
1908
+ if constexpr (N % 2) {
1909
+
1910
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1911
+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1912
+ __half d_residual = __hfma_relu(
1913
+ reinterpret_cast<__half const &>(a),
1914
+ b_residual_ptr[N - 1],
1915
+ c_residual_ptr[N - 1]);
1916
+
1917
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1918
+ }
1919
+
1920
+ #else
1921
+
1922
+ multiply_add<half_t> op;
1923
+ maximum<half_t> mx;
1924
+
1925
+ CUTLASS_PRAGMA_UNROLL
1926
+ for (int i = 0; i < N; ++i) {
1927
+ result[i] = mx(op(a, b[i], c[i]), half_t(0));
1928
+ }
1929
+ #endif
1930
+
1931
+ return result;
1932
+ }
1933
+
1934
+ CUTLASS_HOST_DEVICE
1935
+ Array<half_t, N> operator()(
1936
+ Array<half_t, N> const &a,
1937
+ half_t const &b,
1938
+ Array<half_t, N> const &c) const {
1939
+
1940
+ Array<half_t, N> result;
1941
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1942
+
1943
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1944
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1945
+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
1946
+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1947
+
1948
+ CUTLASS_PRAGMA_UNROLL
1949
+ for (int i = 0; i < N / 2; ++i) {
1950
+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]);
1951
+ }
1952
+
1953
+ if constexpr (N % 2) {
1954
+
1955
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1956
+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1957
+
1958
+ __half d_residual = __hfma_relu(
1959
+ a_residual_ptr[N - 1],
1960
+ reinterpret_cast<__half const &>(b),
1961
+ c_residual_ptr[N - 1]);
1962
+
1963
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1964
+ }
1965
+
1966
+ #else
1967
+
1968
+ multiply_add<half_t> op;
1969
+ maximum<half_t> mx;
1970
+
1971
+ CUTLASS_PRAGMA_UNROLL
1972
+ for (int i = 0; i < N; ++i) {
1973
+ result[i] = mx(op(a[i], b, c[i]), half_t(0));
1974
+ }
1975
+ #endif
1976
+
1977
+ return result;
1978
+ }
1979
+
1980
+ CUTLASS_HOST_DEVICE
1981
+ Array<half_t, N> operator()(
1982
+ Array<half_t, N> const &a,
1983
+ Array<half_t, N> const &b,
1984
+ half_t const &c) const {
1985
+
1986
+ Array<half_t, N> result;
1987
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1988
+
1989
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1990
+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1991
+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1992
+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
1993
+
1994
+ CUTLASS_PRAGMA_UNROLL
1995
+ for (int i = 0; i < N / 2; ++i) {
1996
+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair);
1997
+ }
1998
+
1999
+ if constexpr (N % 2) {
2000
+
2001
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
2002
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
2003
+
2004
+ __half d_residual = __hfma_relu(
2005
+ a_residual_ptr[N - 1],
2006
+ b_residual_ptr[N - 1],
2007
+ reinterpret_cast<__half const &>(c));
2008
+
2009
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2010
+ }
2011
+
2012
+ #else
2013
+
2014
+ multiply_add<half_t> op;
2015
+ maximum<half_t> mx;
2016
+
2017
+ CUTLASS_PRAGMA_UNROLL
2018
+ for (int i = 0; i < N; ++i) {
2019
+ result[i] = mx(op(a[i], b[i], c), half_t(0));
2020
+ }
2021
+ #endif
2022
+
2023
+ return result;
2024
+ }
2025
+ };
2026
+
2027
+ template <int N, bool PropagateNaN>
2028
+ struct minimum<Array<half_t, N>, PropagateNaN> {
2029
+ CUTLASS_HOST_DEVICE
2030
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
2031
+ Array<half_t, N> result;
2032
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2033
+
2034
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
2035
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
2036
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
2037
+
2038
+ CUTLASS_PRAGMA_UNROLL
2039
+ for (int i = 0; i < N / 2; ++i) {
2040
+ result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i])
2041
+ : __hmin2(lhs_ptr[i], rhs_ptr[i]);
2042
+ }
2043
+
2044
+ if constexpr (N % 2) {
2045
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
2046
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
2047
+
2048
+ __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1])
2049
+ : __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
2050
+
2051
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2052
+ }
2053
+
2054
+ #else
2055
+
2056
+ minimum<half_t,PropagateNaN> mn;
2057
+
2058
+ CUTLASS_PRAGMA_UNROLL
2059
+ for (int i = 0; i < N; ++i) {
2060
+ result[i] = mn(lhs[i],rhs[i]);
2061
+ }
2062
+ #endif
2063
+
2064
+ return result;
2065
+ }
2066
+
2067
+ CUTLASS_HOST_DEVICE
2068
+ Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
2069
+ Array<half_t, N> result;
2070
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2071
+
2072
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
2073
+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
2074
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
2075
+
2076
+ CUTLASS_PRAGMA_UNROLL
2077
+ for (int i = 0; i < N / 2; ++i) {
2078
+ result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i])
2079
+ : __hmin2(lhs_pair, rhs_ptr[i]);
2080
+ }
2081
+
2082
+ if constexpr (N % 2) {
2083
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
2084
+
2085
+ __half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1])
2086
+ : __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
2087
+
2088
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2089
+ }
2090
+
2091
+ #else
2092
+
2093
+ minimum<half_t,PropagateNaN> mn;
2094
+
2095
+ CUTLASS_PRAGMA_UNROLL
2096
+ for (int i = 0; i < N; ++i) {
2097
+ result[i] = mn(lhs, rhs[i]);
2098
+ }
2099
+ #endif
2100
+
2101
+ return result;
2102
+ }
2103
+
2104
+ CUTLASS_HOST_DEVICE
2105
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
2106
+ Array<half_t, N> result;
2107
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2108
+
2109
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
2110
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
2111
+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
2112
+
2113
+ CUTLASS_PRAGMA_UNROLL
2114
+ for (int i = 0; i < N / 2; ++i) {
2115
+ result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair)
2116
+ : __hmin2(lhs_ptr[i], rhs_pair);
2117
+ }
2118
+
2119
+ if constexpr (N % 2) {
2120
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
2121
+
2122
+ __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs))
2123
+ : __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
2124
+
2125
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2126
+ }
2127
+
2128
+ #else
2129
+
2130
+ minimum<half_t, PropagateNaN> mn;
2131
+
2132
+ CUTLASS_PRAGMA_UNROLL
2133
+ for (int i = 0; i < N; ++i) {
2134
+ result[i] = mn(lhs[i], rhs);
2135
+ }
2136
+ #endif
2137
+
2138
+ return result;
2139
+ }
2140
+ };
2141
+
2142
+ template <int N, bool PropagateNaN>
2143
+ struct maximum<Array<half_t, N>, PropagateNaN> {
2144
+ CUTLASS_HOST_DEVICE
2145
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
2146
+ Array<half_t, N> result;
2147
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2148
+
2149
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
2150
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
2151
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
2152
+
2153
+ CUTLASS_PRAGMA_UNROLL
2154
+ for (int i = 0; i < N / 2; ++i) {
2155
+ result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i])
2156
+ : __hmax2(lhs_ptr[i], rhs_ptr[i]);
2157
+ }
2158
+
2159
+ if constexpr (N % 2) {
2160
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
2161
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
2162
+
2163
+ __half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1])
2164
+ : __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
2165
+
2166
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2167
+ }
2168
+
2169
+ #else
2170
+
2171
+ maximum<half_t,PropagateNaN> mx;
2172
+
2173
+ CUTLASS_PRAGMA_UNROLL
2174
+ for (int i = 0; i < N; ++i) {
2175
+ result[i] = mx(lhs[i], rhs[i]);
2176
+ }
2177
+ #endif
2178
+
2179
+ return result;
2180
+ }
2181
+
2182
+ CUTLASS_HOST_DEVICE
2183
+ Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
2184
+ Array<half_t, N> result;
2185
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2186
+
2187
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
2188
+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
2189
+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
2190
+
2191
+ CUTLASS_PRAGMA_UNROLL
2192
+ for (int i = 0; i < N / 2; ++i) {
2193
+ result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i])
2194
+ : __hmax2(lhs_pair, rhs_ptr[i]);
2195
+ }
2196
+
2197
+ if constexpr (N % 2) {
2198
+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
2199
+
2200
+ __half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1])
2201
+ : __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
2202
+
2203
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2204
+ }
2205
+
2206
+ #else
2207
+
2208
+ maximum<half_t,PropagateNaN> mx;
2209
+
2210
+ CUTLASS_PRAGMA_UNROLL
2211
+ for (int i = 0; i < N; ++i) {
2212
+ result[i] = mx(lhs, rhs[i]);
2213
+ }
2214
+ #endif
2215
+
2216
+ return result;
2217
+ }
2218
+
2219
+ CUTLASS_HOST_DEVICE
2220
+ Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
2221
+ Array<half_t, N> result;
2222
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2223
+
2224
+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
2225
+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
2226
+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
2227
+
2228
+ CUTLASS_PRAGMA_UNROLL
2229
+ for (int i = 0; i < N / 2; ++i) {
2230
+ result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair)
2231
+ : __hmax2(lhs_ptr[i], rhs_pair);
2232
+ }
2233
+
2234
+ if constexpr (N % 2) {
2235
+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
2236
+
2237
+ __half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs))
2238
+ : __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
2239
+
2240
+ result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
2241
+ }
2242
+
2243
+ #else
2244
+
2245
+ maximum<half_t,PropagateNaN> mx;
2246
+
2247
+ CUTLASS_PRAGMA_UNROLL
2248
+ for (int i = 0; i < N; ++i) {
2249
+ result[i] = mx(lhs[i], rhs);
2250
+ }
2251
+ #endif
2252
+
2253
+ return result;
2254
+ }
2255
+ };
2256
+
2257
+ /// Fused multiply-add
2258
+ template <int N>
2259
+ struct multiply_add<Array<bfloat16_t, N>, Array<bfloat16_t, N>, Array<bfloat16_t, N>> {
2260
+
2261
+ CUTLASS_HOST_DEVICE
2262
+ Array<bfloat16_t, N> operator()(
2263
+ Array<bfloat16_t, N> const &a,
2264
+ Array<bfloat16_t, N> const &b,
2265
+ Array<bfloat16_t, N> const &c) const {
2266
+
2267
+ Array<bfloat16_t, N> result;
2268
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2269
+
2270
+ unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
2271
+ unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
2272
+ unsigned const *b_ptr = reinterpret_cast<unsigned const *>(&b);
2273
+ unsigned const *c_ptr = reinterpret_cast<unsigned const *>(&c);
2274
+
2275
+ CUTLASS_PRAGMA_UNROLL
2276
+ for (int i = 0; i < N / 2; ++i) {
2277
+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
2278
+ : "=r"(result_ptr[i])
2279
+ : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i])
2280
+ );
2281
+ }
2282
+
2283
+ if constexpr (N % 2) {
2284
+
2285
+ uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
2286
+ uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
2287
+ uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
2288
+ uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
2289
+
2290
+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
2291
+ : "=h"(result_ptr[N - 1])
2292
+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1])
2293
+ );
2294
+ }
2295
+
2296
+ #else
2297
+
2298
+ multiply_add<bfloat16_t> op;
2299
+
2300
+ CUTLASS_PRAGMA_UNROLL
2301
+ for (int i = 0; i < N; ++i) {
2302
+ result[i] = op(a[i], b[i], c[i]);
2303
+ }
2304
+ #endif
2305
+
2306
+ return result;
2307
+ }
2308
+
2309
+ CUTLASS_HOST_DEVICE
2310
+ Array<bfloat16_t, N> operator()(
2311
+ bfloat16_t const &a,
2312
+ Array<bfloat16_t, N> const &b,
2313
+ Array<bfloat16_t, N> const &c) const {
2314
+
2315
+ Array<bfloat16_t, N> result;
2316
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2317
+
2318
+ unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
2319
+
2320
+ unsigned const *b_ptr = reinterpret_cast<unsigned const *>(&b);
2321
+ unsigned const *c_ptr = reinterpret_cast<unsigned const *>(&c);
2322
+
2323
+ unsigned a_packed = static_cast<unsigned>(a.raw());
2324
+ a_packed = (a_packed | (a_packed << 16));
2325
+
2326
+ CUTLASS_PRAGMA_UNROLL
2327
+ for (int i = 0; i < N / 2; ++i) {
2328
+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
2329
+ : "=r"(result_ptr[i])
2330
+ : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i])
2331
+ );
2332
+ }
2333
+
2334
+ if constexpr (N % 2) {
2335
+
2336
+ uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
2337
+ uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
2338
+ uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
2339
+ uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
2340
+
2341
+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
2342
+ : "=h"(result_ptr[N - 1])
2343
+ : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1])
2344
+ );
2345
+ }
2346
+
2347
+ #else
2348
+
2349
+ multiply_add<bfloat16_t> op;
2350
+
2351
+ CUTLASS_PRAGMA_UNROLL
2352
+ for (int i = 0; i < N; ++i) {
2353
+ result[i] = op(a, b[i], c[i]);
2354
+ }
2355
+ #endif
2356
+
2357
+ return result;
2358
+ }
2359
+
2360
+ CUTLASS_HOST_DEVICE
2361
+ Array<bfloat16_t, N> operator()(
2362
+ Array<bfloat16_t, N> const &a,
2363
+ bfloat16_t const &b,
2364
+ Array<bfloat16_t, N> const &c) const {
2365
+
2366
+ Array<bfloat16_t, N> result;
2367
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2368
+
2369
+ unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
2370
+
2371
+ unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
2372
+ unsigned const *c_ptr = reinterpret_cast<unsigned const *>(&c);
2373
+
2374
+ unsigned b_packed = static_cast<unsigned>(b.raw());
2375
+ b_packed = (b_packed | (b_packed << 16));
2376
+
2377
+ CUTLASS_PRAGMA_UNROLL
2378
+ for (int i = 0; i < N / 2; ++i) {
2379
+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
2380
+ : "=r"(result_ptr[i])
2381
+ : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i])
2382
+ );
2383
+ }
2384
+
2385
+ if constexpr (N % 2) {
2386
+
2387
+ uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
2388
+ uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
2389
+ uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
2390
+ uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
2391
+
2392
+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
2393
+ : "=h"(result_ptr[N - 1])
2394
+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1])
2395
+ );
2396
+ }
2397
+
2398
+ #else
2399
+
2400
+ multiply_add<bfloat16_t> op;
2401
+
2402
+ CUTLASS_PRAGMA_UNROLL
2403
+ for (int i = 0; i < N; ++i) {
2404
+ result[i] = op(a[i], b, c[i]);
2405
+ }
2406
+ #endif
2407
+
2408
+ return result;
2409
+ }
2410
+
2411
+ CUTLASS_HOST_DEVICE
2412
+ Array<bfloat16_t, N> operator()(
2413
+ Array<bfloat16_t, N> const &a,
2414
+ Array<bfloat16_t, N> const &b,
2415
+ bfloat16_t const &c) const {
2416
+
2417
+ Array<bfloat16_t, N> result;
2418
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2419
+
2420
+ unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
2421
+
2422
+ unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
2423
+ unsigned const *b_ptr = reinterpret_cast<unsigned const *>(&b);
2424
+
2425
+ unsigned c_packed = static_cast<unsigned>(c.raw());
2426
+ c_packed = (c_packed | (c_packed << 16));
2427
+
2428
+ CUTLASS_PRAGMA_UNROLL
2429
+ for (int i = 0; i < N / 2; ++i) {
2430
+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
2431
+ : "=r"(result_ptr[i])
2432
+ : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed)
2433
+ );
2434
+ }
2435
+
2436
+ if constexpr (N % 2) {
2437
+
2438
+ uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
2439
+ uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
2440
+ uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
2441
+ uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
2442
+
2443
+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
2444
+ : "=h"(result_ptr[N - 1])
2445
+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0])
2446
+ );
2447
+ }
2448
+
2449
+ #else
2450
+
2451
+ multiply_add<bfloat16_t> op;
2452
+
2453
+ CUTLASS_PRAGMA_UNROLL
2454
+ for (int i = 0; i < N; ++i) {
2455
+ result[i] = op(a[i], b[i], c);
2456
+ }
2457
+ #endif
2458
+
2459
+ return result;
2460
+ }
2461
+
2462
+ CUTLASS_HOST_DEVICE
2463
+ Array<bfloat16_t, N> operator()(
2464
+ Array<bfloat16_t, N> const &a,
2465
+ bfloat16_t const &b,
2466
+ bfloat16_t const &c) const {
2467
+
2468
+ Array<bfloat16_t, N> result;
2469
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
2470
+
2471
+ unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
2472
+
2473
+ unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
2474
+
2475
+ unsigned b_packed = static_cast<unsigned>(b.raw());
2476
+ b_packed = (b_packed | (b_packed << 16));
2477
+
2478
+ unsigned c_packed = static_cast<unsigned>(c.raw());
2479
+ c_packed = (c_packed | (c_packed << 16));
2480
+
2481
+ CUTLASS_PRAGMA_UNROLL
2482
+ for (int i = 0; i < N / 2; ++i) {
2483
+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
2484
+ : "=r"(result_ptr[i])
2485
+ : "r"(a_ptr[i]), "r"(b_packed), "r"(c_packed)
2486
+ );
2487
+ }
2488
+
2489
+ if constexpr (N % 2) {
2490
+
2491
+ uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
2492
+ uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
2493
+ uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
2494
+ uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
2495
+
2496
+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
2497
+ : "=h"(result_ptr[N - 1])
2498
+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[0])
2499
+ );
2500
+ }
2501
+
2502
+
2503
+ #else
2504
+
2505
+ multiply_add<bfloat16_t> op;
2506
+
2507
+ CUTLASS_PRAGMA_UNROLL
2508
+ for (int i = 0; i < N; ++i) {
2509
+ result[i] = op(a[i], b, c);
2510
+ }
2511
+ #endif
2512
+
2513
+ return result;
2514
+ }
2515
+ };
2516
+
2517
+
2518
+ /// bit_and
2519
+ template <int N>
2520
+ struct bit_and<Array<uint1b_t, N>> {
2521
+ CUTLASS_HOST_DEVICE
2522
+ Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a, Array<uint1b_t, N> const &b) const {
2523
+ using ArrayType = Array<uint1b_t, N>;
2524
+ using Storage = typename ArrayType::Storage;
2525
+ ArrayType result;
2526
+
2527
+ Storage *result_data = result.raw_data();
2528
+ Storage const *a_data = a.raw_data();
2529
+ Storage const *b_data = b.raw_data();
2530
+
2531
+ CUTLASS_PRAGMA_UNROLL
2532
+ for (int i = 0; i < ArrayType::kStorageElements; ++i) {
2533
+ result_data[i] = (a_data[i] & b_data[i]);
2534
+ }
2535
+
2536
+ return result;
2537
+ }
2538
+ };
2539
+
2540
+
2541
+ /// bit_or
2542
+ template <int N>
2543
+ struct bit_or<Array<uint1b_t, N>> {
2544
+ CUTLASS_HOST_DEVICE
2545
+ Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a, Array<uint1b_t, N> const &b) const {
2546
+ using ArrayType = Array<uint1b_t, N>;
2547
+ using Storage = typename ArrayType::Storage;
2548
+ ArrayType result;
2549
+
2550
+ Storage *result_data = result.raw_data();
2551
+ Storage const *a_data = a.raw_data();
2552
+ Storage const *b_data = b.raw_data();
2553
+
2554
+ CUTLASS_PRAGMA_UNROLL
2555
+ for (int i = 0; i < ArrayType::kStorageElements; ++i) {
2556
+ result_data[i] = (a_data[i] | b_data[i]);
2557
+ }
2558
+
2559
+ return result;
2560
+ }
2561
+ };
2562
+
2563
+
2564
+ /// bit_not
2565
+ template <int N>
2566
+ struct bit_not<Array<uint1b_t, N>> {
2567
+ CUTLASS_HOST_DEVICE
2568
+ Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a) const {
2569
+ using ArrayType = Array<uint1b_t, N>;
2570
+ using Storage = typename ArrayType::Storage;
2571
+ ArrayType result;
2572
+
2573
+ Storage *result_data = result.raw_data();
2574
+ Storage const *a_data = a.raw_data();
2575
+
2576
+ CUTLASS_PRAGMA_UNROLL
2577
+ for (int i = 0; i < ArrayType::kStorageElements; ++i) {
2578
+ result_data[i] = (~a_data[i]);
2579
+ }
2580
+
2581
+ return result;
2582
+ }
2583
+ };
2584
+
2585
+ /// bit_xor
2586
+ template <int N>
2587
+ struct bit_xor<Array<uint1b_t, N>> {
2588
+ CUTLASS_HOST_DEVICE
2589
+ Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a, Array<uint1b_t, N> const &b) const {
2590
+ using ArrayType = Array<uint1b_t, N>;
2591
+ using Storage = typename ArrayType::Storage;
2592
+ ArrayType result;
2593
+
2594
+ Storage *result_data = result.raw_data();
2595
+ Storage const *a_data = a.raw_data();
2596
+ Storage const *b_data = b.raw_data();
2597
+
2598
+ CUTLASS_PRAGMA_UNROLL
2599
+ for (int i = 0; i < ArrayType::kStorageElements; ++i) {
2600
+ result_data[i] = (a_data[i] ^ b_data[i]);
2601
+ }
2602
+
2603
+ return result;
2604
+ }
2605
+ };
2606
+
2607
+ /// Fused and-popc-add
2608
+ template <typename T, int N>
2609
+ struct and_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
2610
+ CUTLASS_HOST_DEVICE
2611
+ Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
2612
+ Array<T, N> result;
2613
+ and_popc_add<T> scalar_op;
2614
+
2615
+ CUTLASS_PRAGMA_UNROLL
2616
+ for (int i = 0; i < N; ++i) {
2617
+ result[i] = scalar_op(a[i], b[i], c[i]);
2618
+ }
2619
+
2620
+ return result;
2621
+ }
2622
+
2623
+ CUTLASS_HOST_DEVICE
2624
+ Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
2625
+ Array<T, N> result;
2626
+ and_popc_add<T> scalar_op;
2627
+
2628
+ CUTLASS_PRAGMA_UNROLL
2629
+ for (int i = 0; i < N; ++i) {
2630
+ result[i] = scalar_op(a[i], scalar, c[i]);
2631
+ }
2632
+
2633
+ return result;
2634
+ }
2635
+
2636
+ CUTLASS_HOST_DEVICE
2637
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
2638
+ Array<T, N> result;
2639
+ and_popc_add<T> scalar_op;
2640
+
2641
+ CUTLASS_PRAGMA_UNROLL
2642
+ for (int i = 0; i < N; ++i) {
2643
+ result[i] = scalar_op(scalar, b[i], c[i]);
2644
+ }
2645
+
2646
+ return result;
2647
+ }
2648
+ };
2649
+
2650
+
2651
+ /// Fused or-popc-add
2652
+ template <typename T, int N>
2653
+ struct or_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
2654
+ CUTLASS_HOST_DEVICE
2655
+ Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
2656
+ Array<T, N> result;
2657
+ or_popc_add<T> scalar_op;
2658
+
2659
+ CUTLASS_PRAGMA_UNROLL
2660
+ for (int i = 0; i < N; ++i) {
2661
+ result[i] = scalar_op(a[i], b[i], c[i]);
2662
+ }
2663
+
2664
+ return result;
2665
+ }
2666
+
2667
+ CUTLASS_HOST_DEVICE
2668
+ Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
2669
+ Array<T, N> result;
2670
+ or_popc_add<T> scalar_op;
2671
+
2672
+ CUTLASS_PRAGMA_UNROLL
2673
+ for (int i = 0; i < N; ++i) {
2674
+ result[i] = scalar_op(a[i], scalar, c[i]);
2675
+ }
2676
+
2677
+ return result;
2678
+ }
2679
+
2680
+ CUTLASS_HOST_DEVICE
2681
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
2682
+ Array<T, N> result;
2683
+ or_popc_add<T> scalar_op;
2684
+
2685
+ CUTLASS_PRAGMA_UNROLL
2686
+ for (int i = 0; i < N; ++i) {
2687
+ result[i] = scalar_op(scalar, b[i], c[i]);
2688
+ }
2689
+
2690
+ return result;
2691
+ }
2692
+ };
2693
+
2694
+ /// Fused xor-popc-add
2695
+ template <typename T, int N>
2696
+ struct xor_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
2697
+ CUTLASS_HOST_DEVICE
2698
+ Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
2699
+ Array<T, N> result;
2700
+ xor_popc_add<T> scalar_op;
2701
+
2702
+ CUTLASS_PRAGMA_UNROLL
2703
+ for (int i = 0; i < N; ++i) {
2704
+ result[i] = scalar_op(a[i], b[i], c[i]);
2705
+ }
2706
+
2707
+ return result;
2708
+ }
2709
+
2710
+ CUTLASS_HOST_DEVICE
2711
+ Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
2712
+ Array<T, N> result;
2713
+ xor_popc_add<T> scalar_op;
2714
+
2715
+ CUTLASS_PRAGMA_UNROLL
2716
+ for (int i = 0; i < N; ++i) {
2717
+ result[i] = scalar_op(a[i], scalar, c[i]);
2718
+ }
2719
+
2720
+ return result;
2721
+ }
2722
+
2723
+ CUTLASS_HOST_DEVICE
2724
+ Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
2725
+ Array<T, N> result;
2726
+ xor_popc_add<T> scalar_op;
2727
+
2728
+ CUTLASS_PRAGMA_UNROLL
2729
+ for (int i = 0; i < N; ++i) {
2730
+ result[i] = scalar_op(scalar, b[i], c[i]);
2731
+ }
2732
+
2733
+ return result;
2734
+ }
2735
+ };
2736
+
2737
+
2738
+ /////////////////////////////////////////////////////////////////////////////////////////////////
2739
+ // Operator overloads
2740
+ /////////////////////////////////////////////////////////////////////////////////////////////////
2741
+
2742
+ template <typename T, int N>
2743
+ CUTLASS_HOST_DEVICE
2744
+ Array<T, N> operator+(Array<T, N> const &lhs, Array<T, N> const &rhs) {
2745
+ plus<Array<T, N>> op;
2746
+ return op(lhs, rhs);
2747
+ }
2748
+
2749
+ template <typename T, int N>
2750
+ CUTLASS_HOST_DEVICE
2751
+ Array<T, N> operator+(T const &lhs, Array<T, N> const &rhs) {
2752
+ plus<Array<T, N>> op;
2753
+ return op(lhs, rhs);
2754
+ }
2755
+
2756
+ template <typename T, int N>
2757
+ CUTLASS_HOST_DEVICE
2758
+ Array<T, N> operator+(Array<T, N> const &lhs, T const &rhs) {
2759
+ plus<Array<T, N>> op;
2760
+ return op(lhs, rhs);
2761
+ }
2762
+
2763
+ template <typename T, int N>
2764
+ CUTLASS_HOST_DEVICE
2765
+ Array<T, N> operator-(Array<T, N> const &lhs, Array<T, N> const &rhs) {
2766
+ minus<Array<T, N>> op;
2767
+ return op(lhs, rhs);
2768
+ }
2769
+
2770
+ template <typename T, int N>
2771
+ CUTLASS_HOST_DEVICE
2772
+ Array<T, N> operator-(Array<T, N> const &lhs) {
2773
+ negate<Array<T, N>> op;
2774
+ return op(lhs);
2775
+ }
2776
+
2777
+ template <typename T, int N>
2778
+ CUTLASS_HOST_DEVICE
2779
+ Array<T, N> operator*(Array<T, N> const &lhs, Array<T, N> const &rhs) {
2780
+ multiplies<Array<T, N>> op;
2781
+ return op(lhs, rhs);
2782
+ }
2783
+
2784
+ template <typename T, int N>
2785
+ CUTLASS_HOST_DEVICE
2786
+ Array<T, N> operator*(T lhs, Array<T, N> const &rhs) {
2787
+ multiplies<Array<T, N>> op;
2788
+ return op(lhs, rhs);
2789
+ }
2790
+
2791
+ template <typename T, int N>
2792
+ CUTLASS_HOST_DEVICE
2793
+ Array<T, N> operator*(Array<T, N> const &lhs, T rhs) {
2794
+ multiplies<Array<T, N>> op;
2795
+ return op(lhs, rhs);
2796
+ }
2797
+
2798
+ template <typename T, int N>
2799
+ CUTLASS_HOST_DEVICE
2800
+ Array<T, N> operator/(Array<T, N> const &lhs, Array<T, N> const &rhs) {
2801
+ divides<Array<T, N>> op;
2802
+ return op(lhs, rhs);
2803
+ }
2804
+
2805
+ template <typename T, int N>
2806
+ CUTLASS_HOST_DEVICE
2807
+ Array<T, N> fma(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) {
2808
+ multiply_add<Array<T, N>> op;
2809
+ return op(a, b, c);
2810
+ }
2811
+
2812
+ template <typename T, int N>
2813
+ CUTLASS_HOST_DEVICE
2814
+ Array<T, N> fma(T a, Array<T, N> const &b, Array<T, N> const &c) {
2815
+ multiply_add<Array<T, N>> op;
2816
+ return op(a, b, c);
2817
+ }
2818
+
2819
+ template <typename T, int N>
2820
+ CUTLASS_HOST_DEVICE
2821
+ Array<T, N> fma(Array<T, N> const &a, T b, Array<T, N> const &c) {
2822
+ multiply_add<Array<T, N>> op;
2823
+ return op(a, b, c);
2824
+ }
2825
+
2826
+ template <typename T, int N>
2827
+ CUTLASS_HOST_DEVICE
2828
+ Array<T, N> fma(Array<T, N> const &a, Array<T, N> const &b, T c) {
2829
+ multiply_add<Array<T, N>> op;
2830
+ return op(a, b, c);
2831
+ }
2832
+
2833
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
2834
+
2835
+
2836
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
2837
+ // AlignedArray
2838
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
2839
+
2840
+ /// Aligned array type
2841
+ template <
2842
+ /// Element type
2843
+ typename T,
2844
+ /// Number of elements in the array
2845
+ int N,
2846
+ /// Alignment requirement in bytes
2847
+ int Alignment = ( sizeof_bits<T>::value * N + 7 ) / 8
2848
+ >
2849
+ class alignas(Alignment) AlignedArray: public Array<T, N> {
2850
+ public:
2851
+
2852
+ };
2853
+
2854
+ } // namespace cutlass
2855
+
2856
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
2857
+
2858
+ #include "cutlass/array_subbyte.h"
2859
+
2860
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing warp-level matrix multiply-accumulate operations.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ namespace cutlass {
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ /// Array holding planar complex elements
47
+ template <typename Element_, int N>
48
+ struct ArrayPlanarComplex {
49
+
50
+ /// Underlying real element
51
+ using Element = Element_;
52
+
53
+ /// Number of logical elements
54
+ static constexpr size_t kElements = N;
55
+
56
+ /// Underlying Fragment of real-valued elemenets
57
+ using ArrayReal = cutlass::Array<Element, N>;
58
+
59
+ public:
60
+ /// Fragment of real-valued elements representing the real part
61
+ ArrayReal real;
62
+
63
+ /// Fragment of real-valued elements representing the imaginary part
64
+ ArrayReal imag;
65
+
66
+ public:
67
+ /// Sets the array to zero efficiently
68
+ CUTLASS_HOST_DEVICE
69
+ void clear() {
70
+ real.clear();
71
+ imag.clear();
72
+ }
73
+ };
74
+
75
+ /////////////////////////////////////////////////////////////////////////////////////////////////
76
+
77
+ /// Helper to deduce template arguments
78
+ template <typename Element, int N>
79
+ CUTLASS_HOST_DEVICE
80
+ ArrayPlanarComplex<Element, N>
81
+ make_ArrayPlanarComplex(Array<Element, N> const &real, Array<Element, N> const &imag) {
82
+ return ArrayPlanarComplex<Element, N>{real, imag};
83
+ }
84
+
85
+ /////////////////////////////////////////////////////////////////////////////////////////////////
86
+
87
+ } // namespace cutlass
88
+
89
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
33
+ and is safe to use in a union.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/platform/platform.h"
41
+
42
+ namespace cutlass {
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ /// Statically sized array for any data type
47
+ template <
48
+ typename T,
49
+ int N
50
+ >
51
+ struct Array<T, N, false> {
52
+ static constexpr int kSizeBits = sizeof_bits<T>::value * N;
53
+
54
+ /// Storage type
55
+ using Storage = typename platform::conditional<
56
+ ((kSizeBits % 32) != 0),
57
+ typename platform::conditional<
58
+ ((kSizeBits % 16) != 0),
59
+ uint8_t,
60
+ uint16_t
61
+ >::type,
62
+ uint32_t
63
+ >::type;
64
+
65
+ /// Element type
66
+ using Element = T;
67
+
68
+ /// Number of logical elements per stored object
69
+ static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits<T>::value;
70
+
71
+ /// Number of storage elements
72
+ static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;
73
+
74
+ /// Number of logical elements
75
+ static constexpr size_t kElements = N;
76
+
77
+ /// Bitmask for covering one item
78
+ static constexpr Storage kMask = ((Storage(1) << sizeof_bits<T>::value) - 1);
79
+
80
+ //
81
+ // C++ standard members with pointer types removed
82
+ //
83
+
84
+ typedef T value_type;
85
+ typedef size_t size_type;
86
+ typedef ptrdiff_t difference_type;
87
+ typedef value_type *pointer;
88
+ typedef value_type const *const_pointer;
89
+
90
+ //
91
+ // References
92
+ //
93
+
94
+ /// Reference object inserts or extracts sub-byte items
95
+ class reference {
96
+ /// Pointer to storage element
97
+ Storage *ptr_{nullptr};
98
+
99
+ /// Index into elements packed into Storage object
100
+ int idx_{0};
101
+
102
+ public:
103
+
104
+ reference() = default;
105
+
106
+ /// Ctor
107
+ CUTLASS_HOST_DEVICE
108
+ reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
109
+
110
+ /// Assignment
111
+ CUTLASS_HOST_DEVICE
112
+ reference &operator=(T x) {
113
+ // `*ptr_ & kUpdateMask` will read ptr_ before write to it
114
+ // This means code pattern like
115
+ //
116
+ // ```cpp
117
+ // Array<half_t, N> result;
118
+ // result[0] = xxx;
119
+ // ```
120
+ //
121
+ // Will leads to compiler warning on use of uninitialized member variable. Although we know
122
+ // this read of uninitialized member variable is harmeless.
123
+
124
+ #if defined(__clang__)
125
+ # pragma clang diagnostic push
126
+ # pragma clang diagnostic ignored "-Wuninitialized"
127
+ #elif defined(__GNUC__)
128
+ # pragma GCC diagnostic push
129
+ # pragma GCC diagnostic ignored "-Wuninitialized"
130
+ # pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
131
+ #endif
132
+
133
+ Storage item = (reinterpret_cast<Storage const &>(x) & kMask);
134
+
135
+ Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits<T>::value)));
136
+
137
+ *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits<T>::value)));
138
+
139
+ #if defined(__clang__)
140
+ # pragma clang diagnostic pop
141
+ #elif defined(__GNUC__)
142
+ # pragma GCC diagnostic pop
143
+ #endif
144
+
145
+ return *this;
146
+ }
147
+
148
+ CUTLASS_HOST_DEVICE
149
+ T get() const {
150
+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask);
151
+ return reinterpret_cast<T const &>(item);
152
+ }
153
+
154
+ /// Extract
155
+ CUTLASS_HOST_DEVICE
156
+ operator T() const {
157
+ return get();
158
+ }
159
+
160
+ /// Explicit cast to int
161
+ CUTLASS_HOST_DEVICE
162
+ explicit operator int() const {
163
+ return int(get());
164
+ }
165
+
166
+ /// Explicit cast to float
167
+ CUTLASS_HOST_DEVICE
168
+ explicit operator float() const {
169
+ return float(get());
170
+ }
171
+ };
172
+
173
+ /// Reference object extracts sub-byte items
174
+ class const_reference {
175
+
176
+ /// Pointer to storage element
177
+ Storage const *ptr_{nullptr};
178
+
179
+ /// Index into elements packed into Storage object
180
+ int idx_{0};
181
+
182
+ public:
183
+
184
+ const_reference() = default;
185
+
186
+ /// Ctor
187
+ CUTLASS_HOST_DEVICE
188
+ const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
189
+
190
+ CUTLASS_HOST_DEVICE
191
+ const T get() const {
192
+ Storage item = (*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask;
193
+ return reinterpret_cast<T const &>(item);
194
+ }
195
+
196
+ /// Extract
197
+ CUTLASS_HOST_DEVICE
198
+ operator T() const {
199
+ Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits<T>::value)) & kMask);
200
+ return reinterpret_cast<T const &>(item);
201
+ }
202
+
203
+ /// Explicit cast to int
204
+ CUTLASS_HOST_DEVICE
205
+ explicit operator int() const {
206
+ return int(get());
207
+ }
208
+
209
+ /// Explicit cast to float
210
+ CUTLASS_HOST_DEVICE
211
+ explicit operator float() const {
212
+ return float(get());
213
+ }
214
+ };
215
+
216
+ //
217
+ // Iterators
218
+ //
219
+
220
+ /// Bidirectional iterator over elements
221
+ class iterator {
222
+
223
+ /// Pointer to storage element
224
+ Storage *ptr_{nullptr};
225
+
226
+ /// Index into elements packed into Storage object
227
+ int idx_{0};
228
+
229
+ public:
230
+
231
+ iterator() = default;
232
+
233
+ CUTLASS_HOST_DEVICE
234
+ iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
235
+
236
+ CUTLASS_HOST_DEVICE
237
+ iterator &operator++() {
238
+ ++idx_;
239
+ if (idx_ == kElementsPerStoredItem) {
240
+ ++ptr_;
241
+ idx_ = 0;
242
+ }
243
+ return *this;
244
+ }
245
+
246
+ CUTLASS_HOST_DEVICE
247
+ iterator &operator--() {
248
+ if (!idx_) {
249
+ --ptr_;
250
+ idx_ = kElementsPerStoredItem - 1;
251
+ }
252
+ else {
253
+ --idx_;
254
+ }
255
+ return *this;
256
+ }
257
+
258
+ CUTLASS_HOST_DEVICE
259
+ iterator operator++(int) {
260
+ iterator ret(*this);
261
+ ++idx_;
262
+ if (idx_ == kElementsPerStoredItem) {
263
+ ++ptr_;
264
+ idx_ = 0;
265
+ }
266
+ return ret;
267
+ }
268
+
269
+ CUTLASS_HOST_DEVICE
270
+ iterator operator--(int) {
271
+ iterator ret(*this);
272
+ if (!idx_) {
273
+ --ptr_;
274
+ idx_ = kElementsPerStoredItem - 1;
275
+ }
276
+ else {
277
+ --idx_;
278
+ }
279
+ return ret;
280
+ }
281
+
282
+ CUTLASS_HOST_DEVICE
283
+ reference operator*() const {
284
+ return reference(ptr_, idx_);
285
+ }
286
+
287
+ CUTLASS_HOST_DEVICE
288
+ bool operator==(iterator const &other) const {
289
+ return ptr_ == other.ptr_ && idx_ == other.idx_;
290
+ }
291
+
292
+ CUTLASS_HOST_DEVICE
293
+ bool operator!=(iterator const &other) const {
294
+ return !(*this == other);
295
+ }
296
+ };
297
+
298
+ /// Bidirectional constant iterator over elements
299
+ class const_iterator {
300
+
301
+ /// Pointer to storage element
302
+ Storage const *ptr_{nullptr};
303
+
304
+ /// Index into elements packed into Storage object
305
+ int idx_{0};
306
+
307
+ public:
308
+
309
+ const_iterator() = default;
310
+
311
+ CUTLASS_HOST_DEVICE
312
+ const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
313
+
314
+ CUTLASS_HOST_DEVICE
315
+ iterator &operator++() {
316
+ ++idx_;
317
+ if (idx_ == kElementsPerStoredItem) {
318
+ ++ptr_;
319
+ idx_ = 0;
320
+ }
321
+ return *this;
322
+ }
323
+
324
+ CUTLASS_HOST_DEVICE
325
+ iterator &operator--() {
326
+ if (!idx_) {
327
+ --ptr_;
328
+ idx_ = kElementsPerStoredItem - 1;
329
+ }
330
+ else {
331
+ --idx_;
332
+ }
333
+ return *this;
334
+ }
335
+
336
+ CUTLASS_HOST_DEVICE
337
+ iterator operator++(int) {
338
+ iterator ret(*this);
339
+ ++idx_;
340
+ if (idx_ == kElementsPerStoredItem) {
341
+ ++ptr_;
342
+ idx_ = 0;
343
+ }
344
+ return ret;
345
+ }
346
+
347
+ CUTLASS_HOST_DEVICE
348
+ iterator operator--(int) {
349
+ iterator ret(*this);
350
+ if (!idx_) {
351
+ --ptr_;
352
+ idx_ = kElementsPerStoredItem - 1;
353
+ }
354
+ else {
355
+ --idx_;
356
+ }
357
+ return ret;
358
+ }
359
+
360
+ CUTLASS_HOST_DEVICE
361
+ const_reference operator*() const {
362
+ return const_reference(ptr_, idx_);
363
+ }
364
+
365
+ CUTLASS_HOST_DEVICE
366
+ bool operator==(iterator const &other) const {
367
+ return ptr_ == other.ptr_ && idx_ == other.idx_;
368
+ }
369
+
370
+ CUTLASS_HOST_DEVICE
371
+ bool operator!=(iterator const &other) const {
372
+ return !(*this == other);
373
+ }
374
+ };
375
+
376
+ /// Bidirectional iterator over elements
377
+ class reverse_iterator {
378
+
379
+ /// Pointer to storage element
380
+ Storage *ptr_{nullptr};
381
+
382
+ /// Index into elements packed into Storage object
383
+ int idx_{0};
384
+
385
+ public:
386
+
387
+ reverse_iterator() = default;
388
+
389
+ CUTLASS_HOST_DEVICE
390
+ reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
391
+ };
392
+
393
+ /// Bidirectional constant iterator over elements
394
+ class const_reverse_iterator {
395
+
396
+ /// Pointer to storage element
397
+ Storage const *ptr_{nullptr};
398
+
399
+ /// Index into elements packed into Storage object
400
+ int idx_{0};
401
+
402
+ public:
403
+
404
+ const_reverse_iterator() = default;
405
+
406
+ CUTLASS_HOST_DEVICE
407
+ const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
408
+ };
409
+
410
+ /// Efficient clear method
411
+ CUTLASS_HOST_DEVICE
412
+ void clear() {
413
+
414
+ CUTLASS_PRAGMA_UNROLL
415
+ for (int i = 0; i < int(kStorageElements); ++i) {
416
+ storage[i] = Storage(0);
417
+ }
418
+ }
419
+
420
+ CUTLASS_HOST_DEVICE
421
+ reference at(size_type pos) {
422
+ return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
423
+ }
424
+
425
+ CUTLASS_HOST_DEVICE
426
+ const_reference at(size_type pos) const {
427
+ return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
428
+ }
429
+
430
+ CUTLASS_HOST_DEVICE
431
+ reference operator[](size_type pos) {
432
+ return at(pos);
433
+ }
434
+
435
+ CUTLASS_HOST_DEVICE
436
+ const_reference operator[](size_type pos) const {
437
+ return at(pos);
438
+ }
439
+
440
+ CUTLASS_HOST_DEVICE
441
+ reference front() {
442
+ return at(0);
443
+ }
444
+
445
+ CUTLASS_HOST_DEVICE
446
+ const_reference front() const {
447
+ return at(0);
448
+ }
449
+
450
+ CUTLASS_HOST_DEVICE
451
+ reference back() {
452
+ return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
453
+ }
454
+
455
+ CUTLASS_HOST_DEVICE
456
+ const_reference back() const {
457
+ return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
458
+ }
459
+
460
+ CUTLASS_HOST_DEVICE
461
+ pointer data() {
462
+ return reinterpret_cast<pointer>(storage);
463
+ }
464
+
465
+ CUTLASS_HOST_DEVICE
466
+ const_pointer data() const {
467
+ return reinterpret_cast<const_pointer>(storage);
468
+ }
469
+
470
+ CUTLASS_HOST_DEVICE
471
+ Storage * raw_data() {
472
+ return storage;
473
+ }
474
+
475
+ CUTLASS_HOST_DEVICE
476
+ Storage const * raw_data() const {
477
+ return storage;
478
+ }
479
+
480
+ CUTLASS_HOST_DEVICE
481
+ constexpr bool empty() const {
482
+ return !kElements;
483
+ }
484
+
485
+ CUTLASS_HOST_DEVICE
486
+ constexpr size_type size() const {
487
+ return kElements;
488
+ }
489
+
490
+ CUTLASS_HOST_DEVICE
491
+ constexpr size_type max_size() const {
492
+ return kElements;
493
+ }
494
+
495
+ CUTLASS_HOST_DEVICE
496
+ void fill(T const &value) {
497
+
498
+ CUTLASS_PRAGMA_UNROLL
499
+ for (int i = 0; i < kElementsPerStoredItem; ++i) {
500
+ reference ref(storage, i);
501
+ ref = value;
502
+ }
503
+
504
+ CUTLASS_PRAGMA_UNROLL
505
+ for (int i = 1; i < kStorageElements; ++i) {
506
+ storage[i] = storage[0];
507
+ }
508
+ }
509
+
510
+ CUTLASS_HOST_DEVICE
511
+ iterator begin() {
512
+ return iterator(storage);
513
+ }
514
+
515
+ CUTLASS_HOST_DEVICE
516
+ const_iterator cbegin() const {
517
+ return const_iterator(storage);
518
+ }
519
+
520
+ CUTLASS_HOST_DEVICE
521
+ iterator end() {
522
+ return iterator(storage + kStorageElements);
523
+ }
524
+
525
+ CUTLASS_HOST_DEVICE
526
+ const_iterator cend() const {
527
+ return const_iterator(storage + kStorageElements);
528
+ }
529
+
530
+ CUTLASS_HOST_DEVICE
531
+ reverse_iterator rbegin() {
532
+ return reverse_iterator(storage + kStorageElements);
533
+ }
534
+
535
+ CUTLASS_HOST_DEVICE
536
+ const_reverse_iterator crbegin() const {
537
+ return const_reverse_iterator(storage + kStorageElements);
538
+ }
539
+
540
+ CUTLASS_HOST_DEVICE
541
+ reverse_iterator rend() {
542
+ return reverse_iterator(storage);
543
+ }
544
+
545
+ CUTLASS_HOST_DEVICE
546
+ const_reverse_iterator crend() const {
547
+ return const_reverse_iterator(storage);
548
+ }
549
+
550
+ private:
551
+ /// Internal storage
552
+ Storage storage[kStorageElements];
553
+ };
554
+
555
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
556
+
557
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
558
+
559
+ } // namespace cutlass
560
+
561
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implementation of a CTA-wide barrier for inter-CTA synchronization.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/arch/barrier.h"
39
+
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ namespace cutlass {
43
+
44
+ namespace detail {
45
+
46
+ //
47
+ // Utilities for abstracting synchronization methods for barriers
48
+ //
49
+
50
+ struct SyncthreadsSync {
51
+ CUTLASS_DEVICE
52
+ static void sync() {
53
+ __syncthreads();
54
+ }
55
+ };
56
+
57
+ struct SyncwarpSync {
58
+ CUTLASS_DEVICE
59
+ static void sync() {
60
+ __syncwarp();
61
+ }
62
+ };
63
+
64
+ template <
65
+ int ThreadCount,
66
+ int BarrierId
67
+ >
68
+ struct NamedBarrierSync {
69
+ CUTLASS_DEVICE
70
+ static void sync() {
71
+ cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast<arch::ReservedNamedBarriers>(BarrierId));
72
+ }
73
+ };
74
+
75
+ } // namepspace detail
76
+
77
+ /////////////////////////////////////////////////////////////////////////////////////////////////
78
+
79
+ /// Group or CTA-wide semaphore for inter-CTA synchronization.
80
+ template <class Sync>
81
+ struct GenericBarrier {
82
+
83
+ public:
84
+
85
+ /// Flag type
86
+ using T = int;
87
+
88
+ /// Initial flag value
89
+ static const T INIT = 0;
90
+
91
+
92
+ protected:
93
+
94
+ /// Load flag, as a strong acquire operation (int specialization)
95
+ CUTLASS_DEVICE
96
+ static int ld_acquire(int *ptr)
97
+ {
98
+ int state = 0;
99
+
100
+ #if (__CUDA_ARCH__ >= 700)
101
+ /// SM70 and newer use memory consistency qualifiers
102
+
103
+ // Acquire pattern using acquire modifier
104
+ asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
105
+
106
+ #else
107
+ asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
108
+ #endif // (__CUDA_ARCH__ >= 700)
109
+
110
+ return state;
111
+ }
112
+
113
+
114
+ /// Reduce into flag, with release pattern (int specialization)
115
+ CUTLASS_DEVICE
116
+ static void red_release(int *ptr, int val)
117
+ {
118
+ #if (__CUDA_ARCH__ >= 700)
119
+ /// SM70 and newer use memory consistency qualifiers
120
+
121
+ // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data
122
+ // that was weakly-written by other threads prior to the last syncthreads)
123
+ asm volatile ("fence.acq_rel.gpu;\n");
124
+ asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
125
+
126
+ #else
127
+ __threadfence();
128
+ atomicAdd(ptr, val);
129
+ #endif // (__CUDA_ARCH__ >= 700)
130
+ }
131
+
132
+
133
+ public:
134
+
135
+ /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
136
+ CUTLASS_DEVICE
137
+ static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count)
138
+ {
139
+ T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
140
+
141
+ if (thread_idx == 0)
142
+ {
143
+ // Spin-loop
144
+ #pragma unroll 1
145
+ while(ld_acquire(flag_ptr) < count) {}
146
+ }
147
+
148
+ Sync::sync();
149
+ }
150
+
151
+ /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
152
+ CUTLASS_DEVICE
153
+ static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1)
154
+ {
155
+ T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
156
+
157
+ if (thread_idx == 0)
158
+ {
159
+ // Spin-loop
160
+ #pragma unroll 1
161
+ while(ld_acquire(flag_ptr) != val) {}
162
+ }
163
+ Sync::sync();
164
+ }
165
+
166
+ /// Uses thread[0] to wait for the specified count of signals on the given flag counter
167
+ CUTLASS_DEVICE
168
+ static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
169
+ T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
170
+
171
+ if (thread_idx == 0)
172
+ {
173
+ // Spin-loop
174
+ #pragma unroll 1
175
+ while(atomicCAS(flag_ptr, val, 0) != val) {}
176
+ }
177
+
178
+ Sync::sync();
179
+ }
180
+
181
+ /// Increment the arrival count for a flag
182
+ CUTLASS_DEVICE
183
+ static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1)
184
+ {
185
+ T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
186
+
187
+ Sync::sync();
188
+
189
+ if (thread_idx == 0)
190
+ {
191
+ red_release(flag_ptr, val);
192
+ }
193
+ }
194
+
195
+
196
+ /// Increment the arrival counts for a range of flags
197
+ CUTLASS_DEVICE
198
+ static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1)
199
+ {
200
+ int flag_idx = first_flag_idx + thread_idx;
201
+ T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
202
+
203
+ // Barrier to make sure all other threads in group have written their data
204
+ Sync::sync();
205
+
206
+ // Select threads increment their flags
207
+ if (thread_idx < count) {
208
+ red_release(flag_ptr, val);
209
+ }
210
+ }
211
+ };
212
+
213
+ using Barrier = GenericBarrier<detail::SyncthreadsSync>;
214
+
215
+ /////////////////////////////////////////////////////////////////////////////////////////////////
216
+
217
+ /** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing
218
+ * runtime index values to be used to call into named barriers with compile-time-constant IDs.
219
+ *
220
+ * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID
221
+ * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into
222
+ * @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type
223
+ **/
224
+ template <
225
+ uint32_t ThreadCount_,
226
+ uint32_t Offset = 0,
227
+ uint32_t MaxNumNamedBarriers = 16
228
+ >
229
+ struct NamedBarrierManager {
230
+
231
+ static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers);
232
+ static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15");
233
+
234
+ // Number of threads participating in the barrier
235
+ static constexpr uint32_t ThreadCount = ThreadCount_;
236
+
237
+ template <uint32_t BarrierId>
238
+ using BarrierSync = cutlass::GenericBarrier<cutlass::detail::NamedBarrierSync<ThreadCount, BarrierId>>;
239
+
240
+ // Underlying type used by all barriers for synchronization. Does not depend on
241
+ // template parameter BarrierId, so passing in 0 suffices.
242
+ using T = typename BarrierSync<0>::T;
243
+
244
+ using IntegerSequence = cute::make_integer_sequence<uint32_t, MaxNumNamedBarriers>;
245
+
246
+ CUTLASS_DEVICE
247
+ static
248
+ void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) {
249
+ wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{});
250
+ }
251
+
252
+ CUTLASS_DEVICE
253
+ static void
254
+ wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
255
+ wait_eq_helper<false>(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{});
256
+ }
257
+
258
+ CUTLASS_DEVICE
259
+ static void
260
+ wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
261
+ wait_eq_helper<true>(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{});
262
+ }
263
+
264
+ CUTLASS_DEVICE
265
+ static void
266
+ arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
267
+ arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{});
268
+ }
269
+
270
+ CUTLASS_DEVICE
271
+ static void
272
+ arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
273
+ arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{});
274
+ }
275
+
276
+ private:
277
+ CUTLASS_DEVICE
278
+ static void
279
+ check_barrier_in_range([[maybe_unused]] uint32_t idx) {
280
+ assert((idx < MaxNumNamedBarriers) && "Index exceeds barrier count");
281
+ }
282
+
283
+ template <uint32_t... Idx>
284
+ CUTLASS_DEVICE
285
+ static void
286
+ wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence<uint32_t, Idx...>) {
287
+ check_barrier_in_range(idx);
288
+ ((Idx == idx && (BarrierSync<Idx + Offset>::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...);
289
+ }
290
+
291
+ template <bool Reset, uint32_t... Idx>
292
+ CUTLASS_DEVICE
293
+ static void
294
+ wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence<uint32_t, Idx...>) {
295
+ check_barrier_in_range(idx);
296
+ if constexpr (Reset) {
297
+ ((Idx == idx && (BarrierSync<Idx + Offset>::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...);
298
+ }
299
+ else {
300
+ ((Idx == idx && (BarrierSync<Idx + Offset>::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...);
301
+ }
302
+ }
303
+
304
+ template <uint32_t... Idx>
305
+ CUTLASS_DEVICE
306
+ static void
307
+ arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence<uint32_t, Idx...>) {
308
+ check_barrier_in_range(idx);
309
+ ((Idx == idx && (BarrierSync<Idx + Offset>::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...);
310
+ }
311
+
312
+ template <uint32_t... Idx>
313
+ CUTLASS_DEVICE
314
+ static void
315
+ arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence<uint32_t, Idx...>) {
316
+ check_barrier_in_range(idx);
317
+ ((Idx == idx && (BarrierSync<Idx + Offset>::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...);
318
+ }
319
+ };
320
+
321
+ /////////////////////////////////////////////////////////////////////////////////////////////////
322
+
323
+ /** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads)
324
+ * via an API that mirrors that of NamedBarrierManager
325
+ *
326
+ * @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization
327
+ **/
328
+ template <
329
+ class Synchronizer,
330
+ uint32_t ThreadCount_
331
+ >
332
+ struct SyncManager {
333
+
334
+ // Number of threads participating in the barrier
335
+ static constexpr uint32_t ThreadCount = ThreadCount_;
336
+
337
+ using BarrierSync = cutlass::GenericBarrier<Synchronizer>;
338
+
339
+ // Underlying type used by all barriers for synchronization.
340
+ using T = typename BarrierSync::T;
341
+
342
+ CUTLASS_DEVICE
343
+ static
344
+ void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) {
345
+ BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count);
346
+ }
347
+
348
+ CUTLASS_DEVICE
349
+ static void
350
+ wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
351
+ BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val);
352
+ }
353
+
354
+ CUTLASS_DEVICE
355
+ static void
356
+ wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
357
+ BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val);
358
+ }
359
+
360
+ CUTLASS_DEVICE
361
+ static void
362
+ arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
363
+ BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val);
364
+ }
365
+
366
+ CUTLASS_DEVICE
367
+ static void
368
+ arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
369
+ BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val);
370
+ }
371
+ };
372
+
373
+ /////////////////////////////////////////////////////////////////////////////////////////////////
374
+
375
+ } // namespace cutlass
376
+
377
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*!
32
+ \file
33
+ \brief Defines a proxy class for storing non-standard 16-bit floating point values with
34
+ 8 bits of exponent and 7 bit of mantissa.
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #if defined(__CUDACC_RTC__)
40
+ #include "cutlass/floating_point_nvrtc.h"
41
+ #else
42
+ #include <cmath>
43
+ #include <limits>
44
+ #include <cstdint>
45
+ #include <cstring>
46
+ #endif
47
+
48
+ #include <cuda_bf16.h>
49
+ #include "cutlass/cutlass.h"
50
+ #include "cutlass/platform/platform.h"
51
+
52
+ namespace cutlass {
53
+
54
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// Floating-point type with 8 bits of exponent and 7 bits of mantissa.
57
+ struct alignas(2) bfloat16_t {
58
+
59
+ //
60
+ // Data members
61
+ //
62
+
63
+ /// Storage type
64
+ uint16_t storage;
65
+
66
+ //
67
+ // Methods
68
+ //
69
+
70
+ /// Constructs from an unsigned short
71
+ CUTLASS_HOST_DEVICE
72
+ static bfloat16_t bitcast(uint16_t x) {
73
+ bfloat16_t h;
74
+ h.storage = x;
75
+ return h;
76
+ }
77
+
78
+ private:
79
+ struct from_32_bit_integer_t {};
80
+ static constexpr from_32_bit_integer_t from_32_bit_integer{};
81
+
82
+ template<class T>
83
+ CUTLASS_HOST_DEVICE
84
+ explicit bfloat16_t(from_32_bit_integer_t, T x) {
85
+ static_assert(cutlass::platform::is_integral<T>::value && sizeof(T) == 4, "Requires 32-bit integer");
86
+
87
+ float flt = static_cast<float>(x);
88
+ uint32_t bits;
89
+
90
+ #if defined(__CUDA_ARCH__)
91
+ bits = reinterpret_cast<uint32_t &>(flt);
92
+ #else
93
+ std::memcpy(&bits, &flt, sizeof(bits));
94
+ #endif
95
+
96
+ storage = uint16_t(bits >> 16);
97
+ }
98
+
99
+ public:
100
+ /// Default constructor
101
+ bfloat16_t() = default;
102
+
103
+ /// Reinterpret cast from CUDA's __nv_bfloat16 type
104
+ CUTLASS_HOST_DEVICE
105
+ explicit bfloat16_t(__nv_bfloat16 const & x) {
106
+ #if defined(__CUDA_ARCH__)
107
+ storage = reinterpret_cast<uint16_t const &>(x);
108
+ #else
109
+ __nv_bfloat16_raw raw(x);
110
+ std::memcpy(&storage, &raw.x, sizeof(storage));
111
+ #endif
112
+ }
113
+
114
+ /// Floating-point conversion - round toward nearest
115
+ CUTLASS_HOST_DEVICE
116
+ explicit bfloat16_t(float x) {
117
+
118
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
119
+
120
+ asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));
121
+
122
+ #else
123
+ uint32_t bits;
124
+
125
+ #if defined(__CUDA_ARCH__)
126
+ bits = reinterpret_cast<uint32_t &>(x);
127
+ #else
128
+ std::memcpy(&bits, &x, sizeof(bits));
129
+ #endif
130
+
131
+ if ((bits & 0x7f800000) != 0x7f800000) {
132
+
133
+ bool mantissa_bit = ((bits & (1 << 16)) != 0);
134
+ bool round_bit = ((bits & (1 << 15)) != 0);
135
+ bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0);
136
+
137
+ if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) {
138
+ bits += uint32_t(1 << 16);
139
+ }
140
+ }
141
+ else if (bits & ~0xff800000) {
142
+ bits = 0x7fffffff;
143
+ }
144
+
145
+ storage = uint16_t((bits >> 16) & 0xffff);
146
+ #endif
147
+ }
148
+
149
+ /// Floating-point conversion - round toward nearest
150
+ CUTLASS_HOST_DEVICE
151
+ explicit bfloat16_t(double x): bfloat16_t(float(x)) {
152
+
153
+ }
154
+
155
+ /// Integer conversion - round toward nearest
156
+ CUTLASS_HOST_DEVICE
157
+ explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {}
158
+
159
+ CUTLASS_HOST_DEVICE
160
+ explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {}
161
+
162
+ /// Converts to float
163
+ CUTLASS_HOST_DEVICE
164
+ operator float() const {
165
+ unsigned bits = (unsigned(storage) << 16);
166
+ #if defined(__CUDA_ARCH__)
167
+ return reinterpret_cast<float const &>(bits);
168
+ #else
169
+ float flt;
170
+ std::memcpy(&flt, &bits, sizeof(flt));
171
+ return flt;
172
+ #endif
173
+ }
174
+
175
+ /// Converts to float
176
+ CUTLASS_HOST_DEVICE
177
+ explicit operator double() const {
178
+ return double(float(*this));
179
+ }
180
+
181
+ /// Converts to int
182
+ CUTLASS_HOST_DEVICE
183
+ explicit operator int() const {
184
+ return int(float(*this));
185
+ }
186
+
187
+ /// Casts to bool
188
+ CUTLASS_HOST_DEVICE
189
+ explicit operator bool() const {
190
+ return (float(*this) != 0.0f);
191
+ }
192
+
193
+ /// Bitcasts to CUDA's bf16 type
194
+ CUTLASS_DEVICE
195
+ __nv_bfloat16 to_nv_bfloat16() const {
196
+ return reinterpret_cast<__nv_bfloat16 const &>(storage);
197
+ }
198
+
199
+ /// Obtains raw bits
200
+ CUTLASS_HOST_DEVICE
201
+ uint16_t raw() const {
202
+ return storage;
203
+ }
204
+ /// Returns the sign bit
205
+ CUTLASS_HOST_DEVICE
206
+ bool signbit() const {
207
+ return ((raw() & 0x8000) != 0);
208
+ }
209
+
210
+ /// Returns the biased exponent
211
+ CUTLASS_HOST_DEVICE
212
+ int exponent_biased() const {
213
+ return int((raw() >> 7) & 0x0ff);
214
+ }
215
+
216
+ /// Returns the unbiased exponent
217
+ CUTLASS_HOST_DEVICE
218
+ int exponent() const {
219
+ return exponent_biased() - 127;
220
+ }
221
+
222
+ /// Returns the mantissa
223
+ CUTLASS_HOST_DEVICE
224
+ int mantissa() const {
225
+ return int(raw() & 0x7f);
226
+ }
227
+ };
228
+
229
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
230
+
231
+ CUTLASS_HOST_DEVICE
232
+ bool signbit(cutlass::bfloat16_t const& h) {
233
+ return h.signbit();
234
+ }
235
+
236
+ CUTLASS_HOST_DEVICE
237
+ cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) {
238
+ return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff);
239
+ }
240
+
241
+ CUTLASS_HOST_DEVICE
242
+ bool isnan(cutlass::bfloat16_t const& h) {
243
+ return (h.exponent_biased() == 0x0ff) && h.mantissa();
244
+ }
245
+
246
+ CUTLASS_HOST_DEVICE
247
+ bool isfinite(cutlass::bfloat16_t const& h) {
248
+ return (h.exponent_biased() != 0x0ff);
249
+ }
250
+
251
+ CUTLASS_HOST_DEVICE
252
+ cutlass::bfloat16_t nan_bf16(const char*) {
253
+ // NVIDIA canonical NaN
254
+ return cutlass::bfloat16_t::bitcast(0x7fff);
255
+ }
256
+
257
+ CUTLASS_HOST_DEVICE
258
+ bool isinf(cutlass::bfloat16_t const& h) {
259
+ return (h.exponent_biased() == 0x0ff) && !h.mantissa();
260
+ }
261
+
262
+ CUTLASS_HOST_DEVICE
263
+ bool isnormal(cutlass::bfloat16_t const& h) {
264
+ return h.exponent_biased() && h.exponent_biased() != 0x0ff;
265
+ }
266
+
267
+ CUTLASS_HOST_DEVICE
268
+ int fpclassify(cutlass::bfloat16_t const& h) {
269
+ int exp = h.exponent_biased();
270
+ int mantissa = h.mantissa();
271
+ if (exp == 0x0ff) {
272
+ if (mantissa) {
273
+ return FP_NAN;
274
+ }
275
+ else {
276
+ return FP_INFINITE;
277
+ }
278
+ }
279
+ else if (!exp) {
280
+ if (mantissa) {
281
+ return FP_SUBNORMAL;
282
+ }
283
+ else {
284
+ return FP_ZERO;
285
+ }
286
+ }
287
+ return FP_NORMAL;
288
+ }
289
+
290
+ CUTLASS_HOST_DEVICE
291
+ cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) {
292
+ #if defined(__CUDACC_RTC__)
293
+ return cutlass::bfloat16_t(sqrtf(float(h)));
294
+ #else
295
+ return cutlass::bfloat16_t(std::sqrt(float(h)));
296
+ #endif
297
+ }
298
+
299
+ CUTLASS_HOST_DEVICE
300
+ bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) {
301
+
302
+ uint16_t a_bits;
303
+ uint16_t b_bits;
304
+
305
+ #if defined(__CUDA_ARCH__)
306
+ a_bits = reinterpret_cast<uint16_t const &>(a);
307
+ b_bits = reinterpret_cast<uint16_t const &>(b);
308
+ #else
309
+ std::memcpy(&a_bits, &a, sizeof(a_bits));
310
+ std::memcpy(&b_bits, &b, sizeof(b_bits));
311
+ #endif
312
+
313
+ uint16_t a_mag = (a_bits & 0x7fff);
314
+ uint16_t b_sign = (b_bits & 0x8000);
315
+ uint16_t result = (a_mag | b_sign);
316
+
317
+ return bfloat16_t::bitcast(result);
318
+ }
319
+
320
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
321
+
322
+ } // namespace cutlass
323
+
324
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
325
+ //
326
+ // Standard Library operations and definitions
327
+ //
328
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
329
+
330
+ #if !defined(__CUDACC_RTC__)
331
+ namespace std {
332
+
333
+ /// Numeric limits
334
+ template <>
335
+ struct numeric_limits<cutlass::bfloat16_t> {
336
+ static bool const is_specialized = true;
337
+ static bool const is_signed = true;
338
+ static bool const is_integer = false;
339
+ static bool const is_exact = false;
340
+ static bool const has_infinity = true;
341
+ static bool const has_quiet_NaN = true;
342
+ static bool const has_signaling_NaN = false;
343
+ static std::float_denorm_style const has_denorm = std::denorm_present;
344
+ static bool const has_denorm_loss = true;
345
+ static std::float_round_style const round_style = std::round_to_nearest;
346
+ static bool const is_iec559 = false;
347
+ static bool const is_bounded = true;
348
+ static bool const is_modulo = false;
349
+ static int const digits = 7;
350
+
351
+ /// Least positive value
352
+ CUTLASS_HOST_DEVICE
353
+ static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
354
+
355
+ /// Minimum finite value
356
+ CUTLASS_HOST_DEVICE
357
+ static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
358
+
359
+ /// Maximum finite value
360
+ CUTLASS_HOST_DEVICE
361
+ static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
362
+
363
+ /// Returns smallest finite value
364
+ CUTLASS_HOST_DEVICE
365
+ static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
366
+
367
+ /// Returns smallest finite value
368
+ CUTLASS_HOST_DEVICE
369
+ static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
370
+
371
+ /// Returns smallest finite value
372
+ CUTLASS_HOST_DEVICE
373
+ static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
374
+
375
+ /// Returns smallest finite value
376
+ CUTLASS_HOST_DEVICE
377
+ static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
378
+
379
+ /// Returns smallest finite value
380
+ CUTLASS_HOST_DEVICE
381
+ static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
382
+
383
+ /// Returns smallest finite value
384
+ CUTLASS_HOST_DEVICE
385
+ static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
386
+ };
387
+
388
+ } // namespace std
389
+ #endif
390
+
391
+ namespace cutlass {
392
+ namespace platform {
393
+
394
+ /// Forward Declaration
395
+ template <class T>
396
+ struct numeric_limits;
397
+
398
+ /// Numeric limits
399
+ template <>
400
+ struct numeric_limits<cutlass::bfloat16_t> {
401
+ static bool const is_specialized = true;
402
+ static bool const is_signed = true;
403
+ static bool const is_integer = false;
404
+ static bool const is_exact = false;
405
+ static bool const has_infinity = true;
406
+ static bool const has_quiet_NaN = true;
407
+ static bool const has_signaling_NaN = false;
408
+ #if !defined(__CUDACC_RTC__)
409
+ static std::float_denorm_style const has_denorm = std::denorm_present;
410
+ #endif
411
+ static bool const has_denorm_loss = true;
412
+ #if !defined(__CUDACC_RTC__)
413
+ static std::float_round_style const round_style = std::round_to_nearest;
414
+ #endif
415
+ static bool const is_iec559 = false;
416
+ static bool const is_bounded = true;
417
+ static bool const is_modulo = false;
418
+ static int const digits = 7;
419
+
420
+ /// Least positive value
421
+ CUTLASS_HOST_DEVICE
422
+ static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
423
+
424
+ /// Minimum finite value
425
+ CUTLASS_HOST_DEVICE
426
+ static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
427
+
428
+ /// Maximum finite value
429
+ CUTLASS_HOST_DEVICE
430
+ static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
431
+
432
+ /// Returns smallest finite value
433
+ CUTLASS_HOST_DEVICE
434
+ static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
435
+
436
+ /// Returns smallest finite value
437
+ CUTLASS_HOST_DEVICE
438
+ static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
439
+
440
+ /// Returns smallest finite value
441
+ CUTLASS_HOST_DEVICE
442
+ static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
443
+
444
+ /// Returns smallest finite value
445
+ CUTLASS_HOST_DEVICE
446
+ static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
447
+
448
+ /// Returns smallest finite value
449
+ CUTLASS_HOST_DEVICE
450
+ static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
451
+
452
+ /// Returns smallest finite value
453
+ CUTLASS_HOST_DEVICE
454
+ static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
455
+ };
456
+
457
+ } // namespace platform
458
+ } // namespace cutlass
459
+
460
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
461
+ //
462
+ // Arithmetic operators
463
+ //
464
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
465
+
466
+ namespace cutlass {
467
+
468
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
469
+
470
+ CUTLASS_HOST_DEVICE
471
+ bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) {
472
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
473
+ return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
474
+ #else
475
+ return float(lhs) == float(rhs);
476
+ #endif
477
+ }
478
+
479
+ CUTLASS_HOST_DEVICE
480
+ bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
481
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
482
+ return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
483
+ #else
484
+ return float(lhs) != float(rhs);
485
+ #endif
486
+ }
487
+
488
+ CUTLASS_HOST_DEVICE
489
+ bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) {
490
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
491
+ return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
492
+ #else
493
+ return float(lhs) < float(rhs);
494
+ #endif
495
+ }
496
+
497
+ CUTLASS_HOST_DEVICE
498
+ bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
499
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
500
+ return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
501
+ #else
502
+ return float(lhs) <= float(rhs);
503
+ #endif
504
+ }
505
+
506
+ CUTLASS_HOST_DEVICE
507
+ bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) {
508
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
509
+ return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
510
+ #else
511
+ return float(lhs) > float(rhs);
512
+ #endif
513
+ }
514
+
515
+ CUTLASS_HOST_DEVICE
516
+ bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
517
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
518
+ return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
519
+ #else
520
+ return float(lhs) >= float(rhs);
521
+ #endif
522
+ }
523
+
524
+ CUTLASS_HOST_DEVICE
525
+ bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) {
526
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
527
+ return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
528
+ #else
529
+ return bfloat16_t(float(lhs) + float(rhs));
530
+ #endif
531
+ }
532
+
533
+ CUTLASS_HOST_DEVICE
534
+ bfloat16_t operator-(bfloat16_t const& lhs) {
535
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
536
+ return bfloat16_t(__hneg(lhs.to_nv_bfloat16()));
537
+ #else
538
+ return bfloat16_t(-float(lhs));
539
+ #endif
540
+ }
541
+
542
+ CUTLASS_HOST_DEVICE
543
+ bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) {
544
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
545
+ return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
546
+ #else
547
+ return bfloat16_t(float(lhs) - float(rhs));
548
+ #endif
549
+ }
550
+
551
+ CUTLASS_HOST_DEVICE
552
+ bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) {
553
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
554
+ return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
555
+ #else
556
+ return bfloat16_t(float(lhs) * float(rhs));
557
+ #endif
558
+ }
559
+
560
+ CUTLASS_HOST_DEVICE
561
+ bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) {
562
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
563
+ return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
564
+ #else
565
+ return bfloat16_t(float(lhs) / float(rhs));
566
+ #endif
567
+ }
568
+
569
+ CUTLASS_HOST_DEVICE
570
+ bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) {
571
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
572
+ lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
573
+ #else
574
+ lhs = bfloat16_t(float(lhs) + float(rhs));
575
+ #endif
576
+ return lhs;
577
+ }
578
+
579
+ CUTLASS_HOST_DEVICE
580
+ bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) {
581
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
582
+ lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
583
+ #else
584
+ lhs = bfloat16_t(float(lhs) - float(rhs));
585
+ #endif
586
+ return lhs;
587
+ }
588
+
589
+ CUTLASS_HOST_DEVICE
590
+ bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) {
591
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
592
+ lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
593
+ #else
594
+ lhs = bfloat16_t(float(lhs) * float(rhs));
595
+ #endif
596
+ return lhs;
597
+ }
598
+
599
+ CUTLASS_HOST_DEVICE
600
+ bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) {
601
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
602
+ lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
603
+ #else
604
+ lhs = bfloat16_t(float(lhs) / float(rhs));
605
+ #endif
606
+ return lhs;
607
+ }
608
+
609
+ CUTLASS_HOST_DEVICE
610
+ bfloat16_t& operator++(bfloat16_t & lhs) {
611
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
612
+ lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
613
+ #else
614
+ float tmp(lhs);
615
+ ++tmp;
616
+ lhs = bfloat16_t(tmp);
617
+ #endif
618
+ return lhs;
619
+ }
620
+
621
+ CUTLASS_HOST_DEVICE
622
+ bfloat16_t& operator--(bfloat16_t & lhs) {
623
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
624
+ lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
625
+ #else
626
+ float tmp(lhs);
627
+ --tmp;
628
+ lhs = bfloat16_t(tmp);
629
+ #endif
630
+ return lhs;
631
+ }
632
+
633
+ CUTLASS_HOST_DEVICE
634
+ bfloat16_t operator++(bfloat16_t & lhs, int) {
635
+ bfloat16_t ret(lhs);
636
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
637
+ lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
638
+ #else
639
+ float tmp(lhs);
640
+ tmp++;
641
+ lhs = bfloat16_t(tmp);
642
+ #endif
643
+ return ret;
644
+ }
645
+
646
+ CUTLASS_HOST_DEVICE
647
+ bfloat16_t operator--(bfloat16_t & lhs, int) {
648
+ bfloat16_t ret(lhs);
649
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
650
+ lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
651
+ #else
652
+ float tmp(lhs);
653
+ tmp--;
654
+ lhs = bfloat16_t(tmp);
655
+ #endif
656
+ return ret;
657
+ }
658
+
659
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
660
+
661
+ } // namespace cutlass
662
+
663
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
664
+
665
+ //
666
+ // User-defined literals
667
+ //
668
+
669
+ CUTLASS_HOST_DEVICE
670
+ cutlass::bfloat16_t operator "" _bf16(long double x) {
671
+ return cutlass::bfloat16_t(float(x));
672
+ }
673
+
674
+ CUTLASS_HOST_DEVICE
675
+ cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) {
676
+ return cutlass::bfloat16_t(int(x));
677
+ }
678
+
679
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Basic include for CUTLASS BLAS3/HPC code.
34
+
35
+
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/array.h"
42
+ #include "cutlass/blas3_types.h"
43
+ #include "cutlass/coord.h"
44
+ #include "cutlass/complex.h"
45
+ #include "cutlass/functional.h"
46
+ #include "cutlass/numeric_types.h"
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Defines FillMode inversions
55
+ template <FillMode kFillMode>
56
+ struct InvertFillMode;
57
+
58
+ /// Invert FillMode lower to upper
59
+ template <>
60
+ struct InvertFillMode<FillMode::kLower> {
61
+ static FillMode const mode = FillMode::kUpper;
62
+ };
63
+
64
+ /// Invert FillMode upper to lower
65
+ template <>
66
+ struct InvertFillMode<FillMode::kUpper> {
67
+ static FillMode const mode = FillMode::kLower;
68
+ };
69
+
70
+ /////////////////////////////////////////////////////////////////////////////////////////////////
71
+ /// Defines SideMode inversions
72
+ template <SideMode kSideMode>
73
+ struct InvertSideMode;
74
+
75
+ /// Invert SideMode left to right
76
+ template <>
77
+ struct InvertSideMode<SideMode::kLeft> {
78
+ static SideMode const mode = SideMode::kRight;
79
+ };
80
+
81
+ /// Invert SideMode right to left
82
+ template <>
83
+ struct InvertSideMode<SideMode::kRight> {
84
+ static SideMode const mode = SideMode::kLeft;
85
+ };
86
+
87
+ /////////////////////////////////////////////////////////////////////////////////////////////////
88
+ /// Defines correct compare operation for Triangular matrix boundary
89
+ template <FillMode kFillMode, DiagType kDiagType = DiagType::kNonUnit>
90
+ struct TrMatrixCompareOp {
91
+ using Index = int32_t;
92
+ using Type = typename platform::conditional<
93
+ (kFillMode == FillMode::kLower),
94
+ greater_equal<Index>,
95
+ less_equal<Index>>::type;
96
+ };
97
+
98
+ template <FillMode kFillMode>
99
+ struct TrMatrixCompareOp <kFillMode, DiagType::kUnit> {
100
+ using Index = int32_t;
101
+ using Type = typename platform::conditional<
102
+ (kFillMode == FillMode::kLower),
103
+ greater_equal<Index>,
104
+ less_equal<Index>>::type;
105
+ };
106
+
107
+ template <FillMode kFillMode>
108
+ struct TrMatrixCompareOp <kFillMode, DiagType::kZero> {
109
+ using Index = int32_t;
110
+ using Type = typename platform::conditional<
111
+ (kFillMode == FillMode::kLower),
112
+ greater<Index>,
113
+ less<Index>>::type;
114
+ };
115
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
116
+ // Returns precision in terms of bits (based on datatype) to fill tensors with.
117
+ // Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs).
118
+ // Also defines acceptable mantissa result variance/error.
119
+ template <typename Element>
120
+ struct MantissaInBits {
121
+ static int constexpr bits = 5;
122
+ static double constexpr error = 1.0e-7;
123
+ };
124
+
125
+ // Full precision is supported for FP64
126
+ template <>
127
+ struct MantissaInBits<double> {
128
+ static int constexpr bits = 30;
129
+ static double constexpr error = 1.0e-15;
130
+ };
131
+
132
+ template <>
133
+ struct MantissaInBits<cutlass::complex<double>> {
134
+ static int constexpr bits = 30;
135
+ static double constexpr error = 1.0e-14;
136
+ };
137
+
138
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
139
+
140
+ } // namespace cutlass
141
+
142
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
143
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ /////////////////////////////////////////////////////////////////////////////////////////////////
35
+
36
+ namespace cutlass {
37
+
38
+ /////////////////////////////////////////////////////////////////////////////////////////////////
39
+
40
+ /// Enumerated type describing the type of kernel (based on input or output matrices).
41
+ enum class BlasMode {
42
+ kGemm,
43
+ kSymmetric,
44
+ kHermitian,
45
+ kTriangular,
46
+ kInvalid
47
+ };
48
+
49
+ /// Enumerated type describing the fill mode for matrices for BLAS functions.
50
+ enum class FillMode {
51
+ kFull, /// The entire tensor is covered.
52
+ kLower, /// The 'lower' part of a tensor is covered including diagonal
53
+ kUpper, /// The 'upper' part of a tensor is covered including diaognal
54
+ kDiagonal, /// Only diagonal elements are covered.
55
+ kNone, /// No element is covered.
56
+ kInvalid
57
+ };
58
+
59
+ /// Enumerated type describing the diagonal property of matrices for BLAS functions.
60
+ enum class DiagType {
61
+ kNonUnit,
62
+ kUnit,
63
+ kZero, // Only used internally for computing SYMM/HEMM
64
+ kInvalid
65
+ };
66
+
67
+ /// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions.
68
+ enum class SideMode {
69
+ kLeft,
70
+ kRight,
71
+ kInvalid
72
+ };
73
+
74
+ /////////////////////////////////////////////////////////////////////////////////////////////////
75
+
76
+ } // namespace cutlass
77
+
78
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable,
33
+ statically-sized array types to global memory.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/wmma_array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/complex.h"
43
+
44
+ namespace cutlass {
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+ // AccessWidth
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit
51
+ template <
52
+ typename T,
53
+ int Limit>
54
+ struct AccessWidth
55
+ {
56
+ // Inductive case
57
+ template <
58
+ int ObjectBytes, /// Size of T in bytes
59
+ int AlignBytes, /// Template induction variable
60
+ bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes
61
+ ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))>
62
+ struct Detail
63
+ {
64
+ static const int value = Detail<ObjectBytes, AlignBytes * 2>::value;
65
+ };
66
+
67
+ // Base case (ObjectBytes is not an even multiple of AlignBytes)
68
+ template <
69
+ int ObjectBytes, /// Size of T in bytes
70
+ int AlignBytes> /// Template induction variable
71
+ struct Detail<ObjectBytes, AlignBytes, false>
72
+ {
73
+ static const int value = AlignBytes / 2;
74
+ };
75
+
76
+ /// The maximal power-of-two that evenly divides the size of T
77
+ static const int value = Detail<
78
+ (int) sizeof(T),
79
+ 1>::value;
80
+ };
81
+
82
+
83
+
84
+ /////////////////////////////////////////////////////////////////////////////////////////////////
85
+ // StripedAccessType
86
+ /////////////////////////////////////////////////////////////////////////////////////////////////
87
+
88
+ /// ReinterpretCast type for striping a trivially-copyable type in global memory
89
+ /// (Default specialization. Striping granularity is type T.)
90
+ template <
91
+ typename T, /// Data type
92
+ int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures)
93
+ AccessWidth<T, 16>::value>
94
+ struct alignas(TransferBytes) StripedAccessType : public T
95
+ {};
96
+
97
+
98
+ /// ReinterpretCast type for striping a trivially-copyable type in global memory
99
+ /// (Specialization for cutlass::Array<T>. Striping granularity is a multiple of T.)
100
+ template <
101
+ typename T, /// Array element type
102
+ int N, /// Number of elements in array
103
+ bool RegisterSized, /// T is register-sized
104
+ int TransferBytes> /// Data access width
105
+ struct StripedAccessType<
106
+ Array<T, N, RegisterSized>,
107
+ TransferBytes>
108
+ : public AlignedArray<
109
+ T, // Element type of StripedAccessType
110
+ __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType
111
+ TransferBytes> // Alignment of StripedAccessType
112
+ {};
113
+
114
+
115
+ #if defined(CUTLASS_ARCH_WMMA_ENABLED)
116
+
117
+ /// ReinterpretCast type for striping a trivially-copyable type in global memory
118
+ /// (Specialization for cutlass::WmmaFragmentArray<T>. Striping granularity is a multiple of T.)
119
+ template<
120
+ typename Use,
121
+ int m,
122
+ int n,
123
+ int k,
124
+ typename ElementT,
125
+ typename Layout,
126
+ int kFragments,
127
+ int TransferBytes>
128
+ struct StripedAccessType<
129
+ WmmaFragmentArray<nvcuda::wmma::fragment<Use, m, n, k, ElementT, Layout>, kFragments>,
130
+ TransferBytes>
131
+ : public AlignedArray<
132
+ ElementT,
133
+ __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)),
134
+ TransferBytes>
135
+ {};
136
+
137
+ #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
138
+
139
+
140
+ /////////////////////////////////////////////////////////////////////////////////////////////////
141
+ // BlockStriped
142
+ /////////////////////////////////////////////////////////////////////////////////////////////////
143
+
144
+ /// Utility for performing block-striped access (load, store) of trivially-copyable,
145
+ /// statically-sized array types to global memory
146
+ template <
147
+ int BlockThreads,
148
+ typename ArrayT,
149
+ typename AccessT = StripedAccessType<ArrayT> >
150
+ struct BlockStriped
151
+ {
152
+ /// Number of striped accesses
153
+ static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT));
154
+ static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type");
155
+
156
+ /// Load
157
+ CUTLASS_DEVICE
158
+ static void load(ArrayT &data, ArrayT *ptr, int thread_idx)
159
+ {
160
+ AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
161
+ AccessT *access_data = reinterpret_cast<AccessT*>(&data);
162
+
163
+ CUTLASS_PRAGMA_UNROLL
164
+ for (int i = 0; i < kStripes; ++i) {
165
+ access_data[i] = access_input[(BlockThreads * i) + thread_idx];
166
+ }
167
+ }
168
+
169
+ /// Load & Add
170
+ CUTLASS_DEVICE
171
+ static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx)
172
+ {
173
+ AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
174
+ AccessT *access_data = reinterpret_cast<AccessT*>(&data);
175
+
176
+ plus<AccessT> add;
177
+
178
+ CUTLASS_PRAGMA_UNROLL
179
+ for (int i = 0; i < kStripes; ++i)
180
+ {
181
+ access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]);
182
+ }
183
+ }
184
+
185
+ /// Store
186
+ CUTLASS_DEVICE
187
+ static void store(ArrayT *ptr, const ArrayT &data, int thread_idx)
188
+ {
189
+ AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
190
+ const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
191
+
192
+ CUTLASS_PRAGMA_UNROLL
193
+ for (int i = 0; i < kStripes; ++i) {
194
+ access_output[(BlockThreads * i) + thread_idx] = access_data[i];
195
+ }
196
+ }
197
+
198
+ };
199
+
200
+
201
+ /////////////////////////////////////////////////////////////////////////////////////////////////
202
+ // BlockStripedReduce
203
+ /////////////////////////////////////////////////////////////////////////////////////////////////
204
+
205
+
206
+ /// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
207
+ /// statically-sized array types to global memory.
208
+ /// (Default specialization)
209
+ template <
210
+ int BlockThreads,
211
+ typename ArrayT,
212
+ typename ElementT = typename StripedAccessType<ArrayT>::Element>
213
+ struct BlockStripedReduce :
214
+ BlockStriped<
215
+ BlockThreads,
216
+ ArrayT,
217
+ ElementT>
218
+ {
219
+ /// Reduce
220
+ CUTLASS_DEVICE
221
+ static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
222
+ {
223
+ cutlass::atomic_add<ElementT> reduce;
224
+ ElementT *access_output = reinterpret_cast<ElementT*>(ptr);
225
+ const ElementT *access_data = reinterpret_cast<const ElementT*>(&data);
226
+
227
+ CUTLASS_PRAGMA_UNROLL
228
+ for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
229
+ reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
230
+ }
231
+ }
232
+ };
233
+
234
+
235
+ /// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
236
+ /// statically-sized array types to global memory.
237
+ /// (Specialization for half_t. Uses half2 vectorized-reduction.)
238
+ template <
239
+ int BlockThreads,
240
+ typename ArrayT>
241
+ struct BlockStripedReduce<BlockThreads, ArrayT, half_t> :
242
+ BlockStriped<
243
+ BlockThreads,
244
+ ArrayT,
245
+ half2>
246
+ {
247
+ static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
248
+
249
+ /// Reduce
250
+ CUTLASS_DEVICE
251
+ static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
252
+ {
253
+ cutlass::atomic_add<half2> reduce;
254
+ half2 *access_output = reinterpret_cast<half2*>(ptr);
255
+ const half2 *access_data = reinterpret_cast<const half2*>(&data);
256
+
257
+ CUTLASS_PRAGMA_UNROLL
258
+ for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
259
+ {
260
+ reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
261
+ }
262
+ }
263
+ };
264
+
265
+
266
+ } // namespace cutlass
267
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief CUDA interfaces to launch CUTLASS device-level operators (for >= SM90) that use thread-block clusters.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include <cuda_runtime_api.h>
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/trace.h"
41
+ #include <cute/arch/cluster_sm100.hpp>
42
+ #include "cutlass/arch/synclog.hpp"
43
+
44
+ #if defined(__CUDACC_RTC__)
45
+ #include CUDA_STD_HEADER(type_traits)
46
+ #else
47
+ #include <type_traits>
48
+ #include <cstdio>
49
+ #endif
50
+
51
+ #if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))
52
+ # define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED
53
+ #endif
54
+
55
+ #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
56
+ # define CUDA_ENABLE_PREFERRED_CLUSTER
57
+ #endif
58
+ namespace cutlass {
59
+
60
+ #ifndef NDEBUG
61
+ #define Return_Status(cudaError_t_status) \
62
+ if (cudaError_t_status != cudaSuccess) { \
63
+ fprintf(stderr, \
64
+ "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \
65
+ __FILE__, \
66
+ __LINE__, \
67
+ cudaGetErrorString(cudaError_t_status)); \
68
+ return Status::kInvalid; \
69
+ } else { \
70
+ return Status::kSuccess; \
71
+ }
72
+ #else
73
+ #define Return_Status(cudaError_t_status) \
74
+ if (cudaError_t_status != cudaSuccess) { \
75
+ return Status::kInvalid; \
76
+ } else { \
77
+ return Status::kSuccess; \
78
+ }
79
+ #endif
80
+
81
+ struct ClusterLauncher {
82
+ constexpr static int MaxClusterSize = 32;
83
+
84
+ struct LaunchConfig {
85
+ #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
86
+ cudaLaunchConfig_t launch_config;
87
+
88
+ #if defined(CUDA_ENABLE_PREFERRED_CLUSTER)
89
+ constexpr static int numAttrs = 3;
90
+ #else
91
+
92
+ constexpr static int numAttrs = 2;
93
+ #endif
94
+ cudaLaunchAttribute launch_attribute[numAttrs];
95
+ // Commonly used utility functions
96
+ dim3 gridDim() { return launch_config.gridDim; }
97
+ dim3 blockDim() { return launch_config.blockDim; }
98
+ #endif
99
+ };
100
+
101
+ // Check for hardware compatibility
102
+ static inline CUTLASS_HOST
103
+ Status check_cluster_dims(dim3 grid, dim3 cluster) {
104
+ if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) &&
105
+ (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) {
106
+ return Status::kSuccess;
107
+ }
108
+ else {
109
+ CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch.");
110
+ return Status::kInvalid;
111
+ }
112
+ }
113
+
114
+ static inline CUTLASS_HOST
115
+ Status
116
+ #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
117
+ init(void const* kernel_function)
118
+ #else
119
+ init(void const* /* kernel_function */)
120
+ #endif
121
+ {
122
+ #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
123
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
124
+ if (kernel_function == nullptr) {
125
+ CUTLASS_TRACE_HOST("kernel_function is null");
126
+ return Status::kInvalid;
127
+ }
128
+ CUTLASS_TRACE_HOST("Checking previous error state before calling cudaFuncSetAttribute");
129
+ cudaError_t prevStatus = cudaGetLastError();
130
+ if (prevStatus != cudaSuccess) {
131
+ fprintf(stderr,
132
+ "[ ERROR: CUDA Runtime ] %s:%d: %s\n",
133
+ __FILE__,
134
+ __LINE__,
135
+ cudaGetErrorString(prevStatus));
136
+ return Status::kInvalid;
137
+ }
138
+ CUTLASS_TRACE_HOST("Calling cudaFuncSetAttribute");
139
+ #endif
140
+ // This attribute was added in CUDA 11.8.
141
+ cudaError_t status =
142
+ cudaFuncSetAttribute(
143
+ kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1);
144
+ Return_Status(status);
145
+ #else
146
+ return Status::kInvalid;
147
+ #endif
148
+ }
149
+
150
+ static inline CUTLASS_HOST
151
+ LaunchConfig make_cluster_launch_config(
152
+ dim3 const grid_dims,
153
+ dim3 const cluster_dims,
154
+ dim3 const block_dims,
155
+ size_t const smem_size = 0,
156
+ cudaStream_t cuda_stream = 0,
157
+ bool launch_with_pdl = false
158
+ , dim3 const fallback_cluster_dims = {0, 0, 0}
159
+ ) {
160
+ LaunchConfig cluster_launch_config;
161
+ #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
162
+ auto &launch_config = cluster_launch_config.launch_config;
163
+ auto &launch_attribute = cluster_launch_config.launch_attribute;
164
+ auto numAttrs = cluster_launch_config.numAttrs;
165
+
166
+ launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
167
+
168
+ bool have_fallback = fallback_cluster_dims.x * fallback_cluster_dims.y * fallback_cluster_dims.z > 0;
169
+
170
+ if (have_fallback) {
171
+ launch_attribute[0].val.clusterDim = {fallback_cluster_dims.x, fallback_cluster_dims.y, fallback_cluster_dims.z};
172
+ CUTLASS_TRACE_HOST("ClusterLauncher: Setting fallback ClusterDims = "
173
+ "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n");
174
+ }
175
+ else {
176
+
177
+ launch_attribute[0].val.clusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z};
178
+ CUTLASS_TRACE_HOST("ClusterLauncher: Setting ClusterDims = "
179
+ "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
180
+
181
+ }
182
+
183
+ #if defined(CUDA_ENABLE_PREFERRED_CLUSTER)
184
+ if (have_fallback) {
185
+ if (cute::initialize_preferred_cluster_launch(nullptr, grid_dims, cluster_dims, fallback_cluster_dims)) {
186
+ launch_attribute[1].id = cudaLaunchAttributePreferredClusterDimension;
187
+ launch_attribute[1].val.preferredClusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z};
188
+ CUTLASS_TRACE_HOST("ClusterLauncher: Setting preferred ClusterDims = "
189
+ "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
190
+ }
191
+ }
192
+ else {
193
+ numAttrs--;
194
+ }
195
+ #endif
196
+
197
+
198
+ // PDL attributes
199
+ launch_attribute[numAttrs - 1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
200
+ launch_attribute[numAttrs - 1].val.programmaticStreamSerializationAllowed = 1;
201
+
202
+ launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z};
203
+ launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z};
204
+ launch_config.dynamicSmemBytes = smem_size;
205
+ launch_config.stream = cuda_stream;
206
+ launch_config.numAttrs = launch_with_pdl ? numAttrs : numAttrs - 1;
207
+ launch_config.attrs = launch_attribute;
208
+ return cluster_launch_config;
209
+ #else
210
+ CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
211
+ return cluster_launch_config;
212
+ #endif
213
+ }
214
+
215
+ // This is the method we expect to use going forward
216
+ static inline CUTLASS_HOST
217
+ Status launch(
218
+ dim3 const grid_dims,
219
+ dim3 const cluster_dims,
220
+ dim3 const block_dims,
221
+ size_t const smem_size,
222
+ cudaStream_t cuda_stream,
223
+ void const* kernel,
224
+ void** kernel_params,
225
+ bool launch_with_pdl = false) {
226
+ #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
227
+ LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, cluster_dims,
228
+ block_dims, smem_size, cuda_stream, launch_with_pdl);
229
+
230
+ auto launch_grid_dims = cluster_launch_config.gridDim();
231
+ if (check_cluster_dims(launch_grid_dims, cluster_dims) != Status::kSuccess) {
232
+ CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting.");
233
+ return Status::kInvalid;
234
+ }
235
+
236
+ auto init_status = init(kernel);
237
+ if (init_status != Status::kSuccess) {
238
+ CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting.");
239
+ return Status::kInvalid;
240
+ }
241
+
242
+ CUTLASS_TRACE_HOST("ClusterLauncher: Launching GridDims = "
243
+ "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), "
244
+ "And ClusterDims = "
245
+ "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
246
+
247
+ cutlass::arch::synclog_setup();
248
+ cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params);
249
+ Return_Status(status);
250
+ #else
251
+ CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
252
+ return Status::kInvalid;
253
+ #endif
254
+ }
255
+
256
+
257
+ // This is the method we expect to use going forward
258
+ // Launch a preferred cluster grid
259
+ static inline CUTLASS_HOST
260
+ Status launch_with_fallback_cluster(
261
+ dim3 const grid_dims,
262
+ dim3 const preferred_cluster_dims,
263
+ dim3 const fallback_cluster_dims,
264
+ dim3 const block_dims,
265
+ size_t const smem_size,
266
+ cudaStream_t cuda_stream,
267
+ void const* kernel,
268
+ void** kernel_params,
269
+ bool launch_with_pdl = false) {
270
+ #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
271
+ LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, preferred_cluster_dims,
272
+ block_dims, smem_size, cuda_stream, launch_with_pdl, fallback_cluster_dims);
273
+
274
+ auto launch_grid_dims = cluster_launch_config.gridDim();
275
+ if (check_cluster_dims(launch_grid_dims, preferred_cluster_dims) != Status::kSuccess) {
276
+ CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting.");
277
+ return Status::kInvalid;
278
+ }
279
+
280
+ auto init_status = init(kernel);
281
+ if (init_status != Status::kSuccess) {
282
+ CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting.");
283
+ return Status::kInvalid;
284
+ }
285
+
286
+ CUTLASS_TRACE_HOST("ClusterLauncher: Launching \n\tGridDims = "
287
+ "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), "
288
+ "\n\tPreferred ClusterDims = "
289
+ "(" << preferred_cluster_dims.x << ", " << preferred_cluster_dims.y << ", " << preferred_cluster_dims.z << "),"
290
+ "\n\tFallback ClusterDims = "
291
+ "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n");
292
+
293
+ cutlass::arch::synclog_setup();
294
+ cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params);
295
+ Return_Status(status);
296
+ #else
297
+ CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
298
+ return Status::kInvalid;
299
+ #endif
300
+ }
301
+
302
+
303
+ };
304
+
305
+ namespace detail {
306
+
307
+ template<class Arg>
308
+ void* checked_addressof(Arg&& arg) {
309
+ static_assert(! std::is_rvalue_reference_v<Arg> || ! std::is_const_v<Arg>, "You cannot take the address of a const rvalue reference (const T&&).");
310
+ // We use std::addressof to ensure we get the address,
311
+ // in case the type has an overloaded operator&.
312
+ // Note that this precludes `const T&&` references.
313
+ return const_cast<void*>(reinterpret_cast<void const*>(std::addressof(arg)));
314
+ }
315
+
316
+ } // namespace detail
317
+
318
+ //! Parameters for launch_on_cluster (see below).
319
+ struct ClusterLaunchParams {
320
+ //! Grid dimensions
321
+ dim3 grid_dims{1, 1, 1};
322
+
323
+ //! Block dimensions
324
+ dim3 block_dims{1, 1, 1};
325
+
326
+ //! Cluster dimensions
327
+ dim3 cluster_dims{1, 1, 1};
328
+
329
+ //! Number of bytes required for the kernel's shared memory.
330
+ int smem_size_in_bytes = 0;
331
+
332
+ //! CUDA stream on which to launch the kernel.
333
+ cudaStream_t cuda_stream = nullptr;
334
+ };
335
+
336
+ /// @brief Launch the kernel on the stream using cluster launch.
337
+ ///
338
+ /// @param params Cluster launch parameters (see above).
339
+ /// @param kernel_ptr Pointer to the kernel function (see example).
340
+ /// @param args Zero or more arguments to pass to the kernel.
341
+ ///
342
+ /// @tparam Args Types of the arguments passed to the kernel.
343
+ /// Don't specify this/these template argument(s) explicitly.
344
+ ///
345
+ /// @return Status::Success on success, else an error code.
346
+ ///
347
+ /// @code
348
+ /// template<class SharedMemoryType, class A, class B, class C>
349
+ /// __global__ void kernel(A a, B b, C c);
350
+ ///
351
+ /// X x = get_x();
352
+ /// Y y = get_y();
353
+ /// Z z = get_z();
354
+ ///
355
+ /// void const* kernel_ptr =
356
+ /// const_cast<void const*>(reinterpret_cast<void*>(
357
+ /// &kernel<SharedMemory, X, Y, Z>));
358
+ /// auto status = launch_kernel_on_cluster(
359
+ /// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)},
360
+ /// kernel_ptr, x, y, z);
361
+ /// @endcode
362
+ template<class ... Args>
363
+ CUTLASS_HOST cutlass::Status
364
+ launch_kernel_on_cluster(const ClusterLaunchParams& params,
365
+ void const* kernel_ptr,
366
+ Args&& ... args)
367
+ {
368
+ // Unfortunately, we find ourselves needing to pass in
369
+ // the parameters as an array of raw pointers.
370
+ if constexpr (sizeof...(Args) == 0) {
371
+ return cutlass::ClusterLauncher::launch(
372
+ params.grid_dims,
373
+ params.cluster_dims,
374
+ params.block_dims,
375
+ params.smem_size_in_bytes,
376
+ params.cuda_stream,
377
+ kernel_ptr, nullptr);
378
+ }
379
+ else {
380
+ void* kernel_params[sizeof...(Args)] = {
381
+ detail::checked_addressof(std::forward<Args>(args))...
382
+ };
383
+ return cutlass::ClusterLauncher::launch(
384
+ params.grid_dims,
385
+ params.cluster_dims,
386
+ params.block_dims,
387
+ params.smem_size_in_bytes,
388
+ params.cuda_stream,
389
+ kernel_ptr,
390
+ kernel_params);
391
+ }
392
+ }
393
+
394
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include <cuComplex.h>
35
+
36
+ #include <cuda_fp16.h>
37
+ #include "cutlass/cutlass.h"
38
+ #if defined(__CUDACC_RTC__)
39
+ #include CUDA_STD_HEADER(cstdint)
40
+ #else
41
+ #include <cstdint>
42
+ #endif
43
+ #include "cutlass/functional.h"
44
+ #include "cutlass/platform/platform.h"
45
+ #include "cutlass/real.h"
46
+
47
+ #include "cutlass/numeric_types.h"
48
+
49
+ #include "cutlass/fast_math.h"
50
+
51
+ #if !defined(__CUDACC_RTC__)
52
+ #include <iosfwd>
53
+ #endif
54
+
55
+ namespace cutlass {
56
+
57
+ /////////////////////////////////////////////////////////////////////////////////////////////////
58
+ /// Enumeraed type describing a transformation on a complex value.
59
+ enum class ComplexTransform {
60
+ kNone,
61
+ kConjugate
62
+ };
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+ /// Defines ComplexTransform inversions
66
+ template <ComplexTransform kTransform>
67
+ struct InvertComplexTransform;
68
+
69
+ /// Invert ComplexTransform from kNone to kConjugate
70
+ template <>
71
+ struct InvertComplexTransform<ComplexTransform::kNone> {
72
+ static ComplexTransform const transform = ComplexTransform::kConjugate;
73
+ };
74
+
75
+ /// Invert ComplexTransform from kConjugate to kNone
76
+ template <>
77
+ struct InvertComplexTransform<ComplexTransform::kConjugate> {
78
+ static ComplexTransform const transform = ComplexTransform::kNone;
79
+ };
80
+ /////////////////////////////////////////////////////////////////////////////////////////////////
81
+ //////////////////////////////////////////////////////////////////////////////////////////////////
82
+
83
+ //
84
+ // Accessors for CUDA complex types
85
+ //
86
+
87
+ #if !defined(__CUDACC_RTC__)
88
+ /// Returns the real part of the complex number
89
+ CUTLASS_HOST_DEVICE
90
+ float const &real(cuFloatComplex const &z) { return z.x; }
91
+
92
+ /// Returns the real part of the complex number
93
+ CUTLASS_HOST_DEVICE
94
+ float &real(cuFloatComplex &z) { return z.x; }
95
+
96
+ /// Returns the real part of the complex number
97
+ CUTLASS_HOST_DEVICE
98
+ double const &real(cuDoubleComplex const &z) { return z.x; }
99
+
100
+ /// Returns the real part of the complex number
101
+ CUTLASS_HOST_DEVICE
102
+ double &real(cuDoubleComplex &z) { return z.x; }
103
+
104
+ /// Returns the imaginary part of the complex number
105
+ CUTLASS_HOST_DEVICE
106
+ float const &imag(cuFloatComplex const &z) { return z.y; }
107
+
108
+ /// Returns the imaginary part of the complex number
109
+ CUTLASS_HOST_DEVICE
110
+ float &imag(cuFloatComplex &z) { return z.y; }
111
+
112
+ /// Returns the imaginary part of the complex number
113
+ CUTLASS_HOST_DEVICE
114
+ double const &imag(cuDoubleComplex const &z) { return z.y; }
115
+
116
+ /// Returns the imaginary part of the complex number
117
+ CUTLASS_HOST_DEVICE
118
+ double &imag(cuDoubleComplex &z) { return z.y; }
119
+
120
+ // Returns the conjugate of the complex number
121
+ CUTLASS_HOST_DEVICE cuFloatComplex
122
+ conj(cuFloatComplex const& z) {
123
+ return make_cuFloatComplex(z.x, -z.y);
124
+ }
125
+
126
+ // Returns the conjugate of the complex number
127
+ CUTLASS_HOST_DEVICE cuDoubleComplex
128
+ conj(cuDoubleComplex const& z) {
129
+ return make_cuDoubleComplex(z.x, -z.y);
130
+ }
131
+ #endif
132
+
133
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
134
+
135
+ /// Class for representing and manipulating complex numbers with conversions from built-in CUDA
136
+ /// complex types.
137
+
138
+ template <typename T>
139
+ class complex
140
+ {
141
+ public:
142
+ /// Type alias for scalar type
143
+ using value_type = T;
144
+
145
+ private:
146
+ //
147
+ // Data members
148
+ //
149
+
150
+ /// Real part
151
+ T _real;
152
+
153
+ /// Imaginary part
154
+ T _imag;
155
+
156
+ public:
157
+
158
+ //
159
+ // Methods
160
+ //
161
+
162
+ /// Default constructor
163
+ complex() = default;
164
+
165
+ /// Constructor
166
+ CUTLASS_HOST_DEVICE
167
+ complex(T r) : _real(r), _imag(T(0)) {}
168
+
169
+ /// Constructor
170
+ CUTLASS_HOST_DEVICE
171
+ complex(T r, T i) : _real(r), _imag(i) {}
172
+
173
+ /// Constructor
174
+ template<typename A>
175
+ CUTLASS_HOST_DEVICE
176
+ complex(complex<A> const &z) : _real(static_cast<T>(z.real())), _imag(static_cast<T>(z.imag())) {}
177
+
178
+
179
+ #if !defined(__CUDACC_RTC__)
180
+ /// Conversion from cuFloatComplex
181
+ CUTLASS_HOST_DEVICE
182
+ complex(cuFloatComplex const &z) : _real(static_cast<T>(cuCrealf(z))), _imag(static_cast<T>(cuCimagf(z))) {}
183
+
184
+ /// Conversion from cuDoubleComplex
185
+ CUTLASS_HOST_DEVICE
186
+ complex(cuDoubleComplex const &z) : _real(static_cast<T>(cuCreal(z))), _imag(static_cast<T>(cuCimag(z))) {}
187
+ #endif
188
+
189
+ /// Equality operator
190
+ CUTLASS_HOST_DEVICE bool operator==(complex<T> const &rhs) const {
191
+ return this->real() == rhs.real() && this->imag() == rhs.imag();
192
+ }
193
+
194
+ /// Inequality operator
195
+ CUTLASS_HOST_DEVICE bool operator!=(complex<T> const &rhs) const {
196
+ return !(*this == rhs);
197
+ }
198
+
199
+ /// Addition
200
+ template <typename A>
201
+ CUTLASS_HOST_DEVICE complex<T> operator+(complex<A> const &rhs) const {
202
+ return complex<T>(this->real() + rhs.real(), this->imag() + rhs.imag());
203
+ }
204
+
205
+ /// Reduction into memory address. Components may update out of order.
206
+ template <typename OtherT>
207
+ CUTLASS_DEVICE void red(complex<OtherT> *ptr) const {
208
+ static_assert(platform::is_same<T, OtherT>::value, "Component type must match");
209
+ cutlass::atomic_add<T> reduce;
210
+ reduce(&ptr->_real, _real);
211
+ reduce(&ptr->_imag, _imag);
212
+ }
213
+
214
+ /// Reduction into memory address. Components may update out of order. (Half specialization)
215
+ CUTLASS_DEVICE void red(complex<half_t> *ptr) const {
216
+ static_assert(platform::is_same<T, half_t>::value, "Component type must match");
217
+ half2 *h2_ptr = reinterpret_cast<half2*>(ptr);
218
+ half2 h2_data = reinterpret_cast<half2&>(*this);
219
+ cutlass::atomic_add<half2> reduce;
220
+ reduce(h2_ptr, h2_data);
221
+ }
222
+
223
+ /// Subtraction
224
+ template <typename A>
225
+ CUTLASS_HOST_DEVICE complex<T> operator-(complex<A> const &rhs) const {
226
+ return complex<T>(this->real() - rhs.real(), this->imag() - rhs.imag());
227
+ }
228
+
229
+ /// Multiplication
230
+ template <typename A>
231
+ CUTLASS_HOST_DEVICE complex<T> operator*(complex<A> const &rhs) const {
232
+ return complex<T>(this->real() * rhs.real() - this->imag() * rhs.imag(),
233
+ this->real() * rhs.imag() + this->imag() * rhs.real());
234
+ }
235
+
236
+ /// Scalar Multiplication
237
+ template <typename A>
238
+ CUTLASS_HOST_DEVICE complex<T> operator*(A const &s) const {
239
+ return complex<T>(this->real() * s, this->imag() * s);
240
+ }
241
+
242
+ /// Division
243
+ template <typename A>
244
+ CUTLASS_HOST_DEVICE complex<T> operator/(complex<A> const &rhs) const {
245
+ T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag());
246
+
247
+ return complex<T>(
248
+ (real() * rhs.real() + imag() * rhs.imag()) / d,
249
+ (imag() * rhs.real() - real() * rhs.imag()) / d
250
+ );
251
+ }
252
+
253
+ /// Scalar Division
254
+ template <typename A>
255
+ CUTLASS_HOST_DEVICE complex<T> operator/(A const &s) const {
256
+ return complex<T>(this->real() / s, this->imag() / s);
257
+ }
258
+
259
+ /// Addition
260
+ template <typename A>
261
+ CUTLASS_HOST_DEVICE complex<T> &operator+=(complex<A> const &rhs) {
262
+ *this = *this + rhs;
263
+ return *this;
264
+ }
265
+
266
+ /// Subtraction
267
+ template <typename A>
268
+ CUTLASS_HOST_DEVICE complex<T> &operator-=(complex<A> const &rhs) {
269
+ *this = *this - rhs;
270
+ return *this;
271
+ }
272
+
273
+ /// Multiplication
274
+ template <typename A>
275
+ CUTLASS_HOST_DEVICE complex<T> &operator*=(complex<A> const &rhs) {
276
+ *this = *this * rhs;
277
+ return *this;
278
+ }
279
+
280
+ /// Scalar multiplication
281
+ template <typename A>
282
+ CUTLASS_HOST_DEVICE complex<T> &operator*=(A s) {
283
+ *this = *this * s;
284
+ return *this;
285
+ }
286
+
287
+ /// Division
288
+ template <typename A>
289
+ CUTLASS_HOST_DEVICE complex<T> &operator/=(complex<A> const &rhs) {
290
+ *this = *this / rhs;
291
+ return *this;
292
+ }
293
+
294
+ /// Accesses the real part of the complex number
295
+ CUTLASS_HOST_DEVICE
296
+ T const &real() const { return _real; }
297
+
298
+ /// Accesses the real part of the complex number
299
+ CUTLASS_HOST_DEVICE
300
+ T &real() { return _real; }
301
+
302
+ /// Accesses the imaginary part of the complex number
303
+ CUTLASS_HOST_DEVICE
304
+ T const &imag() const { return _imag; }
305
+
306
+ /// Accesses the imaginary part of the complex number
307
+ CUTLASS_HOST_DEVICE
308
+ T &imag() { return _imag; }
309
+
310
+ /// Set the real part of the complex number
311
+ CUTLASS_HOST_DEVICE
312
+ void real(T real) { _real = real; }
313
+
314
+ /// Set the imaginary part of the complex number
315
+ CUTLASS_HOST_DEVICE
316
+ void imag(T imag) { _imag = imag; }
317
+
318
+ #if !defined(__CUDACC_RTC__)
319
+ /// Converts to cuFloatComplex
320
+ CUTLASS_HOST_DEVICE
321
+ explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); }
322
+
323
+ /// Converts to cuDoubleComplex
324
+ CUTLASS_HOST_DEVICE
325
+ explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); }
326
+ #endif
327
+ };
328
+
329
+ // Complex conjugate
330
+ template<class T>
331
+ CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const& z) {
332
+ return {z.real(), -z.imag()};
333
+ }
334
+
335
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
336
+
337
+ //
338
+ // Accessors for complex template
339
+ //
340
+
341
+ // Nonmember real and imag need to work for non-complex numbers too.
342
+ // That means cutlass::complex, std::complex, cuda::std::complex, and
343
+ // any user-defined complex number type that looks like std::complex.
344
+ // It's reasonable to assume that a "complex number type" has
345
+ // zero-argument real() and imag() member functions returning
346
+ // non-void. While cuFloatComplex and cuDoubleComplex lack those
347
+ // member functions, one-argument nonmember real and imag overloads
348
+ // for those types are defined above.
349
+
350
+ namespace detail {
351
+
352
+ template <typename T, typename Enable = void>
353
+ struct has_zero_argument_real_member_function :
354
+ cutlass::platform::false_type
355
+ {};
356
+
357
+ template <typename T>
358
+ struct has_zero_argument_real_member_function<T,
359
+ cutlass::platform::enable_if_t<
360
+ ! cutlass::platform::is_void_v<
361
+ decltype(cutlass::platform::declval<T>().real())
362
+ >
363
+ >
364
+ > : cutlass::platform::true_type
365
+ {};
366
+
367
+ template <typename T>
368
+ constexpr bool has_zero_argument_real_member_function_v =
369
+ has_zero_argument_real_member_function<T>::value;
370
+
371
+ template <typename T, typename Enable = void>
372
+ struct has_zero_argument_imag_member_function :
373
+ cutlass::platform::false_type
374
+ {};
375
+
376
+ template <typename T>
377
+ struct has_zero_argument_imag_member_function<T,
378
+ cutlass::platform::enable_if_t<
379
+ ! cutlass::platform::is_void_v<
380
+ decltype(cutlass::platform::declval<T>().imag())
381
+ >
382
+ >
383
+ > : cutlass::platform::true_type
384
+ {};
385
+
386
+ template <typename T>
387
+ constexpr bool has_zero_argument_imag_member_function_v =
388
+ has_zero_argument_imag_member_function<T>::value;
389
+
390
+ } // namespace detail
391
+
392
+ template<typename T>
393
+ CUTLASS_HOST_DEVICE auto real(T z) {
394
+ if constexpr (detail::has_zero_argument_real_member_function_v<T>) {
395
+ return z.real();
396
+ } else {
397
+ return z;
398
+ }
399
+ }
400
+
401
+ template<typename T>
402
+ CUTLASS_HOST_DEVICE auto imag(T z) {
403
+ if constexpr (detail::has_zero_argument_imag_member_function_v<T>) {
404
+ return z.imag();
405
+ } else {
406
+ // Imaginary part of a non-complex input has the same type as the
407
+ // input, and its value is zero. CUTLASS assumes in this case
408
+ // that value-initializing T is well-formed and results in zero.
409
+ return T{};
410
+ }
411
+ }
412
+
413
+ //
414
+ // Output operators
415
+ //
416
+
417
+ #if !defined(__CUDACC_RTC__)
418
+ template <typename T>
419
+ std::ostream &operator<<(std::ostream &out, complex<T> const &z) {
420
+ T _r = real(z);
421
+ T _i = imag(z);
422
+
423
+ if (bool(_i)) {
424
+ return out << _r << "+i" << _i;
425
+ }
426
+ return out << _r;
427
+ }
428
+ #endif
429
+
430
+ //
431
+ // Non-member operators defined for complex types
432
+ //
433
+
434
+
435
+ //
436
+ // Non-member functions defined for complex numbers
437
+ //
438
+
439
+ // abs returns the magnitude of the complex number.
440
+
441
+ CUTLASS_HOST_DEVICE float abs(complex<float> const &z) {
442
+ return ::hypot(z.real(), z.imag());
443
+ }
444
+
445
+ CUTLASS_HOST_DEVICE double abs(complex<double> const &z) {
446
+ return ::hypot(z.real(), z.imag());
447
+ }
448
+
449
+ // In theory, it would make sense to add a complex<long double>
450
+ // specialization of abs here, since hypot works for long double too.
451
+ // In practice, long double doesn't have a portable number of bits or
452
+ // behavior, so users who care about higher-precision floating-point
453
+ // computation should probably insist on an actual FP128 type.
454
+
455
+ template <typename T>
456
+ CUTLASS_HOST_DEVICE T abs(complex<T> const &z) {
457
+ // cutlass::complex permits all kinds of T, including types that
458
+ // don't have NaN. For a generic floating-point type with Inf
459
+ // and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it
460
+ // would handle issues like avoiding unwarranted overflow if
461
+ // z.real() or z.imag() is slightly bigger than the square root of
462
+ // the max finite number. That could be a future improvement; for
463
+ // now, the code just uses the naive algorithm.
464
+ //
465
+ // Use the "swap two-step" idiom so that argument-dependent lookup
466
+ // can find any CUTLASS-specific overloads.
467
+ using cutlass::sqrt;
468
+ return sqrt(z.real() * z.real() + z.imag() * z.imag());
469
+ }
470
+
471
+ /// Returns the magnitude of the complex number
472
+ template <typename T>
473
+ CUTLASS_HOST_DEVICE T arg(complex<T> const &z) {
474
+ return atan2(imag(z), real(z));
475
+ }
476
+
477
+ /// Returns the squared magnitude of a real number
478
+ template <typename T>
479
+ CUTLASS_HOST_DEVICE T norm(T const &z) {
480
+ return z * z;
481
+ }
482
+
483
+ /// Returns the squared magnitude of a real number
484
+ template <>
485
+ CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) {
486
+ return static_cast<int8_t>(z * z);
487
+ }
488
+
489
+ /// Returns the squared magnitude of a complex number
490
+ template <typename T>
491
+ CUTLASS_HOST_DEVICE double norm(complex<T> const &z) {
492
+ return real(z) * real(z) + imag(z) * imag(z);
493
+ }
494
+
495
+ /// Norm-accumulate calculation
496
+ template <typename T, typename R>
497
+ CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) {
498
+ return accumulator + static_cast<R>(x) * static_cast<R>(x);
499
+ }
500
+
501
+ /// Norm accumulate specialized for complex types
502
+ template <typename T, typename R>
503
+ CUTLASS_HOST_DEVICE R norm_accumulate(complex<T> const &z, R const &accumulator) {
504
+ return accumulator + static_cast<R>(real(z)) * static_cast<R>(real(z)) +
505
+ static_cast<R>(imag(z)) * static_cast<R>(imag(z));
506
+ }
507
+
508
+ namespace detail {
509
+
510
+ template<class T>
511
+ CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) {
512
+ return conj(z);
513
+ }
514
+
515
+ template<class T>
516
+ CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) {
517
+ return z;
518
+ }
519
+
520
+ template<class T>
521
+ CUTLASS_HOST_DEVICE T conj_impl(T const& z) {
522
+ constexpr bool use_unqualified_conj =
523
+ ! cutlass::platform::is_arithmetic_v<T> &&
524
+ ! detail::has_cutlass_conj_v<T> &&
525
+ detail::has_unqualified_conj_v<T>;
526
+ return conj_impl(z, cutlass::platform::bool_constant<use_unqualified_conj>{});
527
+ }
528
+
529
+ } // namespace detail
530
+
531
+ // Return the complex conjugate of the input.
532
+ //
533
+ // This MUST be a function and not a function object, because it may
534
+ // be common practice for downstream types to define specifically
535
+ // cutlass::conj overloads, instead of overloads in their namespace.
536
+ //
537
+ // As a result of this being a function and not a function object,
538
+ // CUTLASS code needs to declare "using cutlass::conj;" in scope and
539
+ // then call this function unqualified, just like std::swap.
540
+ //
541
+ // If an overload already exists for cutlass::conj(T), that overload
542
+ // will be called instead of this one. Otherwise:
543
+ //
544
+ // 1. for arithmetic types, return z;
545
+ //
546
+ // 2. for types where (namespace-unqualified) conj(z) is well formed
547
+ // and cutlass::conj(z) is NOT well formed, return conj(z); and,
548
+ //
549
+ // 3. for everything else, return z.
550
+ //
551
+ // Regarding (1), the C++ Standard Library makes std::conj always
552
+ // return std::complex, even for (noncomplex) arithmetic types.
553
+ // cutlass::conj(T t) needs to return type T. This follows the
554
+ // convention of linear algebra software like the BLAS, where
555
+ // "conjugate transpose" means the same thing as "transpose" for a
556
+ // matrix of noncomplex numbers.
557
+ //
558
+ // Case (2) covers std::complex, cuda::std::complex, and non-Standard
559
+ // (including user-defined) complex number types (for which "conj(z)"
560
+ // is findable via argument-dependent lookup, but does not live in the
561
+ // cutlass namespace). It excludes cutlass::conj(z) in order to
562
+ // prevent infinite recursion.
563
+ //
564
+ // Case (3) covers non-Standard non-complex number types.
565
+ template<class T>
566
+ CUTLASS_HOST_DEVICE T conj(T const& z) {
567
+ return detail::conj_impl(z);
568
+ }
569
+
570
+ /// Projects the complex number z onto the Riemann sphere
571
+ template <typename T>
572
+ CUTLASS_HOST_DEVICE complex<T> proj(complex<T> const &z) {
573
+ T d = real(z) * real(z) + imag(z) * imag(z) + T(1);
574
+ return complex<T>((T(2) * real(z)) / d, (T(2) * imag(z)) / d);
575
+ }
576
+
577
+ /// Returns a complex number with magnitude r and phase theta
578
+ template <typename T>
579
+ CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
580
+ return complex<T>(r * cos(theta), r * sin(theta));
581
+ }
582
+
583
+ /// Computes the complex exponential of z.
584
+ template <typename T>
585
+ CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
586
+ return complex<T>(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z)));
587
+ }
588
+
589
+ /// Computes the log of z
590
+ template <typename T>
591
+ CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
592
+ return complex<T>(log(abs(z)), arg(z));
593
+ }
594
+
595
+ /// Computes the log base 10 of z
596
+ template <typename T>
597
+ CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
598
+ return log(z) / T(log(T(10)));
599
+ }
600
+
601
+ /// Computes the square root of complex number z
602
+ template <typename T>
603
+ CUTLASS_HOST_DEVICE complex<T> sqrt(complex<T> const &z) {
604
+ return sqrt(T(2)) / T(2) *
605
+ complex<T>(sqrt(sqrt(norm(z)) + real(z)),
606
+ (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z)));
607
+ }
608
+
609
+ /// Computes the cosine of complex z.
610
+ template <typename T>
611
+ CUTLASS_HOST_DEVICE complex<T> cos(complex<T> const &z) {
612
+ return (exp(z) + exp(-z)) / T(2);
613
+ }
614
+
615
+ /// Computes the sin of complex z.
616
+ template <typename T>
617
+ CUTLASS_HOST_DEVICE complex<T> sin(complex<T> const &z) {
618
+ return (exp(-z) - exp(z)) * complex<T>(T(0), T(1) / T(2));
619
+ }
620
+
621
+ /// Comparison
622
+ template <typename T>
623
+ CUTLASS_HOST_DEVICE bool operator<(complex<T> const &lhs, complex<T> const &rhs) {
624
+ return true;
625
+ }
626
+
627
+ //////////////////////////////////////////////////////////////////////////////////////////////////
628
+
629
+ /// Partial specialization for complex-valued type.
630
+ template <typename T>
631
+ struct RealType< complex<T> >
632
+ {
633
+ using Type = T;
634
+
635
+ /// Number of elements
636
+ static int const kExtent = 2;
637
+
638
+ CUTLASS_HOST_DEVICE
639
+ static complex<T> from_real(double x) {
640
+ return complex<T>(static_cast<T>(x));
641
+ }
642
+ };
643
+
644
+ /////////////////////////////////////////////////////////////////////////////////////////////////
645
+
646
+ template <>
647
+ CUTLASS_HOST_DEVICE
648
+ cutlass::complex<half_t> from_real<cutlass::complex<half_t> >(double r) {
649
+ return cutlass::complex<half_t>(half_t(r));
650
+ }
651
+
652
+ template <>
653
+ CUTLASS_HOST_DEVICE
654
+ cutlass::complex<float> from_real<cutlass::complex<float> >(double r) {
655
+ return cutlass::complex<float>(float(r));
656
+ }
657
+
658
+ template <>
659
+ CUTLASS_HOST_DEVICE
660
+ cutlass::complex<double> from_real<cutlass::complex<double> >(double r) {
661
+ return cutlass::complex<double>(r);
662
+ }
663
+
664
+ //////////////////////////////////////////////////////////////////////////////////////////////////
665
+
666
+ template <typename T>
667
+ struct is_complex {
668
+ static bool const value = false;
669
+ };
670
+
671
+ template <typename T>
672
+ struct is_complex<complex<T>> {
673
+ static bool const value = true;
674
+ };
675
+
676
+
677
+ /////////////////////////////////////////////////////////////////////////////////////////////////
678
+ // functional.h numeric specializations
679
+ /////////////////////////////////////////////////////////////////////////////////////////////////
680
+
681
+ /// Squares with optional conversion
682
+ template <typename T, typename Output>
683
+ struct magnitude_squared<complex<T>, Output> {
684
+ CUTLASS_HOST_DEVICE
685
+ Output operator()(complex<T> lhs) const {
686
+ multiplies<Output> mul_op;
687
+
688
+ Output y_r = Output(lhs.real());
689
+ Output y_i = Output(lhs.imag());
690
+
691
+ return mul_op(y_r, y_r) + mul_op(y_i, y_i);
692
+ }
693
+ };
694
+
695
+ /// Fused multiply-add
696
+ template <typename T>
697
+ struct multiply_add<complex<T>, complex<T>, complex<T>> {
698
+ CUTLASS_HOST_DEVICE
699
+ complex<T> operator()(
700
+ complex<T> const &a,
701
+ complex<T> const &b,
702
+ complex<T> const &c) const {
703
+
704
+ T real = c.real();
705
+ T imag = c.imag();
706
+
707
+ real += a.real() * b.real();
708
+ real += -a.imag() * b.imag();
709
+ imag += a.real() * b.imag();
710
+ imag += a.imag () * b.real();
711
+
712
+ return complex<T>{
713
+ real,
714
+ imag
715
+ };
716
+ }
717
+ };
718
+
719
+ /// Fused multiply-add
720
+ template <typename T>
721
+ struct multiply_add<complex<T>, T, complex<T>> {
722
+ CUTLASS_HOST_DEVICE
723
+ complex<T> operator()(
724
+ complex<T> const &a,
725
+ T const &b,
726
+ complex<T> const &c) const {
727
+
728
+ T real = c.real();
729
+ T imag = c.imag();
730
+
731
+ real += a.real() * b;
732
+ imag += a.imag () * b;
733
+
734
+ return complex<T>{
735
+ real,
736
+ imag
737
+ };
738
+ }
739
+ };
740
+
741
+ /// Fused multiply-add
742
+ template <typename T>
743
+ struct multiply_add<T, complex<T>, complex<T>> {
744
+ CUTLASS_HOST_DEVICE
745
+ complex<T> operator()(
746
+ T const &a,
747
+ complex<T> const &b,
748
+ complex<T> const &c) const {
749
+
750
+ T real = c.real();
751
+ T imag = c.imag();
752
+
753
+ real += a * b.real();
754
+ imag += a * b.imag();
755
+
756
+ return complex<T>{
757
+ real,
758
+ imag
759
+ };
760
+ }
761
+ };
762
+
763
+ /// Conjugate
764
+ template <typename T>
765
+ struct conjugate<complex<T>> {
766
+ CUTLASS_HOST_DEVICE
767
+ complex<T> operator()(complex<T> const &a) const {
768
+ // Invoke the complex<T> overload specifically, rather than
769
+ // wasting the compiler's effort on overload resolution.
770
+ return cutlass::conj(a);
771
+ }
772
+ };
773
+
774
+ #if ! defined(__CUDACC_RTC__)
775
+ template <>
776
+ struct conjugate<cuFloatComplex> {
777
+ CUTLASS_HOST_DEVICE
778
+ cuFloatComplex operator()(cuFloatComplex const& z) const {
779
+ return make_cuFloatComplex(z.x, -z.y);
780
+ }
781
+ };
782
+
783
+ template <>
784
+ struct conjugate<cuDoubleComplex> {
785
+ CUTLASS_HOST_DEVICE
786
+ cuDoubleComplex operator()(cuDoubleComplex const& z) const {
787
+ return make_cuDoubleComplex(z.x, -z.y);
788
+ }
789
+ };
790
+ #endif
791
+
792
+ /// Computes the square of a difference with optional conversion
793
+ template <typename T, typename Output>
794
+ struct magnitude_squared_difference<complex<T>, Output> {
795
+ CUTLASS_HOST_DEVICE
796
+ Output operator()(complex<T> lhs, complex<T> rhs) const {
797
+ multiplies<Output> mul_op;
798
+
799
+ Output y_r = Output(lhs.real()) - Output(rhs.real());
800
+ Output y_i = Output(lhs.imag()) - Output(rhs.imag());
801
+
802
+ return mul_op(y_r, y_r) + mul_op(y_i, y_i);
803
+ }
804
+ };
805
+
806
+ /// Reduces value into the data pointed to by ptr (complex<T> specialization)
807
+ template <typename T>
808
+ struct atomic_add<complex<T>> {
809
+ CUTLASS_DEVICE
810
+ void operator()(complex<T> *ptr, const complex<T> &data)
811
+ {
812
+ data.red(ptr);
813
+ }
814
+ };
815
+
816
+
817
+ //////////////////////////////////////////////////////////////////////////////////////////////////
818
+
819
+ } // namespace cutlass
820
+
821
+ //////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h ADDED
@@ -0,0 +1,1239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /* \file
33
+ \brief Boost-style constant definitions for floating-point types.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/numeric_types.h"
40
+
41
+ #include "cutlass/complex.h"
42
+
43
+ ///////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace constants {
47
+
48
+ ///////////////////////////////////////////////////////////////////////////////////
49
+
50
+ //
51
+ // Primary templates
52
+ //
53
+
54
+ /// Returns 1, the multiplicative identity element
55
+ template <typename T> CUTLASS_HOST_DEVICE T one();
56
+
57
+ /// Returns 0, the additive identity element
58
+ template <typename T> CUTLASS_HOST_DEVICE T zero();
59
+
60
+ /// Returns 2
61
+ template <typename T> CUTLASS_HOST_DEVICE T two();
62
+
63
+ /// Returns pi, approximately 3.141
64
+ template <typename T> CUTLASS_HOST_DEVICE T pi();
65
+
66
+ /// Returns 2 * pi
67
+ template <typename T> CUTLASS_HOST_DEVICE T two_pi();
68
+
69
+ /// Returns pi / 2
70
+ template <typename T> CUTLASS_HOST_DEVICE T half_pi();
71
+
72
+ /// Returns sqrt(pi)
73
+ template <typename T> CUTLASS_HOST_DEVICE T root_pi();
74
+
75
+ /// Returns sqrt(pi / 2)
76
+ template <typename T> CUTLASS_HOST_DEVICE T root_half_pi();
77
+
78
+ /// Returns sqrt(2 * pi)
79
+ template <typename T> CUTLASS_HOST_DEVICE T root_two_pi();
80
+
81
+ /// Returns sqrt(ln(4))
82
+ template <typename T> CUTLASS_HOST_DEVICE T root_ln_four();
83
+
84
+ /// Returns e, approximately 2.718...
85
+ template <typename T> CUTLASS_HOST_DEVICE T e();
86
+
87
+ /// Returns (1/2)
88
+ template <typename T> CUTLASS_HOST_DEVICE T half();
89
+
90
+ /// Returns sqrt(2), approximately 1.414...
91
+ template <typename T> CUTLASS_HOST_DEVICE T root_two();
92
+
93
+ /// Returns sqrt(2)/2, approximately 0.707...
94
+ template <typename T> CUTLASS_HOST_DEVICE T half_root_two();
95
+
96
+ /// Returns ln(2), approximately 0.693...
97
+ template <typename T> CUTLASS_HOST_DEVICE T ln_two();
98
+
99
+ /// Returns ln(ln(2)), approximately -0.3665...
100
+ template <typename T> CUTLASS_HOST_DEVICE T ln_ln_two();
101
+
102
+ /// Returns 1/3, approximately 0.333...
103
+ template <typename T> CUTLASS_HOST_DEVICE T third();
104
+
105
+ /// Returns 2/3, approximately 0.666...
106
+ template <typename T> CUTLASS_HOST_DEVICE T twothirds();
107
+
108
+ /// Returns pi - 3, approximately 0.1416...
109
+ template <typename T> CUTLASS_HOST_DEVICE T pi_minus_three();
110
+
111
+ /// Returns 4 - pi, approximately 0.858...
112
+ template <typename T> CUTLASS_HOST_DEVICE T four_minus_pi();
113
+
114
+
115
+ /////////////////////////////////////////////////////////////////////////////////////
116
+
117
+ // Specialization for double
118
+
119
+ /// Returns 1, the multiplicative identity element (specialization for double)
120
+ template <> CUTLASS_HOST_DEVICE double one<double>() {
121
+ uint64_t bits = 0x3ff0000000000000ull;
122
+ return reinterpret_cast<double const &>(bits);
123
+ }
124
+
125
+ /// Returns 1, the multiplicative identity element (specialization for complex<double>)
126
+ template <> CUTLASS_HOST_DEVICE complex<double> one< complex<double> >() {
127
+ return complex<double>(one<double>(), double());
128
+ }
129
+
130
+ /// Returns 0, the additive identity element (specialization for double)
131
+ template <> CUTLASS_HOST_DEVICE double zero<double>() {
132
+ uint64_t bits = 0x0ull;
133
+ return reinterpret_cast<double const &>(bits);
134
+ }
135
+
136
+ /// Returns 0, the additive identity element (specialization for complex<double>)
137
+ template <> CUTLASS_HOST_DEVICE complex<double> zero< complex<double> >() {
138
+ return complex<double>(zero<double>(), double());
139
+ }
140
+
141
+ /// Returns 2 (specialization for double)
142
+ template <> CUTLASS_HOST_DEVICE double two<double>() {
143
+ uint64_t bits = 0x4000000000000000ull;
144
+ return reinterpret_cast<double const &>(bits);
145
+ }
146
+
147
+ /// Returns 2 (specialization for complex<double>)
148
+ template <> CUTLASS_HOST_DEVICE complex<double> two< complex<double> >() {
149
+ return complex<double>(two<double>(), double());
150
+ }
151
+
152
+ /// Returns pi, approximately 3.141 (specialization for double)
153
+ template <> CUTLASS_HOST_DEVICE double pi<double>() {
154
+ uint64_t bits = 0x400921fb54442d18ull;
155
+ return reinterpret_cast<double const &>(bits);
156
+ }
157
+
158
+ /// Returns pi, approximately 3.141 (specialization for complex<double>)
159
+ template <> CUTLASS_HOST_DEVICE complex<double> pi< complex<double> >() {
160
+ return complex<double>(pi<double>(), double());
161
+ }
162
+
163
+ /// Returns 2 * pi (specialization for double)
164
+ template <> CUTLASS_HOST_DEVICE double two_pi<double>() {
165
+ uint64_t bits = 0x401921fb54442d18ull;
166
+ return reinterpret_cast<double const &>(bits);
167
+ }
168
+
169
+ /// Returns 2 * pi (specialization for complex<double>)
170
+ template <> CUTLASS_HOST_DEVICE complex<double> two_pi< complex<double> >() {
171
+ return complex<double>(two_pi<double>(), double());
172
+ }
173
+
174
+ /// Returns pi / 2 (specialization for double)
175
+ template <> CUTLASS_HOST_DEVICE double half_pi<double>() {
176
+ uint64_t bits = 0x3ff921fb54442d18ull;
177
+ return reinterpret_cast<double const &>(bits);
178
+ }
179
+
180
+ /// Returns pi / 2 (specialization for complex<double>)
181
+ template <> CUTLASS_HOST_DEVICE complex<double> half_pi< complex<double> >() {
182
+ return complex<double>(half_pi<double>(), double());
183
+ }
184
+
185
+ /// Returns sqrt(pi) (specialization for double)
186
+ template <> CUTLASS_HOST_DEVICE double root_pi<double>() {
187
+ uint64_t bits = 0x3ffc5bf891b4ef6aull;
188
+ return reinterpret_cast<double const &>(bits);
189
+ }
190
+
191
+ /// Returns sqrt(pi) (specialization for complex<double>)
192
+ template <> CUTLASS_HOST_DEVICE complex<double> root_pi< complex<double> >() {
193
+ return complex<double>(root_pi<double>(), double());
194
+ }
195
+
196
+ /// Returns sqrt(pi / 2) (specialization for double)
197
+ template <> CUTLASS_HOST_DEVICE double root_half_pi<double>() {
198
+ uint64_t bits = 0x3ff40d931ff62705ull;
199
+ return reinterpret_cast<double const &>(bits);
200
+ }
201
+
202
+ /// Returns sqrt(pi / 2) (specialization for complex<double>)
203
+ template <> CUTLASS_HOST_DEVICE complex<double> root_half_pi< complex<double> >() {
204
+ return complex<double>(root_half_pi<double>(), double());
205
+ }
206
+
207
+ /// Returns sqrt(2 * pi) (specialization for double)
208
+ template <> CUTLASS_HOST_DEVICE double root_two_pi<double>() {
209
+ uint64_t bits = 0x40040d931ff62705ull;
210
+ return reinterpret_cast<double const &>(bits);
211
+ }
212
+
213
+ /// Returns sqrt(2 * pi) (specialization for complex<double>)
214
+ template <> CUTLASS_HOST_DEVICE complex<double> root_two_pi< complex<double> >() {
215
+ return complex<double>(root_two_pi<double>(), double());
216
+ }
217
+
218
+ /// Returns sqrt(ln(4)) (specialization for double)
219
+ template <> CUTLASS_HOST_DEVICE double root_ln_four<double>() {
220
+ uint64_t bits = 0x3ff2d6abe44afc43ull;
221
+ return reinterpret_cast<double const &>(bits);
222
+ }
223
+
224
+ /// Returns sqrt(ln(4)) (specialization for complex<double>)
225
+ template <> CUTLASS_HOST_DEVICE complex<double> root_ln_four< complex<double> >() {
226
+ return complex<double>(root_ln_four<double>(), double());
227
+ }
228
+
229
+ /// Returns e, approximately 2.718... (specialization for double)
230
+ template <> CUTLASS_HOST_DEVICE double e<double>() {
231
+ uint64_t bits = 0x4005bf0a8b145769ull;
232
+ return reinterpret_cast<double const &>(bits);
233
+ }
234
+
235
+ /// Returns e, approximately 2.718... (specialization for complex<double>)
236
+ template <> CUTLASS_HOST_DEVICE complex<double> e< complex<double> >() {
237
+ return complex<double>(e<double>(), double());
238
+ }
239
+
240
+ /// Returns (1/2) (specialization for double)
241
+ template <> CUTLASS_HOST_DEVICE double half<double>() {
242
+ uint64_t bits = 0x3fe0000000000000ull;
243
+ return reinterpret_cast<double const &>(bits);
244
+ }
245
+
246
+ /// Returns (1/2) (specialization for complex<double>)
247
+ template <> CUTLASS_HOST_DEVICE complex<double> half< complex<double> >() {
248
+ return complex<double>(half<double>(), double());
249
+ }
250
+
251
+ /// Returns sqrt(2), approximately 1.414... (specialization for double)
252
+ template <> CUTLASS_HOST_DEVICE double root_two<double>() {
253
+ uint64_t bits = 0x3ff6a09e667f3bcdull;
254
+ return reinterpret_cast<double const &>(bits);
255
+ }
256
+
257
+ /// Returns sqrt(2), approximately 1.414... (specialization for complex<double>)
258
+ template <> CUTLASS_HOST_DEVICE complex<double> root_two< complex<double> >() {
259
+ return complex<double>(root_two<double>(), double());
260
+ }
261
+
262
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for double)
263
+ template <> CUTLASS_HOST_DEVICE double half_root_two<double>() {
264
+ uint64_t bits = 0x3fe6a09e667f3bcdull;
265
+ return reinterpret_cast<double const &>(bits);
266
+ }
267
+
268
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<double>)
269
+ template <> CUTLASS_HOST_DEVICE complex<double> half_root_two< complex<double> >() {
270
+ return complex<double>(half_root_two<double>(), double());
271
+ }
272
+
273
+ /// Returns ln(2), approximately 0.693... (specialization for double)
274
+ template <> CUTLASS_HOST_DEVICE double ln_two<double>() {
275
+ uint64_t bits = 0x3fe62e42fefa39efull;
276
+ return reinterpret_cast<double const &>(bits);
277
+ }
278
+
279
+ /// Returns ln(2), approximately 0.693... (specialization for complex<double>)
280
+ template <> CUTLASS_HOST_DEVICE complex<double> ln_two< complex<double> >() {
281
+ return complex<double>(ln_two<double>(), double());
282
+ }
283
+
284
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for double)
285
+ template <> CUTLASS_HOST_DEVICE double ln_ln_two<double>() {
286
+ uint64_t bits = 0xbfd774f29bdd6b9full;
287
+ return reinterpret_cast<double const &>(bits);
288
+ }
289
+
290
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<double>)
291
+ template <> CUTLASS_HOST_DEVICE complex<double> ln_ln_two< complex<double> >() {
292
+ return complex<double>(ln_ln_two<double>(), double());
293
+ }
294
+
295
+ /// Returns 1/3, approximately 0.333... (specialization for double)
296
+ template <> CUTLASS_HOST_DEVICE double third<double>() {
297
+ uint64_t bits = 0x3fd5555555555555ull;
298
+ return reinterpret_cast<double const &>(bits);
299
+ }
300
+
301
+ /// Returns 1/3, approximately 0.333... (specialization for complex<double>)
302
+ template <> CUTLASS_HOST_DEVICE complex<double> third< complex<double> >() {
303
+ return complex<double>(third<double>(), double());
304
+ }
305
+
306
+ /// Returns 2/3, approximately 0.666... (specialization for double)
307
+ template <> CUTLASS_HOST_DEVICE double twothirds<double>() {
308
+ uint64_t bits = 0x3fe5555555555555ull;
309
+ return reinterpret_cast<double const &>(bits);
310
+ }
311
+
312
+ /// Returns 2/3, approximately 0.666... (specialization for complex<double>)
313
+ template <> CUTLASS_HOST_DEVICE complex<double> twothirds< complex<double> >() {
314
+ return complex<double>(twothirds<double>(), double());
315
+ }
316
+
317
+ /// Returns pi - 3, approximately 0.1416... (specialization for double)
318
+ template <> CUTLASS_HOST_DEVICE double pi_minus_three<double>() {
319
+ uint64_t bits = 0x3fc21fb54442d180ull;
320
+ return reinterpret_cast<double const &>(bits);
321
+ }
322
+
323
+ /// Returns pi - 3, approximately 0.1416... (specialization for complex<double>)
324
+ template <> CUTLASS_HOST_DEVICE complex<double> pi_minus_three< complex<double> >() {
325
+ return complex<double>(pi_minus_three<double>(), double());
326
+ }
327
+
328
+ /// Returns 4 - pi, approximately 0.858... (specialization for double)
329
+ template <> CUTLASS_HOST_DEVICE double four_minus_pi<double>() {
330
+ uint64_t bits = 0x3feb7812aeef4ba0ull;
331
+ return reinterpret_cast<double const &>(bits);
332
+ }
333
+
334
+ /// Returns 4 - pi, approximately 0.858... (specialization for complex<double>)
335
+ template <> CUTLASS_HOST_DEVICE complex<double> four_minus_pi< complex<double> >() {
336
+ return complex<double>(four_minus_pi<double>(), double());
337
+ }
338
+
339
+ /////////////////////////////////////////////////////////////////////////////////////
340
+
341
+ // Specialization for float
342
+
343
+ /// Returns 1, the multiplicative identity element (specialization for float)
344
+ template <> CUTLASS_HOST_DEVICE float one<float>() {
345
+ uint32_t bits = 0x3f800000u;
346
+ return reinterpret_cast<float const &>(bits);
347
+ }
348
+
349
+ /// Returns 1, the multiplicative identity element (specialization for complex<float>)
350
+ template <> CUTLASS_HOST_DEVICE complex<float> one< complex<float> >() {
351
+ return complex<float>(one<float>(), float());
352
+ }
353
+
354
+ /// Returns 0, the additive identity element (specialization for float)
355
+ template <> CUTLASS_HOST_DEVICE float zero<float>() {
356
+ uint32_t bits = 0x0u;
357
+ return reinterpret_cast<float const &>(bits);
358
+ }
359
+
360
+ /// Returns 0, the additive identity element (specialization for complex<float>)
361
+ template <> CUTLASS_HOST_DEVICE complex<float> zero< complex<float> >() {
362
+ return complex<float>(zero<float>(), float());
363
+ }
364
+
365
+ /// Returns 2 (specialization for float)
366
+ template <> CUTLASS_HOST_DEVICE float two<float>() {
367
+ uint32_t bits = 0x40000000u;
368
+ return reinterpret_cast<float const &>(bits);
369
+ }
370
+
371
+ /// Returns 2 (specialization for complex<float>)
372
+ template <> CUTLASS_HOST_DEVICE complex<float> two< complex<float> >() {
373
+ return complex<float>(two<float>(), float());
374
+ }
375
+
376
+ /// Returns pi, approximately 3.141 (specialization for float)
377
+ template <> CUTLASS_HOST_DEVICE float pi<float>() {
378
+ uint32_t bits = 0x40490fdbu;
379
+ return reinterpret_cast<float const &>(bits);
380
+ }
381
+
382
+ /// Returns pi, approximately 3.141 (specialization for complex<float>)
383
+ template <> CUTLASS_HOST_DEVICE complex<float> pi< complex<float> >() {
384
+ return complex<float>(pi<float>(), float());
385
+ }
386
+
387
+ /// Returns 2 * pi (specialization for float)
388
+ template <> CUTLASS_HOST_DEVICE float two_pi<float>() {
389
+ uint32_t bits = 0x40c90fdbu;
390
+ return reinterpret_cast<float const &>(bits);
391
+ }
392
+
393
+ /// Returns 2 * pi (specialization for complex<float>)
394
+ template <> CUTLASS_HOST_DEVICE complex<float> two_pi< complex<float> >() {
395
+ return complex<float>(two_pi<float>(), float());
396
+ }
397
+
398
+ /// Returns pi / 2 (specialization for float)
399
+ template <> CUTLASS_HOST_DEVICE float half_pi<float>() {
400
+ uint32_t bits = 0x3fc90fdbu;
401
+ return reinterpret_cast<float const &>(bits);
402
+ }
403
+
404
+ /// Returns pi / 2 (specialization for complex<float>)
405
+ template <> CUTLASS_HOST_DEVICE complex<float> half_pi< complex<float> >() {
406
+ return complex<float>(half_pi<float>(), float());
407
+ }
408
+
409
+ /// Returns sqrt(pi) (specialization for float)
410
+ template <> CUTLASS_HOST_DEVICE float root_pi<float>() {
411
+ uint32_t bits = 0x3fe2dfc5u;
412
+ return reinterpret_cast<float const &>(bits);
413
+ }
414
+
415
+ /// Returns sqrt(pi) (specialization for complex<float>)
416
+ template <> CUTLASS_HOST_DEVICE complex<float> root_pi< complex<float> >() {
417
+ return complex<float>(root_pi<float>(), float());
418
+ }
419
+
420
+ /// Returns sqrt(pi / 2) (specialization for float)
421
+ template <> CUTLASS_HOST_DEVICE float root_half_pi<float>() {
422
+ uint32_t bits = 0x3fa06c99u;
423
+ return reinterpret_cast<float const &>(bits);
424
+ }
425
+
426
+ /// Returns sqrt(pi / 2) (specialization for complex<float>)
427
+ template <> CUTLASS_HOST_DEVICE complex<float> root_half_pi< complex<float> >() {
428
+ return complex<float>(root_half_pi<float>(), float());
429
+ }
430
+
431
+ /// Returns sqrt(2 * pi) (specialization for float)
432
+ template <> CUTLASS_HOST_DEVICE float root_two_pi<float>() {
433
+ uint32_t bits = 0x40206c99u;
434
+ return reinterpret_cast<float const &>(bits);
435
+ }
436
+
437
+ /// Returns sqrt(2 * pi) (specialization for complex<float>)
438
+ template <> CUTLASS_HOST_DEVICE complex<float> root_two_pi< complex<float> >() {
439
+ return complex<float>(root_two_pi<float>(), float());
440
+ }
441
+
442
+ /// Returns sqrt(ln(4)) (specialization for float)
443
+ template <> CUTLASS_HOST_DEVICE float root_ln_four<float>() {
444
+ uint32_t bits = 0x3f96b55fu;
445
+ return reinterpret_cast<float const &>(bits);
446
+ }
447
+
448
+ /// Returns sqrt(ln(4)) (specialization for complex<float>)
449
+ template <> CUTLASS_HOST_DEVICE complex<float> root_ln_four< complex<float> >() {
450
+ return complex<float>(root_ln_four<float>(), float());
451
+ }
452
+
453
+ /// Returns e, approximately 2.718... (specialization for float)
454
+ template <> CUTLASS_HOST_DEVICE float e<float>() {
455
+ uint32_t bits = 0x402df854u;
456
+ return reinterpret_cast<float const &>(bits);
457
+ }
458
+
459
+ /// Returns e, approximately 2.718... (specialization for complex<float>)
460
+ template <> CUTLASS_HOST_DEVICE complex<float> e< complex<float> >() {
461
+ return complex<float>(e<float>(), float());
462
+ }
463
+
464
+ /// Returns (1/2) (specialization for float)
465
+ template <> CUTLASS_HOST_DEVICE float half<float>() {
466
+ uint32_t bits = 0x3f000000u;
467
+ return reinterpret_cast<float const &>(bits);
468
+ }
469
+
470
+ /// Returns (1/2) (specialization for complex<float>)
471
+ template <> CUTLASS_HOST_DEVICE complex<float> half< complex<float> >() {
472
+ return complex<float>(half<float>(), float());
473
+ }
474
+
475
+ /// Returns sqrt(2), approximately 1.414... (specialization for float)
476
+ template <> CUTLASS_HOST_DEVICE float root_two<float>() {
477
+ uint32_t bits = 0x3fb504f3u;
478
+ return reinterpret_cast<float const &>(bits);
479
+ }
480
+
481
+ /// Returns sqrt(2), approximately 1.414... (specialization for complex<float>)
482
+ template <> CUTLASS_HOST_DEVICE complex<float> root_two< complex<float> >() {
483
+ return complex<float>(root_two<float>(), float());
484
+ }
485
+
486
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for float)
487
+ template <> CUTLASS_HOST_DEVICE float half_root_two<float>() {
488
+ uint32_t bits = 0x3f3504f3u;
489
+ return reinterpret_cast<float const &>(bits);
490
+ }
491
+
492
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<float>)
493
+ template <> CUTLASS_HOST_DEVICE complex<float> half_root_two< complex<float> >() {
494
+ return complex<float>(half_root_two<float>(), float());
495
+ }
496
+
497
+ /// Returns ln(2), approximately 0.693... (specialization for float)
498
+ template <> CUTLASS_HOST_DEVICE float ln_two<float>() {
499
+ uint32_t bits = 0x3f317218u;
500
+ return reinterpret_cast<float const &>(bits);
501
+ }
502
+
503
+ /// Returns ln(2), approximately 0.693... (specialization for complex<float>)
504
+ template <> CUTLASS_HOST_DEVICE complex<float> ln_two< complex<float> >() {
505
+ return complex<float>(ln_two<float>(), float());
506
+ }
507
+
508
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for float)
509
+ template <> CUTLASS_HOST_DEVICE float ln_ln_two<float>() {
510
+ uint32_t bits = 0xbebba795u;
511
+ return reinterpret_cast<float const &>(bits);
512
+ }
513
+
514
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<float>)
515
+ template <> CUTLASS_HOST_DEVICE complex<float> ln_ln_two< complex<float> >() {
516
+ return complex<float>(ln_ln_two<float>(), float());
517
+ }
518
+
519
+ /// Returns 1/3, approximately 0.333... (specialization for float)
520
+ template <> CUTLASS_HOST_DEVICE float third<float>() {
521
+ uint32_t bits = 0x3eaaaaabu;
522
+ return reinterpret_cast<float const &>(bits);
523
+ }
524
+
525
+ /// Returns 1/3, approximately 0.333... (specialization for complex<float>)
526
+ template <> CUTLASS_HOST_DEVICE complex<float> third< complex<float> >() {
527
+ return complex<float>(third<float>(), float());
528
+ }
529
+
530
+ /// Returns 2/3, approximately 0.666... (specialization for float)
531
+ template <> CUTLASS_HOST_DEVICE float twothirds<float>() {
532
+ uint32_t bits = 0x3f2aaaabu;
533
+ return reinterpret_cast<float const &>(bits);
534
+ }
535
+
536
+ /// Returns 2/3, approximately 0.666... (specialization for complex<float>)
537
+ template <> CUTLASS_HOST_DEVICE complex<float> twothirds< complex<float> >() {
538
+ return complex<float>(twothirds<float>(), float());
539
+ }
540
+
541
+ /// Returns pi - 3, approximately 0.1416... (specialization for float)
542
+ template <> CUTLASS_HOST_DEVICE float pi_minus_three<float>() {
543
+ uint32_t bits = 0x3e10fdaau;
544
+ return reinterpret_cast<float const &>(bits);
545
+ }
546
+
547
+ /// Returns pi - 3, approximately 0.1416... (specialization for complex<float>)
548
+ template <> CUTLASS_HOST_DEVICE complex<float> pi_minus_three< complex<float> >() {
549
+ return complex<float>(pi_minus_three<float>(), float());
550
+ }
551
+
552
+ /// Returns 4 - pi, approximately 0.858... (specialization for float)
553
+ template <> CUTLASS_HOST_DEVICE float four_minus_pi<float>() {
554
+ uint32_t bits = 0x3f5bc095u;
555
+ return reinterpret_cast<float const &>(bits);
556
+ }
557
+
558
+ /// Returns 4 - pi, approximately 0.858... (specialization for complex<float>)
559
+ template <> CUTLASS_HOST_DEVICE complex<float> four_minus_pi< complex<float> >() {
560
+ return complex<float>(four_minus_pi<float>(), float());
561
+ }
562
+
563
+ /////////////////////////////////////////////////////////////////////////////////////
564
+
565
+ // Specialization for tfloat32_t
566
+
567
+ /// Returns 1, the multiplicative identity element (specialization for tfloat32_t)
568
+ template <> CUTLASS_HOST_DEVICE tfloat32_t one<tfloat32_t>() {
569
+ uint32_t bits = 0x3f801000u;
570
+ return reinterpret_cast<tfloat32_t const &>(bits);
571
+ }
572
+
573
+ /// Returns 1, the multiplicative identity element (specialization for complex<tfloat32_t>)
574
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> one< complex<tfloat32_t> >() {
575
+ return complex<tfloat32_t>(one<tfloat32_t>(), tfloat32_t());
576
+ }
577
+
578
+ /// Returns 0, the additive identity element (specialization for tfloat32_t)
579
+ template <> CUTLASS_HOST_DEVICE tfloat32_t zero<tfloat32_t>() {
580
+ uint32_t bits = 0x1000u;
581
+ return reinterpret_cast<tfloat32_t const &>(bits);
582
+ }
583
+
584
+ /// Returns 0, the additive identity element (specialization for complex<tfloat32_t>)
585
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> zero< complex<tfloat32_t> >() {
586
+ return complex<tfloat32_t>(zero<tfloat32_t>(), tfloat32_t());
587
+ }
588
+
589
+ /// Returns 2 (specialization for tfloat32_t)
590
+ template <> CUTLASS_HOST_DEVICE tfloat32_t two<tfloat32_t>() {
591
+ uint32_t bits = 0x40001000u;
592
+ return reinterpret_cast<tfloat32_t const &>(bits);
593
+ }
594
+
595
+ /// Returns 2 (specialization for complex<tfloat32_t>)
596
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> two< complex<tfloat32_t> >() {
597
+ return complex<tfloat32_t>(two<tfloat32_t>(), tfloat32_t());
598
+ }
599
+
600
+ /// Returns pi, approximately 3.141 (specialization for tfloat32_t)
601
+ template <> CUTLASS_HOST_DEVICE tfloat32_t pi<tfloat32_t>() {
602
+ uint32_t bits = 0x40491fdbu;
603
+ return reinterpret_cast<tfloat32_t const &>(bits);
604
+ }
605
+
606
+ /// Returns pi, approximately 3.141 (specialization for complex<tfloat32_t>)
607
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> pi< complex<tfloat32_t> >() {
608
+ return complex<tfloat32_t>(pi<tfloat32_t>(), tfloat32_t());
609
+ }
610
+
611
+ /// Returns 2 * pi (specialization for tfloat32_t)
612
+ template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi<tfloat32_t>() {
613
+ uint32_t bits = 0x40c91fdbu;
614
+ return reinterpret_cast<tfloat32_t const &>(bits);
615
+ }
616
+
617
+ /// Returns 2 * pi (specialization for complex<tfloat32_t>)
618
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> two_pi< complex<tfloat32_t> >() {
619
+ return complex<tfloat32_t>(two_pi<tfloat32_t>(), tfloat32_t());
620
+ }
621
+
622
+ /// Returns pi / 2 (specialization for tfloat32_t)
623
+ template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi<tfloat32_t>() {
624
+ uint32_t bits = 0x3fc91fdbu;
625
+ return reinterpret_cast<tfloat32_t const &>(bits);
626
+ }
627
+
628
+ /// Returns pi / 2 (specialization for complex<tfloat32_t>)
629
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> half_pi< complex<tfloat32_t> >() {
630
+ return complex<tfloat32_t>(half_pi<tfloat32_t>(), tfloat32_t());
631
+ }
632
+
633
+ /// Returns sqrt(pi) (specialization for tfloat32_t)
634
+ template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi<tfloat32_t>() {
635
+ uint32_t bits = 0x3fe2efc5u;
636
+ return reinterpret_cast<tfloat32_t const &>(bits);
637
+ }
638
+
639
+ /// Returns sqrt(pi) (specialization for complex<tfloat32_t>)
640
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_pi< complex<tfloat32_t> >() {
641
+ return complex<tfloat32_t>(root_pi<tfloat32_t>(), tfloat32_t());
642
+ }
643
+
644
+ /// Returns sqrt(pi / 2) (specialization for tfloat32_t)
645
+ template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi<tfloat32_t>() {
646
+ uint32_t bits = 0x3fa07c99u;
647
+ return reinterpret_cast<tfloat32_t const &>(bits);
648
+ }
649
+
650
+ /// Returns sqrt(pi / 2) (specialization for complex<tfloat32_t>)
651
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_half_pi< complex<tfloat32_t> >() {
652
+ return complex<tfloat32_t>(root_half_pi<tfloat32_t>(), tfloat32_t());
653
+ }
654
+
655
+ /// Returns sqrt(2 * pi) (specialization for tfloat32_t)
656
+ template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi<tfloat32_t>() {
657
+ uint32_t bits = 0x40207c99u;
658
+ return reinterpret_cast<tfloat32_t const &>(bits);
659
+ }
660
+
661
+ /// Returns sqrt(2 * pi) (specialization for complex<tfloat32_t>)
662
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_two_pi< complex<tfloat32_t> >() {
663
+ return complex<tfloat32_t>(root_two_pi<tfloat32_t>(), tfloat32_t());
664
+ }
665
+
666
+ /// Returns sqrt(ln(4)) (specialization for tfloat32_t)
667
+ template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four<tfloat32_t>() {
668
+ uint32_t bits = 0x3f96c55fu;
669
+ return reinterpret_cast<tfloat32_t const &>(bits);
670
+ }
671
+
672
+ /// Returns sqrt(ln(4)) (specialization for complex<tfloat32_t>)
673
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_ln_four< complex<tfloat32_t> >() {
674
+ return complex<tfloat32_t>(root_ln_four<tfloat32_t>(), tfloat32_t());
675
+ }
676
+
677
+ /// Returns e, approximately 2.718... (specialization for tfloat32_t)
678
+ template <> CUTLASS_HOST_DEVICE tfloat32_t e<tfloat32_t>() {
679
+ uint32_t bits = 0x402e0854u;
680
+ return reinterpret_cast<tfloat32_t const &>(bits);
681
+ }
682
+
683
+ /// Returns e, approximately 2.718... (specialization for complex<tfloat32_t>)
684
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> e< complex<tfloat32_t> >() {
685
+ return complex<tfloat32_t>(e<tfloat32_t>(), tfloat32_t());
686
+ }
687
+
688
+ /// Returns (1/2) (specialization for tfloat32_t)
689
+ template <> CUTLASS_HOST_DEVICE tfloat32_t half<tfloat32_t>() {
690
+ uint32_t bits = 0x3f001000u;
691
+ return reinterpret_cast<tfloat32_t const &>(bits);
692
+ }
693
+
694
+ /// Returns (1/2) (specialization for complex<tfloat32_t>)
695
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> half< complex<tfloat32_t> >() {
696
+ return complex<tfloat32_t>(half<tfloat32_t>(), tfloat32_t());
697
+ }
698
+
699
+ /// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t)
700
+ template <> CUTLASS_HOST_DEVICE tfloat32_t root_two<tfloat32_t>() {
701
+ uint32_t bits = 0x3fb514f3u;
702
+ return reinterpret_cast<tfloat32_t const &>(bits);
703
+ }
704
+
705
+ /// Returns sqrt(2), approximately 1.414... (specialization for complex<tfloat32_t>)
706
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_two< complex<tfloat32_t> >() {
707
+ return complex<tfloat32_t>(root_two<tfloat32_t>(), tfloat32_t());
708
+ }
709
+
710
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t)
711
+ template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two<tfloat32_t>() {
712
+ uint32_t bits = 0x3f3514f3u;
713
+ return reinterpret_cast<tfloat32_t const &>(bits);
714
+ }
715
+
716
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<tfloat32_t>)
717
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> half_root_two< complex<tfloat32_t> >() {
718
+ return complex<tfloat32_t>(half_root_two<tfloat32_t>(), tfloat32_t());
719
+ }
720
+
721
+ /// Returns ln(2), approximately 0.693... (specialization for tfloat32_t)
722
+ template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two<tfloat32_t>() {
723
+ uint32_t bits = 0x3f318218u;
724
+ return reinterpret_cast<tfloat32_t const &>(bits);
725
+ }
726
+
727
+ /// Returns ln(2), approximately 0.693... (specialization for complex<tfloat32_t>)
728
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> ln_two< complex<tfloat32_t> >() {
729
+ return complex<tfloat32_t>(ln_two<tfloat32_t>(), tfloat32_t());
730
+ }
731
+
732
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t)
733
+ template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two<tfloat32_t>() {
734
+ uint32_t bits = 0xbebbb795u;
735
+ return reinterpret_cast<tfloat32_t const &>(bits);
736
+ }
737
+
738
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<tfloat32_t>)
739
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> ln_ln_two< complex<tfloat32_t> >() {
740
+ return complex<tfloat32_t>(ln_ln_two<tfloat32_t>(), tfloat32_t());
741
+ }
742
+
743
+ /// Returns 1/3, approximately 0.333... (specialization for tfloat32_t)
744
+ template <> CUTLASS_HOST_DEVICE tfloat32_t third<tfloat32_t>() {
745
+ uint32_t bits = 0x3eaabaabu;
746
+ return reinterpret_cast<tfloat32_t const &>(bits);
747
+ }
748
+
749
+ /// Returns 1/3, approximately 0.333... (specialization for complex<tfloat32_t>)
750
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> third< complex<tfloat32_t> >() {
751
+ return complex<tfloat32_t>(third<tfloat32_t>(), tfloat32_t());
752
+ }
753
+
754
+ /// Returns 2/3, approximately 0.666... (specialization for tfloat32_t)
755
+ template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds<tfloat32_t>() {
756
+ uint32_t bits = 0x3f2abaabu;
757
+ return reinterpret_cast<tfloat32_t const &>(bits);
758
+ }
759
+
760
+ /// Returns 2/3, approximately 0.666... (specialization for complex<tfloat32_t>)
761
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> twothirds< complex<tfloat32_t> >() {
762
+ return complex<tfloat32_t>(twothirds<tfloat32_t>(), tfloat32_t());
763
+ }
764
+
765
+ /// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t)
766
+ template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three<tfloat32_t>() {
767
+ uint32_t bits = 0x3e110daau;
768
+ return reinterpret_cast<tfloat32_t const &>(bits);
769
+ }
770
+
771
+ /// Returns pi - 3, approximately 0.1416... (specialization for complex<tfloat32_t>)
772
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> pi_minus_three< complex<tfloat32_t> >() {
773
+ return complex<tfloat32_t>(pi_minus_three<tfloat32_t>(), tfloat32_t());
774
+ }
775
+
776
+ /// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t)
777
+ template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi<tfloat32_t>() {
778
+ uint32_t bits = 0x3f5bd095u;
779
+ return reinterpret_cast<tfloat32_t const &>(bits);
780
+ }
781
+
782
+ /// Returns 4 - pi, approximately 0.858... (specialization for complex<tfloat32_t>)
783
+ template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> four_minus_pi< complex<tfloat32_t> >() {
784
+ return complex<tfloat32_t>(four_minus_pi<tfloat32_t>(), tfloat32_t());
785
+ }
786
+
787
+ /////////////////////////////////////////////////////////////////////////////////////
788
+
789
+ // Specialization for half_t
790
+
791
+ /// Returns 1, the multiplicative identity element (specialization for half_t)
792
+ template <> CUTLASS_HOST_DEVICE half_t one<half_t>() {
793
+ uint16_t bits = 0x3c00u;
794
+ return reinterpret_cast<half_t const &>(bits);
795
+ }
796
+
797
+ /// Returns 1, the multiplicative identity element (specialization for complex<half_t>)
798
+ template <> CUTLASS_HOST_DEVICE complex<half_t> one< complex<half_t> >() {
799
+ return complex<half_t>(one<half_t>(), half_t());
800
+ }
801
+
802
+ /// Returns 0, the additive identity element (specialization for half_t)
803
+ template <> CUTLASS_HOST_DEVICE half_t zero<half_t>() {
804
+ uint16_t bits = 0x0u;
805
+ return reinterpret_cast<half_t const &>(bits);
806
+ }
807
+
808
+ /// Returns 0, the additive identity element (specialization for complex<half_t>)
809
+ template <> CUTLASS_HOST_DEVICE complex<half_t> zero< complex<half_t> >() {
810
+ return complex<half_t>(zero<half_t>(), half_t());
811
+ }
812
+
813
+ /// Returns 2 (specialization for half_t)
814
+ template <> CUTLASS_HOST_DEVICE half_t two<half_t>() {
815
+ uint16_t bits = 0x4000u;
816
+ return reinterpret_cast<half_t const &>(bits);
817
+ }
818
+
819
+ /// Returns 2 (specialization for complex<half_t>)
820
+ template <> CUTLASS_HOST_DEVICE complex<half_t> two< complex<half_t> >() {
821
+ return complex<half_t>(two<half_t>(), half_t());
822
+ }
823
+
824
+ /// Returns pi, approximately 3.141 (specialization for half_t)
825
+ template <> CUTLASS_HOST_DEVICE half_t pi<half_t>() {
826
+ uint16_t bits = 0x4248u;
827
+ return reinterpret_cast<half_t const &>(bits);
828
+ }
829
+
830
+ /// Returns pi, approximately 3.141 (specialization for complex<half_t>)
831
+ template <> CUTLASS_HOST_DEVICE complex<half_t> pi< complex<half_t> >() {
832
+ return complex<half_t>(pi<half_t>(), half_t());
833
+ }
834
+
835
+ /// Returns 2 * pi (specialization for half_t)
836
+ template <> CUTLASS_HOST_DEVICE half_t two_pi<half_t>() {
837
+ uint16_t bits = 0x4648u;
838
+ return reinterpret_cast<half_t const &>(bits);
839
+ }
840
+
841
+ /// Returns 2 * pi (specialization for complex<half_t>)
842
+ template <> CUTLASS_HOST_DEVICE complex<half_t> two_pi< complex<half_t> >() {
843
+ return complex<half_t>(two_pi<half_t>(), half_t());
844
+ }
845
+
846
+ /// Returns pi / 2 (specialization for half_t)
847
+ template <> CUTLASS_HOST_DEVICE half_t half_pi<half_t>() {
848
+ uint16_t bits = 0x3e48u;
849
+ return reinterpret_cast<half_t const &>(bits);
850
+ }
851
+
852
+ /// Returns pi / 2 (specialization for complex<half_t>)
853
+ template <> CUTLASS_HOST_DEVICE complex<half_t> half_pi< complex<half_t> >() {
854
+ return complex<half_t>(half_pi<half_t>(), half_t());
855
+ }
856
+
857
+ /// Returns sqrt(pi) (specialization for half_t)
858
+ template <> CUTLASS_HOST_DEVICE half_t root_pi<half_t>() {
859
+ uint16_t bits = 0x3f17u;
860
+ return reinterpret_cast<half_t const &>(bits);
861
+ }
862
+
863
+ /// Returns sqrt(pi) (specialization for complex<half_t>)
864
+ template <> CUTLASS_HOST_DEVICE complex<half_t> root_pi< complex<half_t> >() {
865
+ return complex<half_t>(root_pi<half_t>(), half_t());
866
+ }
867
+
868
+ /// Returns sqrt(pi / 2) (specialization for half_t)
869
+ template <> CUTLASS_HOST_DEVICE half_t root_half_pi<half_t>() {
870
+ uint16_t bits = 0x3d03u;
871
+ return reinterpret_cast<half_t const &>(bits);
872
+ }
873
+
874
+ /// Returns sqrt(pi / 2) (specialization for complex<half_t>)
875
+ template <> CUTLASS_HOST_DEVICE complex<half_t> root_half_pi< complex<half_t> >() {
876
+ return complex<half_t>(root_half_pi<half_t>(), half_t());
877
+ }
878
+
879
+ /// Returns sqrt(2 * pi) (specialization for half_t)
880
+ template <> CUTLASS_HOST_DEVICE half_t root_two_pi<half_t>() {
881
+ uint16_t bits = 0x4103u;
882
+ return reinterpret_cast<half_t const &>(bits);
883
+ }
884
+
885
+ /// Returns sqrt(2 * pi) (specialization for complex<half_t>)
886
+ template <> CUTLASS_HOST_DEVICE complex<half_t> root_two_pi< complex<half_t> >() {
887
+ return complex<half_t>(root_two_pi<half_t>(), half_t());
888
+ }
889
+
890
+ /// Returns sqrt(ln(4)) (specialization for half_t)
891
+ template <> CUTLASS_HOST_DEVICE half_t root_ln_four<half_t>() {
892
+ uint16_t bits = 0x3cb6u;
893
+ return reinterpret_cast<half_t const &>(bits);
894
+ }
895
+
896
+ /// Returns sqrt(ln(4)) (specialization for complex<half_t>)
897
+ template <> CUTLASS_HOST_DEVICE complex<half_t> root_ln_four< complex<half_t> >() {
898
+ return complex<half_t>(root_ln_four<half_t>(), half_t());
899
+ }
900
+
901
+ /// Returns e, approximately 2.718... (specialization for half_t)
902
+ template <> CUTLASS_HOST_DEVICE half_t e<half_t>() {
903
+ uint16_t bits = 0x4170u;
904
+ return reinterpret_cast<half_t const &>(bits);
905
+ }
906
+
907
+ /// Returns e, approximately 2.718... (specialization for complex<half_t>)
908
+ template <> CUTLASS_HOST_DEVICE complex<half_t> e< complex<half_t> >() {
909
+ return complex<half_t>(e<half_t>(), half_t());
910
+ }
911
+
912
+ /// Returns (1/2) (specialization for half_t)
913
+ template <> CUTLASS_HOST_DEVICE half_t half<half_t>() {
914
+ uint16_t bits = 0x3800u;
915
+ return reinterpret_cast<half_t const &>(bits);
916
+ }
917
+
918
+ /// Returns (1/2) (specialization for complex<half_t>)
919
+ template <> CUTLASS_HOST_DEVICE complex<half_t> half< complex<half_t> >() {
920
+ return complex<half_t>(half<half_t>(), half_t());
921
+ }
922
+
923
+ /// Returns sqrt(2), approximately 1.414... (specialization for half_t)
924
+ template <> CUTLASS_HOST_DEVICE half_t root_two<half_t>() {
925
+ uint16_t bits = 0x3da8u;
926
+ return reinterpret_cast<half_t const &>(bits);
927
+ }
928
+
929
+ /// Returns sqrt(2), approximately 1.414... (specialization for complex<half_t>)
930
+ template <> CUTLASS_HOST_DEVICE complex<half_t> root_two< complex<half_t> >() {
931
+ return complex<half_t>(root_two<half_t>(), half_t());
932
+ }
933
+
934
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t)
935
+ template <> CUTLASS_HOST_DEVICE half_t half_root_two<half_t>() {
936
+ uint16_t bits = 0x39a8u;
937
+ return reinterpret_cast<half_t const &>(bits);
938
+ }
939
+
940
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<half_t>)
941
+ template <> CUTLASS_HOST_DEVICE complex<half_t> half_root_two< complex<half_t> >() {
942
+ return complex<half_t>(half_root_two<half_t>(), half_t());
943
+ }
944
+
945
+ /// Returns ln(2), approximately 0.693... (specialization for half_t)
946
+ template <> CUTLASS_HOST_DEVICE half_t ln_two<half_t>() {
947
+ uint16_t bits = 0x398cu;
948
+ return reinterpret_cast<half_t const &>(bits);
949
+ }
950
+
951
+ /// Returns ln(2), approximately 0.693... (specialization for complex<half_t>)
952
+ template <> CUTLASS_HOST_DEVICE complex<half_t> ln_two< complex<half_t> >() {
953
+ return complex<half_t>(ln_two<half_t>(), half_t());
954
+ }
955
+
956
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t)
957
+ template <> CUTLASS_HOST_DEVICE half_t ln_ln_two<half_t>() {
958
+ uint16_t bits = 0xb5ddu;
959
+ return reinterpret_cast<half_t const &>(bits);
960
+ }
961
+
962
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<half_t>)
963
+ template <> CUTLASS_HOST_DEVICE complex<half_t> ln_ln_two< complex<half_t> >() {
964
+ return complex<half_t>(ln_ln_two<half_t>(), half_t());
965
+ }
966
+
967
+ /// Returns 1/3, approximately 0.333... (specialization for half_t)
968
+ template <> CUTLASS_HOST_DEVICE half_t third<half_t>() {
969
+ uint16_t bits = 0x3555u;
970
+ return reinterpret_cast<half_t const &>(bits);
971
+ }
972
+
973
+ /// Returns 1/3, approximately 0.333... (specialization for complex<half_t>)
974
+ template <> CUTLASS_HOST_DEVICE complex<half_t> third< complex<half_t> >() {
975
+ return complex<half_t>(third<half_t>(), half_t());
976
+ }
977
+
978
+ /// Returns 2/3, approximately 0.666... (specialization for half_t)
979
+ template <> CUTLASS_HOST_DEVICE half_t twothirds<half_t>() {
980
+ uint16_t bits = 0x3955u;
981
+ return reinterpret_cast<half_t const &>(bits);
982
+ }
983
+
984
+ /// Returns 2/3, approximately 0.666... (specialization for complex<half_t>)
985
+ template <> CUTLASS_HOST_DEVICE complex<half_t> twothirds< complex<half_t> >() {
986
+ return complex<half_t>(twothirds<half_t>(), half_t());
987
+ }
988
+
989
+ /// Returns pi - 3, approximately 0.1416... (specialization for half_t)
990
+ template <> CUTLASS_HOST_DEVICE half_t pi_minus_three<half_t>() {
991
+ uint16_t bits = 0x3088u;
992
+ return reinterpret_cast<half_t const &>(bits);
993
+ }
994
+
995
+ /// Returns pi - 3, approximately 0.1416... (specialization for complex<half_t>)
996
+ template <> CUTLASS_HOST_DEVICE complex<half_t> pi_minus_three< complex<half_t> >() {
997
+ return complex<half_t>(pi_minus_three<half_t>(), half_t());
998
+ }
999
+
1000
+ /// Returns 4 - pi, approximately 0.858... (specialization for half_t)
1001
+ template <> CUTLASS_HOST_DEVICE half_t four_minus_pi<half_t>() {
1002
+ uint16_t bits = 0x3adeu;
1003
+ return reinterpret_cast<half_t const &>(bits);
1004
+ }
1005
+
1006
+ /// Returns 4 - pi, approximately 0.858... (specialization for complex<half_t>)
1007
+ template <> CUTLASS_HOST_DEVICE complex<half_t> four_minus_pi< complex<half_t> >() {
1008
+ return complex<half_t>(four_minus_pi<half_t>(), half_t());
1009
+ }
1010
+
1011
+ /////////////////////////////////////////////////////////////////////////////////////
1012
+
1013
+ // Specialization for bfloat16_t
1014
+
1015
+ /// Returns 1, the multiplicative identity element (specialization for bfloat16_t)
1016
+ template <> CUTLASS_HOST_DEVICE bfloat16_t one<bfloat16_t>() {
1017
+ uint16_t bits = 0x3f80u;
1018
+ return reinterpret_cast<bfloat16_t const &>(bits);
1019
+ }
1020
+
1021
+ /// Returns 1, the multiplicative identity element (specialization for complex<bfloat16_t>)
1022
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> one< complex<bfloat16_t> >() {
1023
+ return complex<bfloat16_t>(one<bfloat16_t>(), bfloat16_t());
1024
+ }
1025
+
1026
+ /// Returns 0, the additive identity element (specialization for bfloat16_t)
1027
+ template <> CUTLASS_HOST_DEVICE bfloat16_t zero<bfloat16_t>() {
1028
+ uint16_t bits = 0x0u;
1029
+ return reinterpret_cast<bfloat16_t const &>(bits);
1030
+ }
1031
+
1032
+ /// Returns 0, the additive identity element (specialization for complex<bfloat16_t>)
1033
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> zero< complex<bfloat16_t> >() {
1034
+ return complex<bfloat16_t>(zero<bfloat16_t>(), bfloat16_t());
1035
+ }
1036
+
1037
+ /// Returns 2 (specialization for bfloat16_t)
1038
+ template <> CUTLASS_HOST_DEVICE bfloat16_t two<bfloat16_t>() {
1039
+ uint16_t bits = 0x4000u;
1040
+ return reinterpret_cast<bfloat16_t const &>(bits);
1041
+ }
1042
+
1043
+ /// Returns 2 (specialization for complex<bfloat16_t>)
1044
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> two< complex<bfloat16_t> >() {
1045
+ return complex<bfloat16_t>(two<bfloat16_t>(), bfloat16_t());
1046
+ }
1047
+
1048
+ /// Returns pi, approximately 3.141 (specialization for bfloat16_t)
1049
+ template <> CUTLASS_HOST_DEVICE bfloat16_t pi<bfloat16_t>() {
1050
+ uint16_t bits = 0x4049u;
1051
+ return reinterpret_cast<bfloat16_t const &>(bits);
1052
+ }
1053
+
1054
+ /// Returns pi, approximately 3.141 (specialization for complex<bfloat16_t>)
1055
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> pi< complex<bfloat16_t> >() {
1056
+ return complex<bfloat16_t>(pi<bfloat16_t>(), bfloat16_t());
1057
+ }
1058
+
1059
+ /// Returns 2 * pi (specialization for bfloat16_t)
1060
+ template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi<bfloat16_t>() {
1061
+ uint16_t bits = 0x40c9u;
1062
+ return reinterpret_cast<bfloat16_t const &>(bits);
1063
+ }
1064
+
1065
+ /// Returns 2 * pi (specialization for complex<bfloat16_t>)
1066
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> two_pi< complex<bfloat16_t> >() {
1067
+ return complex<bfloat16_t>(two_pi<bfloat16_t>(), bfloat16_t());
1068
+ }
1069
+
1070
+ /// Returns pi / 2 (specialization for bfloat16_t)
1071
+ template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi<bfloat16_t>() {
1072
+ uint16_t bits = 0x3fc9u;
1073
+ return reinterpret_cast<bfloat16_t const &>(bits);
1074
+ }
1075
+
1076
+ /// Returns pi / 2 (specialization for complex<bfloat16_t>)
1077
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> half_pi< complex<bfloat16_t> >() {
1078
+ return complex<bfloat16_t>(half_pi<bfloat16_t>(), bfloat16_t());
1079
+ }
1080
+
1081
+ /// Returns sqrt(pi) (specialization for bfloat16_t)
1082
+ template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi<bfloat16_t>() {
1083
+ uint16_t bits = 0x3fe3u;
1084
+ return reinterpret_cast<bfloat16_t const &>(bits);
1085
+ }
1086
+
1087
+ /// Returns sqrt(pi) (specialization for complex<bfloat16_t>)
1088
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_pi< complex<bfloat16_t> >() {
1089
+ return complex<bfloat16_t>(root_pi<bfloat16_t>(), bfloat16_t());
1090
+ }
1091
+
1092
+ /// Returns sqrt(pi / 2) (specialization for bfloat16_t)
1093
+ template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi<bfloat16_t>() {
1094
+ uint16_t bits = 0x3fa0u;
1095
+ return reinterpret_cast<bfloat16_t const &>(bits);
1096
+ }
1097
+
1098
+ /// Returns sqrt(pi / 2) (specialization for complex<bfloat16_t>)
1099
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_half_pi< complex<bfloat16_t> >() {
1100
+ return complex<bfloat16_t>(root_half_pi<bfloat16_t>(), bfloat16_t());
1101
+ }
1102
+
1103
+ /// Returns sqrt(2 * pi) (specialization for bfloat16_t)
1104
+ template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi<bfloat16_t>() {
1105
+ uint16_t bits = 0x4020u;
1106
+ return reinterpret_cast<bfloat16_t const &>(bits);
1107
+ }
1108
+
1109
+ /// Returns sqrt(2 * pi) (specialization for complex<bfloat16_t>)
1110
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_two_pi< complex<bfloat16_t> >() {
1111
+ return complex<bfloat16_t>(root_two_pi<bfloat16_t>(), bfloat16_t());
1112
+ }
1113
+
1114
+ /// Returns sqrt(ln(4)) (specialization for bfloat16_t)
1115
+ template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four<bfloat16_t>() {
1116
+ uint16_t bits = 0x3f97u;
1117
+ return reinterpret_cast<bfloat16_t const &>(bits);
1118
+ }
1119
+
1120
+ /// Returns sqrt(ln(4)) (specialization for complex<bfloat16_t>)
1121
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_ln_four< complex<bfloat16_t> >() {
1122
+ return complex<bfloat16_t>(root_ln_four<bfloat16_t>(), bfloat16_t());
1123
+ }
1124
+
1125
+ /// Returns e, approximately 2.718... (specialization for bfloat16_t)
1126
+ template <> CUTLASS_HOST_DEVICE bfloat16_t e<bfloat16_t>() {
1127
+ uint16_t bits = 0x402eu;
1128
+ return reinterpret_cast<bfloat16_t const &>(bits);
1129
+ }
1130
+
1131
+ /// Returns e, approximately 2.718... (specialization for complex<bfloat16_t>)
1132
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> e< complex<bfloat16_t> >() {
1133
+ return complex<bfloat16_t>(e<bfloat16_t>(), bfloat16_t());
1134
+ }
1135
+
1136
+ /// Returns (1/2) (specialization for bfloat16_t)
1137
+ template <> CUTLASS_HOST_DEVICE bfloat16_t half<bfloat16_t>() {
1138
+ uint16_t bits = 0x3f00u;
1139
+ return reinterpret_cast<bfloat16_t const &>(bits);
1140
+ }
1141
+
1142
+ /// Returns (1/2) (specialization for complex<bfloat16_t>)
1143
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> half< complex<bfloat16_t> >() {
1144
+ return complex<bfloat16_t>(half<bfloat16_t>(), bfloat16_t());
1145
+ }
1146
+
1147
+ /// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t)
1148
+ template <> CUTLASS_HOST_DEVICE bfloat16_t root_two<bfloat16_t>() {
1149
+ uint16_t bits = 0x3fb5u;
1150
+ return reinterpret_cast<bfloat16_t const &>(bits);
1151
+ }
1152
+
1153
+ /// Returns sqrt(2), approximately 1.414... (specialization for complex<bfloat16_t>)
1154
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_two< complex<bfloat16_t> >() {
1155
+ return complex<bfloat16_t>(root_two<bfloat16_t>(), bfloat16_t());
1156
+ }
1157
+
1158
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t)
1159
+ template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two<bfloat16_t>() {
1160
+ uint16_t bits = 0x3f35u;
1161
+ return reinterpret_cast<bfloat16_t const &>(bits);
1162
+ }
1163
+
1164
+ /// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<bfloat16_t>)
1165
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> half_root_two< complex<bfloat16_t> >() {
1166
+ return complex<bfloat16_t>(half_root_two<bfloat16_t>(), bfloat16_t());
1167
+ }
1168
+
1169
+ /// Returns ln(2), approximately 0.693... (specialization for bfloat16_t)
1170
+ template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two<bfloat16_t>() {
1171
+ uint16_t bits = 0x3f31u;
1172
+ return reinterpret_cast<bfloat16_t const &>(bits);
1173
+ }
1174
+
1175
+ /// Returns ln(2), approximately 0.693... (specialization for complex<bfloat16_t>)
1176
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> ln_two< complex<bfloat16_t> >() {
1177
+ return complex<bfloat16_t>(ln_two<bfloat16_t>(), bfloat16_t());
1178
+ }
1179
+
1180
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t)
1181
+ template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two<bfloat16_t>() {
1182
+ uint16_t bits = 0xbebcu;
1183
+ return reinterpret_cast<bfloat16_t const &>(bits);
1184
+ }
1185
+
1186
+ /// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<bfloat16_t>)
1187
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> ln_ln_two< complex<bfloat16_t> >() {
1188
+ return complex<bfloat16_t>(ln_ln_two<bfloat16_t>(), bfloat16_t());
1189
+ }
1190
+
1191
+ /// Returns 1/3, approximately 0.333... (specialization for bfloat16_t)
1192
+ template <> CUTLASS_HOST_DEVICE bfloat16_t third<bfloat16_t>() {
1193
+ uint16_t bits = 0x3eabu;
1194
+ return reinterpret_cast<bfloat16_t const &>(bits);
1195
+ }
1196
+
1197
+ /// Returns 1/3, approximately 0.333... (specialization for complex<bfloat16_t>)
1198
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> third< complex<bfloat16_t> >() {
1199
+ return complex<bfloat16_t>(third<bfloat16_t>(), bfloat16_t());
1200
+ }
1201
+
1202
+ /// Returns 2/3, approximately 0.666... (specialization for bfloat16_t)
1203
+ template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds<bfloat16_t>() {
1204
+ uint16_t bits = 0x3f2bu;
1205
+ return reinterpret_cast<bfloat16_t const &>(bits);
1206
+ }
1207
+
1208
+ /// Returns 2/3, approximately 0.666... (specialization for complex<bfloat16_t>)
1209
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> twothirds< complex<bfloat16_t> >() {
1210
+ return complex<bfloat16_t>(twothirds<bfloat16_t>(), bfloat16_t());
1211
+ }
1212
+
1213
+ /// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t)
1214
+ template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three<bfloat16_t>() {
1215
+ uint16_t bits = 0x3e11u;
1216
+ return reinterpret_cast<bfloat16_t const &>(bits);
1217
+ }
1218
+
1219
+ /// Returns pi - 3, approximately 0.1416... (specialization for complex<bfloat16_t>)
1220
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> pi_minus_three< complex<bfloat16_t> >() {
1221
+ return complex<bfloat16_t>(pi_minus_three<bfloat16_t>(), bfloat16_t());
1222
+ }
1223
+
1224
+ /// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t)
1225
+ template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi<bfloat16_t>() {
1226
+ uint16_t bits = 0x3f5cu;
1227
+ return reinterpret_cast<bfloat16_t const &>(bits);
1228
+ }
1229
+
1230
+ /// Returns 4 - pi, approximately 0.858... (specialization for complex<bfloat16_t>)
1231
+ template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> four_minus_pi< complex<bfloat16_t> >() {
1232
+ return complex<bfloat16_t>(four_minus_pi<bfloat16_t>(), bfloat16_t());
1233
+ }
1234
+ ///////////////////////////////////////////////////////////////////////////////////
1235
+
1236
+ } // namespace constants
1237
+ } // namespace cutlass
1238
+
1239
+ ///////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/detail/dependent_false.hpp"
34
+ #include "cutlass/conv/collective/collective_conv.hpp"
35
+
36
+ /////////////////////////////////////////////////////////////////////////////////////////////////
37
+
38
+ namespace cutlass::conv::collective {
39
+
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ // Used to specify stage counts or dispatch to automatic computation of stage count
43
+ template<int num_stages>
44
+ struct StageCount {
45
+ static constexpr int value = num_stages;
46
+
47
+ StageCount() = default;
48
+ explicit StageCount(cute::Int<num_stages>) {}
49
+ };
50
+
51
+ template<int carveout_bytes>
52
+ struct StageCountAutoCarveout {
53
+ static constexpr int bytes = carveout_bytes;
54
+
55
+ StageCountAutoCarveout() = default;
56
+ explicit StageCountAutoCarveout(cute::Int<carveout_bytes>) {}
57
+ };
58
+
59
+ // Used to automatically let the builder pick the kernel schedule.
60
+ // Can be overridden with kernel schedule tags in cutlass/conv/dispatch_policy.hpp
61
+ struct KernelScheduleAuto {};
62
+
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ template <
66
+ class ArchTag,
67
+ class OpClass,
68
+ conv::Operator,
69
+ class ElementA,
70
+ class GmemLayoutA,
71
+ int AlignmentA,
72
+ class ElementB,
73
+ class GmemLayoutB,
74
+ int AlignmentB,
75
+ class ElementAccumulator,
76
+ class TileShape_MNK,
77
+ class ClusterShape_MNK,
78
+ class StageCountType,
79
+ class KernelScheduleType,
80
+ class Enable = void
81
+ >
82
+ struct CollectiveBuilder {
83
+ static_assert(cutlass::detail::dependent_false<ElementA>, "Could not build a collective for given parameters.");
84
+ };
85
+
86
+ /////////////////////////////////////////////////////////////////////////////////////////////////
87
+
88
+ } // namespace cutlass::conv::collective
89
+
90
+ /////////////////////////////////////////////////////////////////////////////////////////////////
91
+
92
+ #include "builders/sm90_gmma_builder.inl"
93
+ #include "builders/sm100_umma_builder.inl"
94
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/detail/dependent_false.hpp"
34
+ #include "cutlass/conv/collective/detail.hpp"
35
+
36
+ /////////////////////////////////////////////////////////////////////////////////////////////////
37
+
38
+ namespace cutlass::conv::collective {
39
+
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ template <
43
+ class DispatchPolicy,
44
+ class TileShape,
45
+ class ElementA,
46
+ class ElementB,
47
+ class TiledMma,
48
+ class TileTraitsA,
49
+ class TileTraitsB
50
+ >
51
+ struct CollectiveConv {
52
+ static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
53
+ };
54
+
55
+ /////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ } // namespace cutlass::conv::collective
58
+
59
+ /////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ #include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp"
62
+ #include "sm100_implicit_gemm_umma_warpspecialized.hpp"
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/conv/convnd_problem_shape.hpp"
34
+
35
+ /////////////////////////////////////////////////////////////////////////////////////////////////
36
+
37
+ namespace cutlass::conv::collective::detail {
38
+
39
+ /////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ // Construct the stride types for conv collectives based on the dispatch policy, strides 64b by default
42
+ template <class DispatchPolicy>
43
+ constexpr auto
44
+ sm90_dispatch_policy_to_stride_A() {
45
+ if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) {
46
+ // Maps to modes ((w,n), C)
47
+ if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
48
+ return cute::Stride<cute::Stride<int64_t, int64_t>,
49
+ cute::Int<1>>{};
50
+ }
51
+ // Maps to modes ((w,h,n), C)
52
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
53
+ return cute::Stride<cute::Stride<int64_t, int64_t, int64_t>,
54
+ cute::Int<1>>{};
55
+ }
56
+ // Maps to modes ((w,h,d,n), C)
57
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
58
+ return cute::Stride<cute::Stride<int64_t, int64_t, int64_t, int64_t>,
59
+ cute::Int<1>>{};
60
+ }
61
+ // error dims assert
62
+ else {
63
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
64
+ }
65
+ }
66
+ else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) {
67
+ // Maps to modes (k, nq/npq/nzpq)
68
+ if constexpr (DispatchPolicy::NumSpatialDimensions == 1 ||
69
+ DispatchPolicy::NumSpatialDimensions == 2 ||
70
+ DispatchPolicy::NumSpatialDimensions == 3) {
71
+ return cute::Stride<cute::Int<1>, int64_t>{};
72
+ }
73
+ // error dims assert
74
+ else {
75
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
76
+ }
77
+ }
78
+ else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) {
79
+ // Maps to modes ((q,n), K)
80
+ if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
81
+ return cute::Stride<cute::Stride<int64_t, int64_t>,
82
+ cute::Int<1>>{};
83
+ }
84
+ // Maps to modes ((q,p,n), K)
85
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
86
+ return cute::Stride<cute::Stride<int64_t, int64_t, int64_t>,
87
+ cute::Int<1>>{};
88
+ }
89
+ // Maps to modes ((q,p,z,n), K)
90
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
91
+ return cute::Stride<cute::Stride<int64_t, int64_t, int64_t, int64_t>,
92
+ cute::Int<1>>{};
93
+ }
94
+ // error dims assert
95
+ else {
96
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
97
+ }
98
+ }
99
+ else {
100
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported ConvOp.");
101
+ }
102
+ }
103
+
104
+ // Construct the stirde types for conv collectives based on the dispatch policy, strides 64b by default
105
+ template <class DispatchPolicy>
106
+ constexpr auto
107
+ sm90_dispatch_policy_to_stride_B() {
108
+ if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) {
109
+ // Maps to modes (k, (C,s))
110
+ if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
111
+ return cute::Stride<int64_t, cute::Stride<cute::Int<1>, int64_t>>{};
112
+ }
113
+ // Maps to modes (k, (C,s,r))
114
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
115
+ return cute::Stride<int64_t, cute::Stride<cute::Int<1>, int64_t, int64_t>>{};
116
+ }
117
+ // Maps to modes (k, (C,s,r,t))
118
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
119
+ return cute::Stride<int64_t, cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>>{};
120
+ }
121
+ // error dims assert
122
+ else {
123
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
124
+ }
125
+ }
126
+ else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) {
127
+ // Maps to modes (C, (w,n))
128
+ if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
129
+ return cute::Stride<cute::Int<1>,
130
+ cute::Stride<int64_t, int64_t>>{};
131
+ }
132
+ // Maps to modes (C, (w,h,n))
133
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
134
+ return cute::Stride<cute::Int<1>,
135
+ cute::Stride<int64_t, int64_t, int64_t>>{};
136
+ }
137
+ // Maps to modes (C, (w,h,d,n))
138
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
139
+ return cute::Stride<cute::Int<1>,
140
+ cute::Stride<int64_t, int64_t, int64_t, int64_t>>{};
141
+ }
142
+ // error dims assert
143
+ else {
144
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
145
+ }
146
+ }
147
+ else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) {
148
+ // Maps to modes (C, (k,s))
149
+ if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
150
+ return cute::Stride<cute::Int<1>, cute::Stride<int64_t, int64_t>>{};
151
+ }
152
+ // Maps to modes (C, (k,s,r))
153
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
154
+ return cute::Stride<cute::Int<1>, cute::Stride<int64_t, int64_t, int64_t>>{};
155
+ }
156
+ // Maps to modes (C, (k,s,r,t))
157
+ else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
158
+ return cute::Stride<cute::Int<1>, cute::Stride<int64_t, int64_t, int64_t, int64_t>>{};
159
+ }
160
+ // error dims assert
161
+ else {
162
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
163
+ }
164
+ }
165
+ else {
166
+ static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported ConvOp.");
167
+ }
168
+ }
169
+
170
+
171
+ template <class DispatchPolicy>
172
+ constexpr auto
173
+ sm100_dispatch_policy_to_stride_A() {
174
+ return sm90_dispatch_policy_to_stride_A<DispatchPolicy>();
175
+ }
176
+
177
+ template <class DispatchPolicy>
178
+ constexpr auto
179
+ sm100_dispatch_policy_to_stride_B() {
180
+ return sm90_dispatch_policy_to_stride_B<DispatchPolicy>();
181
+ }
182
+
183
+
184
+ /////////////////////////////////////////////////////////////////////////////////////////////////
185
+
186
+ // Compute the lower/near corner, returning it as a cute::array in [W,H,D] order
187
+ template <conv::Operator ConvOp, int NumSpatialDimensions>
188
+ CUTLASS_HOST_DEVICE
189
+ constexpr auto
190
+ compute_lower_corner_whd(ConvProblemShape<ConvOp, NumSpatialDimensions> const& problem_shape) {
191
+ using cute::for_each;
192
+ using cute::make_seq;
193
+
194
+ cute::array<int, NumSpatialDimensions> lower{};
195
+ if constexpr (ConvOp == conv::Operator::kFprop ||
196
+ ConvOp == conv::Operator::kWgrad) {
197
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
198
+ lower[NumSpatialDimensions-1-i] = -1 * problem_shape.lower_padding[i];
199
+ });
200
+ }
201
+ else if constexpr (ConvOp == conv::Operator::kDgrad) {
202
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
203
+ lower[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] -
204
+ (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i];
205
+ });
206
+ }
207
+ return lower;
208
+ }
209
+
210
+ // Computes the upper/far corner, returning it as a cute::array in [W,H,D] order
211
+ template <conv::Operator ConvOp, int NumSpatialDimensions>
212
+ CUTLASS_HOST_DEVICE
213
+ constexpr auto
214
+ compute_upper_corner_whd(ConvProblemShape<ConvOp, NumSpatialDimensions> const& problem_shape) {
215
+ using cute::for_each;
216
+ using cute::make_seq;
217
+
218
+ cute::array<int, NumSpatialDimensions> upper{};
219
+ if constexpr (ConvOp == conv::Operator::kFprop) {
220
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
221
+ upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] -
222
+ (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i];
223
+ });
224
+ }
225
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
226
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
227
+ upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] -
228
+ (problem_shape.shape_C[i+1] - 1) * problem_shape.dilation[i];
229
+ });
230
+ }
231
+ else if constexpr (ConvOp == conv::Operator::kDgrad) {
232
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
233
+ upper[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] -
234
+ (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i] + problem_shape.shape_C[i+1] - problem_shape.shape_A[i+1];
235
+ });
236
+ }
237
+ return upper;
238
+ }
239
+
240
+ // Compute the lower/near corner of (t,r,s), returning it as a cute::array in [S,R,T] order
241
+ template <conv::Operator ConvOp, int NumSpatialDimensions>
242
+ CUTLASS_HOST_DEVICE
243
+ constexpr auto
244
+ compute_lower_srt(ConvProblemShape<ConvOp, NumSpatialDimensions> const& problem_shape) {
245
+ using cute::for_each;
246
+ using cute::make_seq;
247
+
248
+ cute::array<int, NumSpatialDimensions> lower{};
249
+ if constexpr (ConvOp == conv::Operator::kFprop ||
250
+ ConvOp == conv::Operator::kWgrad) {
251
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
252
+ lower[NumSpatialDimensions-1-i] = 0;
253
+ });
254
+ }
255
+ else if constexpr (ConvOp == conv::Operator::kDgrad) {
256
+ for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
257
+ lower[NumSpatialDimensions-1-i] = (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i];
258
+ });
259
+ }
260
+ return lower;
261
+ }
262
+
263
+ template <class CopyOp> struct is_im2col_load { static constexpr bool value = false; };
264
+ template <> struct is_im2col_load<cute::SM90_TMA_LOAD_IM2COL > { static constexpr bool value = true; };
265
+ template <> struct is_im2col_load<cute::SM90_TMA_LOAD_IM2COL_MULTICAST> { static constexpr bool value = true; };
266
+ template <> struct is_im2col_load<cute::SM100_TMA_2SM_LOAD_IM2COL > { static constexpr bool value = true; };
267
+ template <> struct is_im2col_load<cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST> { static constexpr bool value = true; };
268
+
269
+ /////////////////////////////////////////////////////////////////////////////////////////////////
270
+
271
+ } // namespace cutlass::conv::collective::detail
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+
33
+ #pragma once
34
+
35
+ #include "cutlass/cutlass.h"
36
+ #include "cutlass/gemm/dispatch_policy.hpp"
37
+ #include "cutlass/pipeline/pipeline.hpp"
38
+ #include "cutlass/gemm/gemm.h"
39
+ #include "cutlass/detail/cluster.hpp"
40
+
41
+ #include "cutlass/conv/detail.hpp"
42
+ #include "cute/algorithm/functional.hpp"
43
+ #include "cute/arch/cluster_sm90.hpp"
44
+ #include "cute/atom/mma_atom.hpp"
45
+ #include "cute/algorithm/gemm.hpp"
46
+ #include "cute/numeric/arithmetic_tuple.hpp"
47
+ #include "cutlass/trace.h"
48
+
49
+ #if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0)
50
+ # include <sstream>
51
+ #endif
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ namespace cutlass::conv::collective {
56
+ using namespace cute;
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ // WarpSpecialized Mainloop
61
+ // Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
62
+ template <
63
+ conv::Operator ConvOp,
64
+ int Stages,
65
+ int NumSpatialDims,
66
+ int SchedulerPipelineStageCount,
67
+ int AccumulatorPipelineStageCount,
68
+ class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
69
+ class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL)
70
+ class ElementA_,
71
+ class ElementB_,
72
+ class TiledMma_,
73
+ class TileTraitsA_,
74
+ class TileTraitsB_>
75
+ struct CollectiveConv<
76
+ MainloopSm100TmaUmmaWarpSpecializedImplicitGemm<
77
+ ConvOp,
78
+ Stages,
79
+ NumSpatialDims,
80
+ SchedulerPipelineStageCount,
81
+ AccumulatorPipelineStageCount,
82
+ ClusterShape>,
83
+ TileShapeMNKL_,
84
+ ElementA_,
85
+ ElementB_,
86
+ TiledMma_,
87
+ TileTraitsA_,
88
+ TileTraitsB_>
89
+ {
90
+ //
91
+ // Type Aliases
92
+ //
93
+ using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm<
94
+ ConvOp,
95
+ Stages,
96
+ NumSpatialDims,
97
+ SchedulerPipelineStageCount,
98
+ AccumulatorPipelineStageCount,
99
+ ClusterShape>;
100
+ using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK)
101
+ using ElementA = ElementA_;
102
+ using ElementB = ElementB_;
103
+ using TiledMma = TiledMma_;
104
+ using ElementAccumulator = typename TiledMma::ValTypeC;
105
+ using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy;
106
+ using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy;
107
+ using SmemLayoutAtomA = typename TileTraitsA_::SmemLayoutAtom;
108
+ using SmemLayoutAtomB = typename TileTraitsB_::SmemLayoutAtom;
109
+ using ArchTag = typename DispatchPolicy::ArchTag;
110
+ static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions;
111
+ static constexpr int NumTensorDimensions = NumSpatialDimensions + 2;
112
+ // deducde the kernel facing stride tuple types based on the dispatch policy (spatial dim, algo, etc.)
113
+ using StrideA = decltype(detail::sm100_dispatch_policy_to_stride_A<DispatchPolicy>());
114
+ using StrideB = decltype(detail::sm100_dispatch_policy_to_stride_B<DispatchPolicy>());
115
+
116
+ static constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShape>;
117
+ static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
118
+ static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
119
+ using TmaInternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>>;
120
+ using TmaInternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>>;
121
+
122
+ using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
123
+ using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
124
+
125
+ // Determine MMA type: MMA_1SM vs MMA_2SM
126
+ using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>;
127
+
128
+ using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
129
+ DispatchPolicy::Stages,
130
+ ClusterShape,
131
+ AtomThrShapeMNK>;
132
+ using MainloopPipelineState = typename MainloopPipeline::PipelineState;
133
+
134
+ using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
135
+
136
+ CUTE_STATIC_ASSERT_V(evenly_divides(shape<0>(TileShape{}), tile_size<0>(TiledMma{})), "TileShape_M should be evenly divided by TiledMma_M");
137
+ CUTE_STATIC_ASSERT_V(evenly_divides(shape<1>(TileShape{}), tile_size<1>(TiledMma{})) || (ConvOp == conv::Operator::kWgrad), "TileShape_N should be evenly divided by TiledMma_N");
138
+
139
+ using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
140
+
141
+ // Define A and B block shapes for reduced size TMA_LOADs
142
+ using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
143
+ using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
144
+
145
+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
146
+ static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
147
+ "SmemLayoutAtom must evenly divide tile shape.");
148
+ static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
149
+ "SmemLayoutAtom must evenly divide tile shape.");
150
+
151
+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
152
+ static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
153
+ "SmemLayoutAtom must evenly divide tile shape.");
154
+ static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
155
+ "SmemLayoutAtom must evenly divide tile shape.");
156
+
157
+ // Tile along K mode first before tiling over MN. PIPE mode last as usual.
158
+ // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs.
159
+ using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
160
+ SmemLayoutAtomA{},
161
+ append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
162
+ Step<_2,_1,_3>{}));
163
+ using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
164
+ SmemLayoutAtomB{},
165
+ append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
166
+ Step<_2,_1,_3>{}));
167
+
168
+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
169
+ static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
170
+ cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
171
+ "MMA atom must source both A and B operand from smem_desc for this mainloop.");
172
+
173
+ static constexpr bool is_im2col_A = detail::is_im2col_load<GmemTiledCopyA>::value;
174
+ static constexpr bool is_im2col_B = detail::is_im2col_load<GmemTiledCopyB>::value;
175
+ static constexpr bool is_strided_dgrad = ConvOp == conv::Operator::kDgrad && not is_im2col_A && not is_im2col_B;
176
+
177
+ static constexpr int TileShapeMNKLRank = rank(TileShapeMNKL_{});
178
+ // If rank > 3, TileL exists and it is GroupsPerTile. The kernel is grouped conv now.
179
+ static constexpr bool is_grouped_wgrad = ConvOp == conv::Operator::kWgrad && TileShapeMNKLRank > 3;
180
+
181
+ struct SharedStorage {
182
+ struct TensorStorage : cute::aligned_struct<128, _0> {
183
+ cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
184
+ cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
185
+ } tensors;
186
+
187
+ using PipelineStorage = typename MainloopPipeline::SharedStorage;
188
+ PipelineStorage pipeline;
189
+ };
190
+
191
+ using TensorStorage = typename SharedStorage::TensorStorage;
192
+ using PipelineStorage = typename SharedStorage::PipelineStorage;
193
+
194
+ // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
195
+ static constexpr uint32_t TmaTransactionBytes =
196
+ size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof(ElementA))) +
197
+ size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof(ElementB)));
198
+
199
+ // Host side kernel arguments
200
+ struct Arguments {
201
+ ElementA const* ptr_A{nullptr};
202
+ ElementB const* ptr_B{nullptr};
203
+ };
204
+
205
+ private:
206
+
207
+ // Note that for fprop and non-strided dgrad kernel, the tma load mode is im2col for tensor A and tiled for
208
+ // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor
209
+ // B since operand A, B is swapped.
210
+ // For strided dgrad A and B are both tma tiled and not im2col
211
+
212
+ template <class TensorA, class ClusterShapeVMNK>
213
+ static constexpr auto
214
+ get_tma_load_a_instance(
215
+ TensorA const& tensor_a,
216
+ ProblemShape const& problem_shape,
217
+ ClusterShapeVMNK const& cluster_shape_vmnk) {
218
+
219
+ if constexpr (is_im2col_A) {
220
+ // compute the upper and lower corners based on the conv padding
221
+ auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
222
+ auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
223
+ auto lower_srt = detail::compute_lower_srt(problem_shape);
224
+
225
+ // gbasis strides for dgrad kernel need to be negated
226
+ cute::array<int32_t, NumSpatialDimensions> stride_srt{};
227
+ for (int i = 0; i < NumSpatialDimensions; ++i) {
228
+ stride_srt[i] = ConvOp == conv::Operator::kDgrad ?
229
+ -problem_shape.dilation[NumSpatialDimensions-1-i] :
230
+ problem_shape.dilation[NumSpatialDimensions-1-i];
231
+ }
232
+
233
+ return make_im2col_tma_atom_A_sm100(
234
+ GmemTiledCopyA{},
235
+ tensor_a,
236
+ SmemLayoutA{}(_,_,_,cute::Int<0>{}),
237
+ TileShape{},
238
+ TiledMma{},
239
+ cluster_shape_vmnk,
240
+ shape(lower_corner_whd),
241
+ shape(upper_corner_whd),
242
+ cute::reverse(shape(problem_shape.lower_padding)),
243
+ cute::reverse(shape(problem_shape.upper_padding)),
244
+ cute::reverse(shape(problem_shape.traversal_stride)),
245
+ shape(lower_srt),
246
+ shape(stride_srt));
247
+ }
248
+ // TMA tiled mode for tensor A in wgrad and strided dgrad
249
+ else {
250
+ return make_tma_atom_A_sm100<TmaInternalElementA>(
251
+ GmemTiledCopyA{},
252
+ tensor_a,
253
+ SmemLayoutA{}(_,_,_,cute::Int<0>{}),
254
+ TileShape{},
255
+ TiledMma{},
256
+ cluster_shape_vmnk);
257
+ }
258
+ }
259
+
260
+ template <class TensorB, class ClusterShapeVMNK>
261
+ static constexpr auto
262
+ get_tma_load_b_instance(
263
+ TensorB const& tensor_b,
264
+ ProblemShape const& problem_shape,
265
+ ClusterShapeVMNK const& cluster_shape_vmnk) {
266
+
267
+ if constexpr (is_im2col_B) {
268
+ // compute the upper and lower corners based on the conv padding
269
+ auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
270
+ auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
271
+ auto lower_srt = detail::compute_lower_srt(problem_shape);
272
+
273
+ return make_im2col_tma_atom_B_sm100(
274
+ GmemTiledCopyB{},
275
+ tensor_b,
276
+ SmemLayoutB{}(_,_,_,cute::Int<0>{}),
277
+ TileShape{},
278
+ TiledMma{},
279
+ cluster_shape_vmnk,
280
+ shape(lower_corner_whd),
281
+ shape(upper_corner_whd),
282
+ cute::reverse(shape(problem_shape.lower_padding)),
283
+ cute::reverse(shape(problem_shape.upper_padding)),
284
+ cute::reverse(shape(problem_shape.traversal_stride)),
285
+ shape(lower_srt),
286
+ cute::reverse(shape(problem_shape.dilation)));
287
+ }
288
+ else {
289
+ return make_tma_atom_B_sm100<TmaInternalElementB>(
290
+ GmemTiledCopyB{},
291
+ tensor_b,
292
+ SmemLayoutB{}(_,_,_,cute::Int<0>{}),
293
+ TileShape{},
294
+ TiledMma{},
295
+ cluster_shape_vmnk);
296
+ }
297
+ }
298
+
299
+ public:
300
+
301
+ // Performs im2col transformations on the input of type ConvProblemShape
302
+ static constexpr auto
303
+ get_problem_shape_MNKL(ProblemShape const& problem_shape) {
304
+ if constexpr (is_im2col_A || is_im2col_B) {
305
+ // transformation + im2col linearization
306
+ return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape);
307
+ }
308
+ else {
309
+ // transformation
310
+ return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
311
+ }
312
+ }
313
+
314
+ // Device-side kernel params
315
+ //
316
+ // Arguments has the untransformed problem shape from the user.
317
+ // Params will have the transformed problem shape.
318
+ struct Params {
319
+ using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{}));
320
+
321
+ using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return<IsDynamicCluster>(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})),
322
+ make_tile(typename TiledMma::AtomThrID{})));
323
+
324
+ // Assumption: StrideA is congruent with Problem_MK
325
+ // Select TMA load type according to convolution operator.
326
+ using TensorShapeA = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
327
+ decltype(repeat_like(StrideA{}, int32_t(0))),
328
+ decltype(make_shape(_Submode{}, int32_t(0)))>;
329
+
330
+ using TensorShapeB = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
331
+ decltype(make_shape(int32_t(0), _Submode{})),
332
+ decltype(repeat_like(StrideB{}, int32_t(0)))>;
333
+
334
+ using TMA_A = decltype(get_tma_load_a_instance(
335
+ make_tensor(
336
+ make_gmem_ptr(recast_ptr<TmaInternalElementA>(nullptr)),
337
+ make_layout(TensorShapeA{}, StrideA{})),
338
+ ConvProblemShape<ConvOp, NumSpatialDimensions>{},
339
+ ClusterLayout_VMNK{}));
340
+
341
+ using TMA_B = decltype(get_tma_load_b_instance(
342
+ make_tensor(
343
+ make_gmem_ptr(recast_ptr<TmaInternalElementB>(nullptr)),
344
+ make_layout(TensorShapeB{}, StrideB{})),
345
+ ConvProblemShape<ConvOp, NumSpatialDimensions>{},
346
+ ClusterLayout_VMNK{}));
347
+
348
+ // Members
349
+ TMA_A tma_load_a;
350
+ TMA_B tma_load_b;
351
+ TMA_A tma_load_a_fallback;
352
+ TMA_B tma_load_b_fallback;
353
+ dim3 cluster_shape_fallback;
354
+ };
355
+
356
+ //
357
+ // Constructor
358
+ //
359
+ CUTLASS_DEVICE
360
+ CollectiveConv(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
361
+ : cluster_shape_(cluster_shape)
362
+ , block_rank_in_cluster_(block_rank_in_cluster) {
363
+ if constexpr (IsDynamicCluster) {
364
+ const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
365
+ cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
366
+ observed_tma_load_a_ = is_fallback_cluster ? &params.tma_load_a_fallback : &params.tma_load_a;
367
+ observed_tma_load_b_ = is_fallback_cluster ? &params.tma_load_b_fallback : &params.tma_load_b;
368
+ }
369
+ else {
370
+ observed_tma_load_a_ = &params.tma_load_a;
371
+ observed_tma_load_b_ = &params.tma_load_b;
372
+ }
373
+ }
374
+
375
+ //
376
+ // Methods
377
+ //
378
+
379
+ static constexpr Params
380
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
381
+ (void) workspace;
382
+
383
+ // from the flat problem shape arrays of ConvProblemShape<N>, create a rank-3 MNK problem shape tuple
384
+ // tma desc creation depends on the original untransformed domain.
385
+
386
+ // A extents.
387
+ auto shape_A_orig = problem_shape.get_shape_A();
388
+ // B extents.
389
+ auto shape_B_orig = problem_shape.get_shape_B();
390
+
391
+ // Fill inferred cute strides from flat stride arrays
392
+ auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp);
393
+ auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp);
394
+
395
+ auto ptr_A = recast_ptr<TmaInternalElementA>(args.ptr_A);
396
+ auto ptr_B = recast_ptr<TmaInternalElementB>(args.ptr_B);
397
+
398
+ Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA));
399
+ Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB));
400
+
401
+ auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
402
+ // Cluster layout for TMA construction
403
+ auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
404
+ auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
405
+
406
+ // Cluster layout for TMA construction
407
+ auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{}));
408
+
409
+ auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk);
410
+ auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk);
411
+ auto tma_load_a_fallback = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk_fallback);
412
+ auto tma_load_b_fallback = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk_fallback);
413
+
414
+ static_assert(size(typename decltype(tma_load_a)::ThrID{}) == size(AtomThrShapeMNK{}));
415
+ static_assert(size(typename decltype(tma_load_b)::ThrID{}) == size(AtomThrShapeMNK{}));
416
+
417
+ return {
418
+ tma_load_a,
419
+ tma_load_b,
420
+ tma_load_a_fallback,
421
+ tma_load_b_fallback,
422
+ hw_info.cluster_shape_fallback
423
+ };
424
+ }
425
+
426
+ template<class ProblemShape>
427
+ static bool
428
+ can_implement(
429
+ ProblemShape const& problem_shape,
430
+ Arguments const& args) {
431
+ // Activation and Filter channel mode extents much match
432
+ bool implementable = true;
433
+ // channel mode is major
434
+ {
435
+ const bool check = problem_shape.stride_A[NumTensorDimensions-1] == 1;
436
+ #if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0)
437
+ if (not check) {
438
+ const auto offending_stride =
439
+ problem_shape.stride_A[NumTensorDimensions-1];
440
+ std::ostringstream os;
441
+ os << "CollectiveConv::can_implement: "
442
+ "problem_shape.stride_A[NumTensorDimensions-1 = "
443
+ << (NumTensorDimensions-1) << "] = "
444
+ << offending_stride << " != 1";
445
+ CUTLASS_TRACE_HOST( os.str() );
446
+ }
447
+ #endif
448
+ implementable &= check;
449
+ }
450
+
451
+ {
452
+ const bool check = problem_shape.stride_B[NumTensorDimensions-1] == 1;
453
+ #if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0)
454
+ if (not check) {
455
+ const auto offending_stride =
456
+ problem_shape.stride_B[NumTensorDimensions-1];
457
+ std::ostringstream os;
458
+ os << "CollectiveConv::can_implement: "
459
+ "problem_shape.stride_B[NumTensorDimensions-1 = "
460
+ << (NumTensorDimensions-1) << "] = "
461
+ << offending_stride << " != 1\n";
462
+ CUTLASS_TRACE_HOST( os.str() );
463
+ }
464
+ #endif
465
+ implementable &= check;
466
+ }
467
+
468
+ {
469
+ const auto & traversal_stride = problem_shape.traversal_stride;
470
+ for (auto stride: traversal_stride) {
471
+ implementable &= (stride >= 1 && stride <= 8);
472
+ }
473
+ }
474
+
475
+ if constexpr (ConvOp == conv::Operator::kDgrad && not is_strided_dgrad) {
476
+ const auto & traversal_stride = problem_shape.traversal_stride;
477
+ for (auto stride: traversal_stride) {
478
+ implementable &= (stride == 1);
479
+ }
480
+ }
481
+
482
+ constexpr int tma_alignment_bits = 128;
483
+ // A extents.
484
+ auto shape_A_orig = problem_shape.get_shape_A();
485
+ // B extents.
486
+ auto shape_B_orig = problem_shape.get_shape_B();
487
+
488
+ constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
489
+ {
490
+ const bool check = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(shape_A_orig, StrideA{});
491
+ if (not check) {
492
+ CUTLASS_TRACE_HOST("A shape and/or strides have alignment issue.");
493
+ }
494
+ implementable &= check;
495
+ }
496
+
497
+ constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
498
+ {
499
+ const bool check = cutlass::detail::check_alignment<min_tma_aligned_elements_B>(shape_B_orig, StrideB{});
500
+ if (not check) {
501
+ CUTLASS_TRACE_HOST("B shape and/or strides have alignment issue.");
502
+ }
503
+ implementable &= check;
504
+ }
505
+
506
+ if (not implementable) {
507
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
508
+ return false;
509
+ }
510
+
511
+ if (is_im2col_A || is_im2col_B) {
512
+ // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
513
+ constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
514
+ auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
515
+ for (int i = 0; i < problem_shape.RankS; ++i) {
516
+ implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
517
+ }
518
+ auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
519
+ for (int i = 0; i < problem_shape.RankS; ++i) {
520
+ implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
521
+ }
522
+
523
+ if (!implementable) {
524
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
525
+ return false;
526
+ }
527
+ }
528
+
529
+ if (is_im2col_A || is_im2col_B) {
530
+ // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit]
531
+ constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1;
532
+ auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B;
533
+ for (int i = 0; i < problem_shape.RankS; ++i) {
534
+ // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array
535
+ implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0)
536
+ && ((flt_data[i+1] - 1) * problem_shape.dilation[i] <= offset_limit);
537
+ }
538
+
539
+ if (!implementable) {
540
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n");
541
+ return false;
542
+ }
543
+ }
544
+
545
+ // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
546
+ if constexpr (ConvOp == conv::Operator::kWgrad) {
547
+
548
+ const auto & input_shape = problem_shape.shape_A;
549
+ const auto & input_stride = problem_shape.stride_A;
550
+
551
+ implementable &= input_stride[ProblemShape::RankT - 1] == 1;
552
+ int64_t input_shape_size = 1;
553
+ for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
554
+ input_shape_size *= input_shape[i + 1];
555
+ implementable &= input_stride[i] == input_shape_size;
556
+ }
557
+
558
+ const auto & output_shape = problem_shape.shape_C;
559
+ const auto & output_stride = problem_shape.stride_C;
560
+
561
+ implementable &= output_stride[ProblemShape::RankT - 1] == 1;
562
+ int64_t output_shape_size = 1;
563
+ for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
564
+ output_shape_size *= output_shape[i + 1];
565
+ implementable &= output_stride[i] == output_shape_size;
566
+ }
567
+
568
+ if (!implementable) {
569
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
570
+ return false;
571
+ }
572
+ }
573
+
574
+ // Conv kernels only support cross correlation mode currently.
575
+ {
576
+ implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;
577
+
578
+ if (!implementable) {
579
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
580
+ return false;
581
+ }
582
+ }
583
+
584
+ // When groups > 1, it should be a Grouped Conv.
585
+ if (problem_shape.groups > 1) {
586
+ implementable &= TileShapeMNKLRank > 3;
587
+
588
+ if (!implementable) {
589
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Only Grouped Conv can support groups > 1.\n");
590
+ return false;
591
+ }
592
+ }
593
+
594
+ // Only support Grouped Wgrad currently.
595
+ if constexpr (TileShapeMNKLRank > 3) {
596
+ implementable &= ConvOp == conv::Operator::kWgrad;
597
+
598
+ if (!implementable) {
599
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv Only support Grouped Wgrad currently.\n");
600
+ return false;
601
+ }
602
+ }
603
+
604
+ // Grouped Wgrad channel check.
605
+ if constexpr (is_grouped_wgrad) {
606
+
607
+ int input_K = size<0>(problem_shape.get_shape_A());
608
+ int input_C = size<0>(problem_shape.get_shape_B());
609
+
610
+ implementable &= input_K == input_C;
611
+
612
+ if (!implementable) {
613
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv's input K and input C do not match.\n");
614
+ return false;
615
+ }
616
+
617
+ int output_K = size<0>(problem_shape.get_shape_C());
618
+ int output_C = size<1,0>(problem_shape.get_shape_C());
619
+
620
+ implementable &= input_K == output_K;
621
+ implementable &= input_C == output_C * problem_shape.groups;
622
+
623
+ if (!implementable) {
624
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's input and output K,C and groups do not match\n");
625
+ return false;
626
+ }
627
+
628
+ constexpr int Tile_N = size<1>(TileShape{});
629
+ constexpr int GroupsPerTile = size<3>(TileShapeMNKL_{});
630
+
631
+ implementable &= Tile_N / GroupsPerTile == input_C / problem_shape.groups;
632
+
633
+ if (!implementable) {
634
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's Tile_N, GroupsPerTile and input_C, groups do not match.\n");
635
+ return false;
636
+ }
637
+ }
638
+
639
+ // The extents of linearized problem shape should be int32_t type(maximum is 2^31-1).
640
+ if constexpr (is_im2col_A || is_im2col_B) {
641
+ auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
642
+ auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast<int64_t>(s); }); };
643
+
644
+ if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) {
645
+ implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits<int32_t>::max()) &
646
+ (cute::product(to_64b(L)) <= cutlass::platform::numeric_limits<int32_t>::max());
647
+ }
648
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
649
+ implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits<int32_t>::max());
650
+ }
651
+
652
+ if (!implementable) {
653
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n");
654
+ return false;
655
+ }
656
+ }
657
+
658
+ return true;
659
+ }
660
+
661
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
662
+ CUTLASS_DEVICE void
663
+ prefetch_tma_descriptors() {
664
+ cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor());
665
+ cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor());
666
+ }
667
+
668
+ /// Construct A Single Stage's Accumulator Shape
669
+ CUTLASS_DEVICE static auto
670
+ partition_accumulator_shape() {
671
+ auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
672
+
673
+ return acc_shape;
674
+ }
675
+
676
+ /// Perform a collective-scoped matrix multiply-accumulate
677
+ /// Producer Perspective
678
+ template <
679
+ class GTensorA, class GTensorB,
680
+ class GTensorPartitionedA, class GTensorPartitionedB,
681
+ class STensorA, class STensorB,
682
+ class TileCoordMNKL,
683
+ class KTileIterator
684
+ >
685
+ CUTLASS_DEVICE auto
686
+ load(
687
+ Params const& params,
688
+ MainloopPipeline pipeline,
689
+ MainloopPipelineState mainloop_pipe_producer_state,
690
+ cute::tuple<GTensorA, GTensorB,
691
+ GTensorPartitionedA, GTensorPartitionedB,
692
+ STensorA, STensorB,
693
+ uint16_t, uint16_t> const& load_inputs,
694
+ TileCoordMNKL const& cta_coord_mnkl,
695
+ KTileIterator k_tile_iter, int k_tile_count) {
696
+
697
+ auto [unused_gA, unused_gB,
698
+ tAgA_mk, tBgB_nk, tAsA, tBsB,
699
+ mcast_mask_a, mcast_mask_b] = load_inputs;
700
+
701
+ // slice out the work coord from partitioned tensors
702
+ Tensor tAgA = tAgA_mk(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _);
703
+ auto tensor_b_coord = get<1>(cta_coord_mnkl);
704
+ if constexpr (is_grouped_wgrad) {
705
+ // in grouped wgrad, tensor A = NZPQK, tensor B = NDHWC, tensor C = KTRSc, where C = G*c, c = channel_per_group = 8,16,32.
706
+ // CTA Tiling follows output tensor KTRSc. So cta_size_m = K/CTA_TILE_M. cta_size_n = T*R*S*ceil(c/CTA_TILE_N) = T*R*S*1 = T*R*S.
707
+ // tensor_a_coord = K_idx = cta_coord_m.
708
+ // tensor_b_coord = TRS_idx * C/CTA_TILE_N + C_idx = cta_coord_n * get<1,0>(shape(tBgB_nk) + cta_coord_m,
709
+ // because K == C and CTA_TILE_M == CTA_TILE_N => C_idx = K_idx = cta_coord_m.
710
+ tensor_b_coord = get<0>(cta_coord_mnkl) + get<1>(cta_coord_mnkl) * get<1,0>(shape(tBgB_nk));
711
+ }
712
+ Tensor tBgB = tBgB_nk(_, tensor_b_coord, _);
713
+
714
+ auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state);
715
+
716
+ // Issue the Mainloop loads
717
+ CUTLASS_PRAGMA_NO_UNROLL
718
+ while (k_tile_count > 0) {
719
+ // LOCK mainloop_pipe_producer_state for _writing_
720
+ pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
721
+
722
+ using BarrierType = typename MainloopPipeline::ProducerBarrierType;
723
+ BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state);
724
+
725
+ int write_stage = mainloop_pipe_producer_state.index();
726
+ ++mainloop_pipe_producer_state;
727
+ barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state);
728
+
729
+ if constexpr (is_strided_dgrad) {
730
+ // construct gemm-k tile coord for gB
731
+ auto [conv_k, flt_coord, out_coord] = *k_tile_iter;
732
+ auto gemm_k_tile = prepend(flt_coord, conv_k); // (k,s,r,t)
733
+
734
+ // gA doesn't have a gemm-k (k,s,r,t) iterator mode because it's not an im2col tensor
735
+ auto offset_kqpzn = append(prepend(out_coord, _0{}),_0{}); // (k,q,p,z,n)
736
+ auto tAgA_offset = make_tensor(tAgA.data() + offset_kqpzn, tAgA.layout()); // (TMA, k)
737
+
738
+ if (cute::elect_one_sync()) {
739
+ copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA_offset(_,conv_k), tAsA(_,write_stage));
740
+ copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,gemm_k_tile) , tBsB(_,write_stage));
741
+ }
742
+ }
743
+ else {
744
+ if (cute::elect_one_sync()) {
745
+ copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
746
+ copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage));
747
+ }
748
+ }
749
+
750
+ --k_tile_count;
751
+ ++k_tile_iter;
752
+ }
753
+
754
+ return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
755
+ }
756
+
757
+ /// Set up the data needed by this collective for load.
758
+ /// Return tuple element contain
759
+ /// gA_mk - The tiled tma tensor for input A
760
+ /// gB_nk - The tiled tma tensor for input B
761
+ /// tAsA - partitioned smem tensor for A
762
+ /// tBsB - partitioned smem tensor for B
763
+ /// mcast_mask_a - tma multicast mask for A
764
+ /// mcast_mask_b - tma multicast mask for B
765
+ template <class ProblemShape_MNKL>
766
+ CUTLASS_DEVICE auto
767
+ load_init(
768
+ ProblemShape_MNKL const& problem_shape_MNKL,
769
+ Params const& params,
770
+ TensorStorage& shared_tensors) const {
771
+ using X = Underscore;
772
+
773
+ // Separate out problem shape for convenience
774
+ auto [M,N,K,L] = problem_shape_MNKL;
775
+
776
+ // Represent the full tensors -- get these from TMA
777
+ auto K_A = conditional_return<is_strided_dgrad>(get<0>(K), K);
778
+ Tensor mA_mk = observed_tma_load_a_->get_tma_tensor(make_shape(M, K_A));
779
+ Tensor mB_nk = observed_tma_load_b_->get_tma_tensor(make_shape(N, K));
780
+
781
+ // Tile the tensors and defer the slice
782
+ Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k)
783
+ Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k)
784
+
785
+ // Partition for this CTA
786
+ ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
787
+
788
+ Tensor tCgA_mk = cta_mma.partition_A(gA_mk); // (MMA, MMA_M, MMA_K, m, k)
789
+ Tensor tCgB_nk = cta_mma.partition_B(gB_nk); // (MMA, MMA_N, MMA_K, n, k)
790
+
791
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
792
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE)
793
+
794
+ // Define the CTA-in-cluster Layout and Coord
795
+ Layout cta_layout_mnk = make_layout(cluster_shape_);
796
+ Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
797
+ auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
798
+
799
+ // Project the cta_layout for tma_a along the n-modes
800
+ auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_,
801
+ get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
802
+ group_modes<0,3>(sA), group_modes<0,3>(tCgA_mk));
803
+
804
+ // Project the cta_layout for tma_b along the m-modes
805
+ auto [tBgB_nk, tBsB] = tma_partition(*observed_tma_load_b_,
806
+ get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
807
+ group_modes<0,3>(sB), group_modes<0,3>(tCgB_nk));
808
+
809
+ // TMA Multicast Masks
810
+ uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
811
+ uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
812
+
813
+ return cute::make_tuple(
814
+ gA_mk, gB_nk, // for scheduler
815
+ tAgA_mk, tBgB_nk, tAsA, tBsB, // for input tensor values
816
+ mcast_mask_a, mcast_mask_b); // multicast masks
817
+ }
818
+
819
+ /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
820
+ CUTLASS_DEVICE void
821
+ load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) {
822
+ // Issue the epilogue waits
823
+ /* This helps avoid early exit of ctas in Cluster
824
+ * Waits for all stages to either be released (all
825
+ * Consumer UNLOCKs), or if the stage was never used
826
+ * then would just be acquired since the phase was
827
+ * still inverted from make_producer_start_state
828
+ */
829
+ pipeline.producer_tail(mainloop_pipe_producer_state);
830
+ }
831
+
832
+ /// Perform a collective-scoped matrix multiply-accumulate
833
+ /// Consumer Perspective
834
+ template <
835
+ class FrgEngine, class FrgLayout,
836
+ class FragmentA, class FragmentB
837
+ >
838
+ CUTLASS_DEVICE auto
839
+ mma(MainloopPipeline pipeline,
840
+ MainloopPipelineState mainloop_pipe_consumer_state,
841
+ cute::Tensor<FrgEngine, FrgLayout>& accumulators,
842
+ cute::tuple<TiledMma, FragmentA, FragmentB> const& mma_inputs,
843
+ int k_tile_count)
844
+ {
845
+ static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
846
+ static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
847
+
848
+ auto [tiled_mma, tCrA, tCrB] = mma_inputs;
849
+
850
+ uint32_t skip_wait = k_tile_count <= 0;
851
+ auto barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
852
+
853
+ //
854
+ // PIPELINED MAIN LOOP
855
+ //
856
+ tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
857
+
858
+ CUTLASS_PRAGMA_NO_UNROLL
859
+ while (k_tile_count > 0) {
860
+ // WAIT on mainloop_pipe_consumer_state until its data are available (phase bit flips from mainloop_pipe_consumer_state.phase() value)
861
+ pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
862
+
863
+ // Compute on k_tile
864
+ int read_stage = mainloop_pipe_consumer_state.index();
865
+ // Save current mainlop pipeline read state
866
+ auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
867
+
868
+ // Advance mainloop_pipe
869
+ ++mainloop_pipe_consumer_state;
870
+ --k_tile_count;
871
+ skip_wait = k_tile_count <= 0;
872
+ // Peek at next iteration
873
+ barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
874
+
875
+ // Unroll the K mode manually so we can set scale C to 1
876
+ CUTLASS_PRAGMA_UNROLL
877
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
878
+ // (V,M,K) x (V,N,K) => (V,M,N)
879
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators);
880
+ tiled_mma.accumulate_ = UMMA::ScaleOut::One;
881
+ }
882
+ pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
883
+ }
884
+
885
+ return mainloop_pipe_consumer_state;
886
+ }
887
+
888
+ CUTLASS_DEVICE auto
889
+ mma_init(TensorStorage& shared_tensors) const {
890
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
891
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
892
+
893
+ TiledMma tiled_mma;
894
+
895
+ // Allocate "fragments/descriptors" for A and B matrices
896
+ Tensor tCrA = tiled_mma.make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
897
+ Tensor tCrB = tiled_mma.make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
898
+
899
+ CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
900
+ CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB)); // PIPE
901
+ return cute::make_tuple(tiled_mma, tCrA, tCrB);
902
+ }
903
+
904
+ private:
905
+
906
+ typename Params::TMA_A const* observed_tma_load_a_ = nullptr;
907
+ typename Params::TMA_B const* observed_tma_load_b_ = nullptr;
908
+
909
+ ClusterShape cluster_shape_;
910
+ uint32_t block_rank_in_cluster_;
911
+ };
912
+
913
+ /////////////////////////////////////////////////////////////////////////////////////////////////
914
+
915
+ } // namespace cutlass::conv::collective
916
+
917
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/cutlass.h"
34
+
35
+ #include "cute/arch/cluster_sm90.hpp"
36
+ #include "cute/arch/copy_sm90.hpp"
37
+ #include "cute/atom/mma_atom.hpp"
38
+ #include "cute/atom/copy_traits_sm90_im2col.hpp"
39
+ #include "cute/numeric/arithmetic_tuple.hpp"
40
+ #include "cute/algorithm/functional.hpp"
41
+ #include "cute/algorithm/gemm.hpp"
42
+
43
+ #include "cutlass/conv/detail.hpp"
44
+ #include "cutlass/conv/convolution.h"
45
+ #include "cutlass/conv/dispatch_policy.hpp"
46
+ #include "cutlass/pipeline/pipeline.hpp"
47
+ #include "cutlass/util/packed_stride.hpp"
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass::conv::collective {
52
+ using namespace cute;
53
+
54
+ /////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ template <
57
+ conv::Operator ConvOp,
58
+ int Stages,
59
+ int NumSpatialDims,
60
+ class ClusterShape,
61
+ class KernelSchedule,
62
+ int PipelineAsyncMmaStages,
63
+ class TileShape_,
64
+ class ElementA_,
65
+ class ElementB_,
66
+ class TiledMma_,
67
+ class TileTraitsA_,
68
+ class TileTraitsB_>
69
+ struct CollectiveConv<
70
+ MainloopSm90TmaGmmaWarpSpecializedImplicitGemm<
71
+ ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>,
72
+ TileShape_,
73
+ ElementA_,
74
+ ElementB_,
75
+ TiledMma_,
76
+ TileTraitsA_,
77
+ TileTraitsB_>
78
+ {
79
+ //
80
+ // Type Aliases
81
+ //
82
+ using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm<
83
+ ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>;
84
+ using TileShape = TileShape_;
85
+ using ElementA = ElementA_;
86
+ using ElementB = ElementB_;
87
+ using TiledMma = TiledMma_;
88
+ using ElementAccumulator = typename TiledMma::ValTypeC;
89
+ using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy;
90
+ using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy;
91
+ using SmemLayoutA = typename TileTraitsA_::SmemLayout;
92
+ using SmemLayoutB = typename TileTraitsB_::SmemLayout;
93
+ using ArchTag = typename DispatchPolicy::ArchTag;
94
+ static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions;
95
+ static constexpr int NumTensorDimensions = NumSpatialDimensions + 2;
96
+ // Deduce the kernel-facing stride tuple types based on the dispatch policy
97
+ // (which is a function of the number of spatial dimensions, the algorithm, etc.)
98
+ using StrideA = decltype(detail::sm90_dispatch_policy_to_stride_A<DispatchPolicy>());
99
+ using StrideB = decltype(detail::sm90_dispatch_policy_to_stride_B<DispatchPolicy>());
100
+
101
+ using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
102
+
103
+ using PipelineParams = typename MainloopPipeline::Params;
104
+ using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
105
+
106
+ using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
107
+
108
+ static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
109
+ static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
110
+ static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
111
+
112
+ static_assert(rank(SmemLayoutB{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
113
+ static_assert((size<1>(TileShape{}) == size<0>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape.");
114
+ static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape.");
115
+
116
+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
117
+ static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
118
+ cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
119
+ "MMA atom must source both A and B operand from smem_desc for this mainloop.");
120
+
121
+ // The tma load mode of wgrad is tiled for tensor A and im2col for tensor B while the tma load mode of fprop and dgrad
122
+ // kernel is im2col for tensor A and tiled for tensor B.
123
+ static_assert((ConvOp == conv::Operator::kWgrad
124
+ && (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>))
125
+ || (ConvOp != conv::Operator::kWgrad
126
+ && (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_IM2COL> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_IM2COL_MULTICAST>)),
127
+ "GmemTiledCopyA - invalid SM90 TMA copy atom specified.");
128
+ static_assert((ConvOp == conv::Operator::kWgrad
129
+ && (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_IM2COL> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_IM2COL_MULTICAST>))
130
+ || (ConvOp != conv::Operator::kWgrad
131
+ && (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)),
132
+ "GmemTiledCopyB - invalid SM90 TMA copy atom specified.");
133
+
134
+ static constexpr bool is_im2col_A = detail::is_im2col_load<GmemTiledCopyA>::value;
135
+ static constexpr bool is_im2col_B = detail::is_im2col_load<GmemTiledCopyB>::value;
136
+
137
+ // TMA converts f32 input to tf32 when copying from GMEM to SMEM
138
+ // For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
139
+ static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
140
+ static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
141
+ using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
142
+ using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
143
+
144
+ struct SharedStorage
145
+ {
146
+ struct TensorStorage : cute::aligned_struct<128, _0> {
147
+ cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
148
+ cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
149
+ } tensors;
150
+
151
+ using PipelineStorage = typename MainloopPipeline::SharedStorage;
152
+ PipelineStorage pipeline;
153
+ };
154
+ using TensorStorage = typename SharedStorage::TensorStorage;
155
+ using PipelineStorage = typename SharedStorage::PipelineStorage;
156
+
157
+ static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
158
+ static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages;
159
+ static constexpr uint32_t TmaTransactionBytes =
160
+ (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof(InternalElementA)))+
161
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof(InternalElementB)));
162
+
163
+ // Host side kernel arguments
164
+ struct Arguments {
165
+ ElementA const* ptr_A{nullptr};
166
+ ElementB const* ptr_B{nullptr};
167
+ };
168
+
169
+ private:
170
+ // Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for
171
+ // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor
172
+ // B since operand A, B is swapped.
173
+ // Get tma_load_a instantce.
174
+ template <class TensorA>
175
+ static constexpr auto
176
+ get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) {
177
+ if constexpr (is_im2col_A) {
178
+ // compute the upper and lower corners based on the conv padding
179
+ auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
180
+ auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
181
+ auto lower_srt = detail::compute_lower_srt(problem_shape);
182
+
183
+ // The calculation of gbasis strides for dgrad kernel needs perform negate for dilation values.
184
+ cute::array<int32_t, NumSpatialDimensions> stride_srt{};
185
+ for (int i = 0; i < NumSpatialDimensions; ++i) {
186
+ stride_srt[i] = ConvOp == conv::Operator::kDgrad ?
187
+ -problem_shape.dilation[NumSpatialDimensions-1-i] :
188
+ problem_shape.dilation[NumSpatialDimensions-1-i];
189
+ }
190
+
191
+ return make_im2col_tma_copy(
192
+ GmemTiledCopyA{},
193
+ tensor_a,
194
+ SmemLayoutA{}(_,_,_0{}),
195
+ product_each(shape(SmemLayoutA{}(_,_,_0{}))),
196
+ size<1>(ClusterShape{}),
197
+ shape(lower_corner_whd),
198
+ shape(upper_corner_whd),
199
+ cute::reverse(shape(problem_shape.lower_padding)),
200
+ cute::reverse(shape(problem_shape.upper_padding)),
201
+ cute::reverse(shape(problem_shape.traversal_stride)),
202
+ shape(lower_srt),
203
+ shape(stride_srt));
204
+ }
205
+ // TMA tiled mode for tensor A in wgrad kernel.
206
+ else {
207
+ return make_tma_copy(
208
+ GmemTiledCopyA{},
209
+ tensor_a,
210
+ SmemLayoutA{}(_,_,_0{}),
211
+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
212
+ size<1>(ClusterShape{}));
213
+ }
214
+ }
215
+
216
+ // Get tma_load_b instantce.
217
+ template <class TensorB>
218
+ static constexpr auto
219
+ get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) {
220
+ // TMA im2col mode for tensor B in wgrad kernel.
221
+ if constexpr (is_im2col_B) {
222
+ // compute the upper and lower corners based on the conv padding
223
+ auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
224
+ auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
225
+ auto lower_srt = detail::compute_lower_srt(problem_shape);
226
+
227
+ return make_im2col_tma_copy(
228
+ GmemTiledCopyB{},
229
+ tensor_b,
230
+ SmemLayoutB{}(_,_,_0{}),
231
+ product_each(shape(SmemLayoutB{}(_,_,_0{}))),
232
+ size<0>(ClusterShape{}),
233
+ shape(lower_corner_whd),
234
+ shape(upper_corner_whd),
235
+ cute::reverse(shape(problem_shape.lower_padding)),
236
+ cute::reverse(shape(problem_shape.upper_padding)),
237
+ cute::reverse(shape(problem_shape.traversal_stride)),
238
+ shape(lower_srt),
239
+ cute::reverse(shape(problem_shape.dilation)));
240
+ }
241
+ else {
242
+ return make_tma_copy(
243
+ GmemTiledCopyB{},
244
+ tensor_b,
245
+ SmemLayoutB{}(_,_,_0{}),
246
+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
247
+ size<0>(ClusterShape{}));
248
+ }
249
+ }
250
+
251
+ public:
252
+
253
+ // Performs im2col transformations on the input of type ConvProblemShape
254
+ static constexpr auto
255
+ get_problem_shape_MNKL(ProblemShape const& problem_shape) {
256
+
257
+ if constexpr (is_im2col_A || is_im2col_B) {
258
+ // transformation + im2col linearization
259
+ return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape);
260
+ }
261
+ else {
262
+ // transformation
263
+ return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
264
+ }
265
+ }
266
+
267
+ // Device side kernel params
268
+ struct Params {
269
+ using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{}));
270
+
271
+ // Assumption: StrideA is congruent with Problem_MK
272
+ // Select TMA load type according to convolution operator.
273
+ using TensorShapeA = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
274
+ decltype(repeat_like(StrideA{}, int32_t(0))),
275
+ decltype(make_shape(_Submode{}, int(0)))>;
276
+
277
+ using TensorShapeB = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
278
+ decltype(make_shape(int(0), _Submode{})),
279
+ decltype(repeat_like(StrideB{}, int32_t(0)))>;
280
+
281
+ using TMA_A = decltype(get_tma_load_a_instance(
282
+ make_tensor(
283
+ make_gmem_ptr(static_cast<InternalElementA const*>(nullptr)),
284
+ make_layout(TensorShapeA{}, StrideA{})),
285
+ ConvProblemShape<ConvOp, NumSpatialDimensions>{}));
286
+
287
+ using TMA_B = decltype(get_tma_load_b_instance(
288
+ make_tensor(
289
+ make_gmem_ptr(static_cast<InternalElementB const*>(nullptr)),
290
+ make_layout(TensorShapeB{}, StrideB{})),
291
+ ConvProblemShape<ConvOp, NumSpatialDimensions>{}));
292
+
293
+ // Members
294
+ TMA_A tma_load_a;
295
+ TMA_B tma_load_b;
296
+ uint32_t tma_transaction_bytes = TmaTransactionBytes;
297
+ };
298
+
299
+ //
300
+ // Methods
301
+ //
302
+
303
+ // Lowers the host side user facing arguments to the kernel facing lauch params
304
+ static constexpr Params
305
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
306
+ (void) workspace;
307
+ // from the flat problem shape arrays of ConvProblemShape<ConvOp, N>, create a rank-3 MNK problem shape tuple
308
+ // tma desc creation depends on the original untransformed domain.
309
+
310
+ // A extents.
311
+ auto shape_A_orig = problem_shape.get_shape_A();
312
+ // B extents.
313
+ auto shape_B_orig = problem_shape.get_shape_B();
314
+
315
+ // Fill inferred cute strides from flat stride arrays
316
+ auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp);
317
+ auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp);
318
+
319
+ auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
320
+ auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
321
+
322
+ Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA));
323
+ Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB));
324
+
325
+ auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape);
326
+ auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape);
327
+
328
+ return {
329
+ tma_load_a,
330
+ tma_load_b,
331
+ TmaTransactionBytes
332
+ };
333
+ }
334
+
335
+ template <class ProblemShape>
336
+ static bool
337
+ can_implement(
338
+ ProblemShape const& problem_shape,
339
+ Arguments const& args) {
340
+ // Activation and Filter channel mode extents much match
341
+ bool implementable = true;
342
+ // channel mode is major
343
+ implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1;
344
+ implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1;
345
+
346
+ constexpr int tma_alignment_bits = 128;
347
+ // A extents.
348
+ auto shape_A_orig = problem_shape.get_shape_A();
349
+ // B extents.
350
+ auto shape_B_orig = problem_shape.get_shape_B();
351
+ constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
352
+ implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(shape_A_orig, StrideA{});
353
+ constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
354
+ implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(shape_B_orig, StrideB{});
355
+
356
+ if (!implementable) {
357
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
358
+ return false;
359
+ }
360
+
361
+ // Check valid padding values for TMA_LOAD_IM2COL
362
+ constexpr int padding_limit = (ProblemShape::RankS == 1) ? 65536 : (ProblemShape::RankS == 2 ? 256 : 16);
363
+ for (int i = 0; i < problem_shape.RankS; ++i) {
364
+ implementable = implementable && problem_shape.lower_padding[i] <= padding_limit && problem_shape.lower_padding[i] >= 0;
365
+ implementable = implementable && problem_shape.upper_padding[i] <= padding_limit && problem_shape.upper_padding[i] >= 0;
366
+ }
367
+
368
+ if (!implementable) {
369
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
370
+ return false;
371
+ }
372
+
373
+ if (is_im2col_A || is_im2col_B) {
374
+ // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
375
+ constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
376
+ auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
377
+ for (int i = 0; i < problem_shape.RankS; ++i) {
378
+ implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
379
+ }
380
+ auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
381
+ for (int i = 0; i < problem_shape.RankS; ++i) {
382
+ implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
383
+ }
384
+
385
+ if (!implementable) {
386
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
387
+ return false;
388
+ }
389
+ }
390
+
391
+ if (is_im2col_A || is_im2col_B) {
392
+ // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1]
393
+ constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1;
394
+ auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B;
395
+ for (int i = 0; i < problem_shape.RankS; ++i) {
396
+ // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array
397
+ implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0)
398
+ && ((flt_data[i+1] - 1) * problem_shape.dilation[i] < offset_limit);
399
+ }
400
+
401
+ if (!implementable) {
402
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n");
403
+ return false;
404
+ }
405
+ }
406
+
407
+ // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
408
+ if constexpr (ConvOp == conv::Operator::kWgrad) {
409
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
410
+ std::ostringstream os;
411
+ #endif
412
+ const auto & input_shape = problem_shape.shape_A;
413
+ const auto & input_stride = problem_shape.stride_A;
414
+
415
+ implementable &= input_stride[ProblemShape::RankT - 1] == 1;
416
+ int64_t input_shape_size = 1;
417
+ for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
418
+ input_shape_size *= input_shape[i + 1];
419
+ implementable &= input_stride[i] == input_shape_size;
420
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
421
+ if (input_stride[i] != input_shape_size) {
422
+ os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***";
423
+ }
424
+ #endif
425
+ }
426
+
427
+ if (!implementable) {
428
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
429
+ os << "\n input_shape_size: " << input_shape_size
430
+ << "\n input_shape: " << input_shape
431
+ << "\n input_stride: " << input_stride
432
+ << "\n";
433
+ #endif
434
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n");
435
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
436
+ CUTLASS_TRACE_HOST(os.str());
437
+ #endif
438
+ return false;
439
+ }
440
+
441
+ const auto & output_shape = problem_shape.shape_C;
442
+ const auto & output_stride = problem_shape.stride_C;
443
+
444
+ implementable &= output_stride[ProblemShape::RankT - 1] == 1;
445
+ int64_t output_shape_size = 1;
446
+ for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
447
+ output_shape_size *= output_shape[i + 1];
448
+ implementable &= output_stride[i] == output_shape_size;
449
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
450
+ if (output_stride[i] != output_shape_size) {
451
+ os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***";
452
+ }
453
+ #endif
454
+ }
455
+
456
+ if (!implementable) {
457
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
458
+ os << "\n output_shape_size: " << input_shape_size
459
+ << "\n output_shape: " << input_shape
460
+ << "\n output_stride: " << input_stride
461
+ << "\n";
462
+ #endif
463
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
464
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
465
+ CUTLASS_TRACE_HOST(os.str());
466
+ #endif
467
+ return false;
468
+ }
469
+ }
470
+
471
+ // Conv kernels only support cross correlation mode currently.
472
+ implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;
473
+
474
+ if (!implementable) {
475
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
476
+ return false;
477
+ }
478
+
479
+ if (problem_shape.groups > 1) {
480
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n");
481
+ return false;
482
+ }
483
+
484
+ if constexpr (is_im2col_A || is_im2col_B) {
485
+ auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
486
+ auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast<int64_t>(s); }); };
487
+
488
+ if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) {
489
+ implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits<int32_t>::max()) &
490
+ (cute::product(to_64b(L)) <= cutlass::platform::numeric_limits<int32_t>::max());
491
+ }
492
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
493
+ implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits<int32_t>::max());
494
+ }
495
+
496
+ if (!implementable) {
497
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n");
498
+ return false;
499
+ }
500
+ }
501
+
502
+ return true;
503
+ }
504
+
505
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
506
+ CUTLASS_DEVICE
507
+ static void prefetch_tma_descriptors(Params const& mainloop_params) {
508
+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
509
+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
510
+ }
511
+
512
+ /// Set up the data needed by this collective for load and mma.
513
+ /// Returns a tuple of tensors. The collective and the kernel layer have the contract
514
+ /// Returned tuple must contain at least two elements, with the first two elements being:
515
+ /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k)
516
+ /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k)
517
+ /// The rest of the tensors can be specified as needed by this collective.
518
+ /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with
519
+ /// StrideA and StrideB set up for TMA
520
+ template <class ProblemShapeMNKL>
521
+ CUTLASS_DEVICE auto
522
+ load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){
523
+ //load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
524
+ using X = Underscore;
525
+ // Separate out problem shape for convenience
526
+ auto [M, N, K, L] = problem_shape_MNKL;
527
+
528
+ // TMA requires special handling of strides to deal with coord codomain mapping
529
+ // Represent the full tensors -- get these from TMA
530
+ Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k)
531
+ Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k)
532
+
533
+ // Make tiled views, defer the slice
534
+ Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k)
535
+ Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k)
536
+
537
+ return cute::make_tuple(gA_mk, gB_nk);
538
+ }
539
+
540
+ /// Perform a collective-scoped matrix multiply-accumulate
541
+ /// Producer Perspective
542
+ template <
543
+ class TensorA, class TensorB,
544
+ class KTileIterator, class BlockCoord
545
+ >
546
+ CUTLASS_DEVICE void
547
+ load(
548
+ Params const& mainloop_params,
549
+ MainloopPipeline pipeline,
550
+ PipelineState smem_pipe_producer_state,
551
+ cute::tuple<TensorA, TensorB> const& load_inputs,
552
+ BlockCoord const& blk_coord,
553
+ KTileIterator k_tile_iter, int k_tile_count,
554
+ int thread_idx,
555
+ uint32_t block_rank_in_cluster,
556
+ TensorStorage& shared_tensors) {
557
+
558
+ int lane_predicate = cute::elect_one_sync();
559
+ if (lane_predicate) {
560
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
561
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
562
+
563
+ //
564
+ // Prepare the TMA loads for A and B
565
+ //
566
+ constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
567
+
568
+ uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
569
+ auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
570
+ auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
571
+
572
+ auto [gA_mk, gB_nk] = load_inputs;
573
+
574
+ // Partition the inputs based on the current block coordinates.
575
+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
576
+
577
+ Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k)
578
+ Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k)
579
+
580
+ // Applies the mapping from block_tma_a
581
+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
582
+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
583
+
584
+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
585
+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
586
+
587
+ uint16_t mcast_mask_a = 0;
588
+ uint16_t mcast_mask_b = 0;
589
+
590
+ // Issue TmaLoads
591
+ // Maps the tile -> block, value
592
+ if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
593
+ cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
594
+ auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
595
+ for (int n = 0; n < size<1>(block_layout); ++n) {
596
+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
597
+ }
598
+ }
599
+
600
+ if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
601
+ cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
602
+ auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
603
+ for (int m = 0; m < size<0>(block_layout); ++m) {
604
+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
605
+ }
606
+ }
607
+
608
+ // Mainloop
609
+ CUTLASS_PRAGMA_NO_UNROLL
610
+ for ( ; k_tile_count > 0; --k_tile_count) {
611
+ // LOCK smem_pipe_producer_state for _writing_
612
+ pipeline.producer_acquire(smem_pipe_producer_state);
613
+
614
+ //
615
+ // Copy gmem to smem for *k_tile_iter
616
+ //
617
+
618
+ using BarrierType = typename MainloopPipeline::ProducerBarrierType;
619
+ BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state);
620
+
621
+ int write_stage = smem_pipe_producer_state.index();
622
+
623
+ copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
624
+ copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
625
+ ++k_tile_iter;
626
+
627
+ // Advance smem_pipe_producer_state
628
+ ++smem_pipe_producer_state;
629
+ }
630
+ }
631
+ }
632
+
633
+ /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
634
+ CUTLASS_DEVICE void
635
+ load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) {
636
+ int lane_predicate = cute::elect_one_sync();
637
+
638
+ // Issue the epilogue waits
639
+ if (lane_predicate) {
640
+ /* This helps avoid early exit of blocks in Cluster
641
+ * Waits for all stages to either be released (all
642
+ * Consumer UNLOCKs), or if the stage was never used
643
+ * then would just be acquired since the phase was
644
+ * still inverted from make_producer_start_state
645
+ */
646
+ pipeline.producer_tail(smem_pipe_producer_state);
647
+ }
648
+ }
649
+
650
+ /// Perform a collective-scoped matrix multiply-accumulate
651
+ /// Consumer Perspective
652
+ template <class FrgTensorC>
653
+ CUTLASS_DEVICE void
654
+ mma(MainloopPipeline pipeline,
655
+ PipelineState smem_pipe_consumer_state,
656
+ FrgTensorC& accum,
657
+ int k_tile_count,
658
+ int thread_idx,
659
+ TensorStorage& shared_tensors,
660
+ Params const& mainloop_params) {
661
+ static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
662
+
663
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
664
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
665
+
666
+ //
667
+ // Define C accumulators and A/B partitioning
668
+ //
669
+
670
+ TiledMma tiled_mma;
671
+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
672
+
673
+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
674
+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
675
+
676
+ // Allocate "fragments/descriptors"
677
+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
678
+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
679
+
680
+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
681
+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
682
+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
683
+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
684
+ CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
685
+ CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
686
+
687
+ //
688
+ // PIPELINED MAIN LOOP
689
+ //
690
+ static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
691
+ "ERROR : Incorrect number of MMAs in flight");
692
+
693
+ // We release buffers to producer warps(dma load) with some mmas in flight
694
+ PipelineState smem_pipe_release = smem_pipe_consumer_state;
695
+
696
+ // Prologue GMMAs
697
+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
698
+
699
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
700
+
701
+ warpgroup_fence_operand(accum);
702
+ CUTLASS_PRAGMA_UNROLL
703
+ for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) {
704
+ // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value)
705
+ pipeline.consumer_wait(smem_pipe_consumer_state);
706
+
707
+ int read_stage = smem_pipe_consumer_state.index();
708
+ warpgroup_arrive();
709
+ // Unroll the K mode manually to set scale D to 1
710
+ CUTLASS_PRAGMA_UNROLL
711
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
712
+ // (V,M,K) x (V,N,K) => (V,M,N)
713
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
714
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
715
+ }
716
+
717
+ warpgroup_commit_batch();
718
+
719
+ ++smem_pipe_consumer_state;
720
+ }
721
+
722
+ warpgroup_fence_operand(accum);
723
+ // Mainloop GMMAs
724
+ k_tile_count -= prologue_mma_count;
725
+
726
+ CUTLASS_PRAGMA_NO_UNROLL
727
+ for ( ; k_tile_count > 0; --k_tile_count) {
728
+ // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value)
729
+ pipeline.consumer_wait(smem_pipe_consumer_state);
730
+
731
+ //
732
+ // Compute on k_tile
733
+ //
734
+
735
+ int read_stage = smem_pipe_consumer_state.index();
736
+ warpgroup_fence_operand(accum);
737
+ warpgroup_arrive();
738
+ // Unroll the K mode manually to set scale D to 1
739
+ CUTLASS_PRAGMA_UNROLL
740
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
741
+ // (V,M) x (V,N) => (V,M,N)
742
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
743
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
744
+ }
745
+ warpgroup_commit_batch();
746
+
747
+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_producer_state is consumed
748
+ warpgroup_wait<K_PIPE_MMAS>();
749
+ warpgroup_fence_operand(accum);
750
+
751
+ // UNLOCK smem_pipe_release, done _computing_ on it
752
+ pipeline.consumer_release(smem_pipe_release);
753
+
754
+ // Advance smem_pipe_consumer_state and smem_pipe_release
755
+ ++smem_pipe_consumer_state;
756
+ ++smem_pipe_release;
757
+ }
758
+
759
+ warpgroup_fence_operand(accum);
760
+ }
761
+
762
+ /// Perform a Consumer Epilogue to release all buffers
763
+ CUTLASS_DEVICE void
764
+ mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
765
+ // Prologue GMMAs
766
+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
767
+ k_tile_count -= prologue_mma_count;
768
+
769
+ smem_pipe_release.advance(k_tile_count);
770
+
771
+ // Wait on all GMMAs to complete
772
+ warpgroup_wait<0>();
773
+
774
+ for (int count = 0; count < prologue_mma_count; ++count) {
775
+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
776
+ ++smem_pipe_release;
777
+ }
778
+ }
779
+ };
780
+
781
+ /////////////////////////////////////////////////////////////////////////////////////////////////
782
+
783
+ } // namespace cutlass::conv::collective
784
+
785
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief This file contains definitions and utility functions for describing convolution problem sizes.
33
+
34
+ Conv2dProblem desciption:
35
+ activation (NHWC),
36
+ filter (KRSC),
37
+ output (NPQK),
38
+ pading (pad_h, pad_w),
39
+ stride (stride_h, stride_w),
40
+ dilation (dilation_h, dilation_w).
41
+
42
+ Free functions to map:
43
+ Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
44
+ Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
45
+ Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
46
+ */
47
+
48
+ #pragma once
49
+
50
+ #include "cutlass/cutlass.h"
51
+ #include "cutlass/tensor_coord.h"
52
+ #include "cutlass/fast_math.h"
53
+ #include "cutlass/gemm/gemm_enumerated_types.h"
54
+ #include "cutlass/matrix_coord.h"
55
+ #include "cutlass/conv/convolution.h"
56
+ #include "cutlass/functional.h"
57
+
58
+ namespace cutlass {
59
+ namespace conv {
60
+
61
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
62
+
63
+ /// Problem size structure
64
+ struct Conv2dProblemSize {
65
+
66
+ // Conv2d strictly problem size parameters
67
+ int N, H, W, C, P, Q, K, R, S;
68
+ int pad_h, pad_w;
69
+ int stride_h, stride_w;
70
+ int dilation_h, dilation_w;
71
+ Mode mode;
72
+
73
+ // Conv2d implementation-related parameters
74
+ int split_k_slices;
75
+ int groups;
76
+
77
+ //
78
+ // Methods
79
+ //
80
+
81
+ public:
82
+ CUTLASS_HOST_DEVICE
83
+ Conv2dProblemSize():
84
+ N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0),
85
+ pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1),
86
+ mode(Mode::kConvolution), split_k_slices(1), groups(1) { }
87
+
88
+ /// Constructor for default padding, stride, dilation, and split-K
89
+ CUTLASS_HOST_DEVICE
90
+ Conv2dProblemSize(
91
+ int N,
92
+ int H,
93
+ int W,
94
+ int C,
95
+ int P,
96
+ int Q,
97
+ int K,
98
+ int R,
99
+ int S,
100
+ Mode mode
101
+ ):
102
+ N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S),
103
+ pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1),
104
+ mode(mode), split_k_slices(1), groups (1) { }
105
+
106
+ /// Constructor
107
+ CUTLASS_HOST_DEVICE
108
+ Conv2dProblemSize(
109
+ int N,
110
+ int H,
111
+ int W,
112
+ int C,
113
+ int K,
114
+ int R,
115
+ int S,
116
+ int P,
117
+ int Q,
118
+ int pad_h,
119
+ int pad_w,
120
+ int stride_h,
121
+ int stride_w,
122
+ int dilation_h,
123
+ int dilation_w,
124
+ Mode mode,
125
+ int split_k_slices = 1,
126
+ int groups = 1
127
+ ):
128
+ N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S),
129
+ pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w),
130
+ dilation_h(dilation_h), dilation_w(dilation_w),
131
+ mode(mode), split_k_slices(split_k_slices), groups (groups) { }
132
+
133
+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
134
+ // set user-defined output size and sets P and Q (include all data members in ctor)
135
+ CUTLASS_HOST_DEVICE
136
+ Conv2dProblemSize(
137
+ cutlass::Tensor4DCoord input_size, // NHWC
138
+ cutlass::Tensor4DCoord filter_size, // KRSC
139
+ cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _
140
+ cutlass::MatrixCoord stride, // stride_h, stride_w
141
+ cutlass::MatrixCoord dilation, // dilation_h, dilation_w
142
+ cutlass::Tensor4DCoord output_size, // NPQK
143
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
144
+ int split_k_slices = 1,
145
+ int groups = 1
146
+ ):
147
+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
148
+ P(output_size.h()), Q(output_size.w()),
149
+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
150
+ pad_h(padding[0]), pad_w(padding[2]),
151
+ stride_h(stride.row()), stride_w(stride.column()),
152
+ dilation_h(dilation.row()), dilation_w(dilation.column()),
153
+ mode(mode), split_k_slices(split_k_slices), groups(groups) {}
154
+
155
+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
156
+ // computes output size and sets P and Q (skip output from ctor arguments)
157
+ CUTLASS_HOST_DEVICE
158
+ Conv2dProblemSize(
159
+ cutlass::Tensor4DCoord input_size, // NHWC
160
+ cutlass::Tensor4DCoord filter_size, // KRSC
161
+ cutlass::Tensor4DCoord padding, // pad_h, upper_pad_h, pad_w, upper_pad_w
162
+ cutlass::MatrixCoord stride, // stride_h, stride_w
163
+ cutlass::MatrixCoord dilation, // dilation_h, dilation_w
164
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
165
+ int split_k_slices = 1,
166
+ int groups = 1
167
+ ):
168
+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
169
+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
170
+ pad_h(padding[0]), pad_w(padding[2]),
171
+ stride_h(stride.row()), stride_w(stride.column()),
172
+ dilation_h(dilation.row()), dilation_w(dilation.column()),
173
+ mode(mode), split_k_slices(split_k_slices), groups(groups) {
174
+ // set output P and Q
175
+ P = ((H + pad_h + padding[1] - R * dilation_h) / stride_h) + 1;
176
+ Q = ((W + pad_w + padding[3] - S * dilation_w) / stride_w) + 1;
177
+ }
178
+
179
+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
180
+ // set user-defined output size and sets P and Q (skip padding, striding, and dilation)
181
+ CUTLASS_HOST_DEVICE
182
+ Conv2dProblemSize(
183
+ cutlass::Tensor4DCoord input_size, // NHWC
184
+ cutlass::Tensor4DCoord filter_size, // KRSC
185
+ cutlass::Tensor4DCoord output_size, // NPQK
186
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
187
+ int split_k_slices = 1,
188
+ int groups = 1
189
+ ):
190
+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
191
+ P(output_size.h()), Q(output_size.w()),
192
+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
193
+ pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1),
194
+ dilation_h(1), dilation_w(1),
195
+ mode(mode), split_k_slices(split_k_slices), groups(groups) {}
196
+
197
+ // Reset covolution mode in the problem
198
+ CUTLASS_HOST_DEVICE
199
+ Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) {
200
+ Conv2dProblemSize tmp(*this);
201
+ tmp.mode = mode_;
202
+ return tmp;
203
+ }
204
+
205
+ // Reset covolution mode in the problem
206
+ CUTLASS_HOST_DEVICE
207
+ Conv2dProblemSize reset_split_k_slices(int split_k_slices_) {
208
+ Conv2dProblemSize tmp(*this);
209
+ tmp.split_k_slices = split_k_slices_;
210
+ return tmp;
211
+ }
212
+
213
+ /// Equality operator (ignores mode and split_k_slice)
214
+ CUTLASS_HOST_DEVICE
215
+ bool operator==(Conv2dProblemSize const &conv) const {
216
+ return (
217
+ (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) &&
218
+ (K == conv.K) && (R == conv.R) && (S == conv.S) &&
219
+ (P == conv.P) && (Q == conv.Q) &&
220
+ (pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
221
+ (stride_h == conv.stride_h) && (stride_w == conv.stride_w) &&
222
+ (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w)
223
+ );
224
+ }
225
+
226
+ /// Inequality operator
227
+ CUTLASS_HOST_DEVICE
228
+ bool operator!=(Conv2dProblemSize const &rhs) const {
229
+ return !(*this == rhs);
230
+ }
231
+
232
+ /// Returns activation extent as Tensor4DCoord
233
+ CUTLASS_HOST_DEVICE
234
+ cutlass::Tensor4DCoord activation_extent() const {
235
+
236
+ return cutlass::Tensor4DCoord ({N, H, W, C});
237
+ }
238
+
239
+ /// Returns filter extent as Tensor4DCoord
240
+ CUTLASS_HOST_DEVICE
241
+ cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const {
242
+
243
+ return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups})
244
+ : cutlass::Tensor4DCoord ({K, R, S, C / groups});
245
+ }
246
+
247
+ /// Returns output extent as Tensor4DCoord
248
+ CUTLASS_HOST_DEVICE
249
+ cutlass::Tensor4DCoord output_extent() const {
250
+
251
+ return cutlass::Tensor4DCoord ({N, P, Q, K});
252
+ }
253
+
254
+ /// Returns activation size in number of elements
255
+ CUTLASS_HOST_DEVICE
256
+ int64_t activation_size() const {
257
+
258
+ return static_cast<int64_t>(N) * static_cast<int64_t>(H) *
259
+ static_cast<int64_t>(W) * static_cast<int64_t>(C);
260
+ }
261
+
262
+ /// Returns filter size in number of elements
263
+ CUTLASS_HOST_DEVICE
264
+ int64_t filter_size() const {
265
+
266
+ return static_cast<int64_t>(K) * static_cast<int64_t>(R) *
267
+ static_cast<int64_t>(S) * static_cast<int64_t>(C) /
268
+ static_cast<int64_t>(groups);
269
+ }
270
+
271
+ /// Returns output size in number of elements
272
+ CUTLASS_HOST_DEVICE
273
+ int64_t output_size() const {
274
+
275
+ return static_cast<int64_t>(N) * static_cast<int64_t>(P) *
276
+ static_cast<int64_t>(Q) * static_cast<int64_t>(K);
277
+ }
278
+
279
+ /// Returns padding as Tensor4DCoord
280
+ CUTLASS_HOST_DEVICE
281
+ cutlass::Tensor4DCoord padding() const {
282
+
283
+ return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w});
284
+ }
285
+
286
+ /// Returns stride as MatrixCoord
287
+ CUTLASS_HOST_DEVICE
288
+ cutlass::MatrixCoord stride() const {
289
+
290
+ return cutlass::MatrixCoord ({stride_h, stride_w});
291
+ }
292
+
293
+ /// Returns dilation as MatrixCoord
294
+ CUTLASS_HOST_DEVICE
295
+ cutlass::MatrixCoord dilation() const {
296
+
297
+ return cutlass::MatrixCoord ({dilation_h, dilation_w});
298
+ }
299
+
300
+ /////////////////////////////////////////////////////////////////
301
+ // Methods used for strided dgrad implementation
302
+ /////////////////////////////////////////////////////////////////
303
+ /// Number of filter r positions to accumulate in gemm-k dim
304
+ CUTLASS_HOST_DEVICE
305
+ int num_gemm_k_filter_r(int r) const {
306
+ return ((R - r + stride_h - 1) / stride_h);
307
+ }
308
+
309
+ /// Number of filter s positions to accumulate in gemm-k dim
310
+ CUTLASS_HOST_DEVICE
311
+ int num_gemm_k_filter_s(int s) const {
312
+ return ((S - s + stride_w - 1) / stride_w);
313
+ }
314
+
315
+ /// Number of filter positions to accumulate in gemm-k dim
316
+ CUTLASS_HOST_DEVICE
317
+ int num_gemm_k_filter_positions(int r, int s) const {
318
+ return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s);
319
+ }
320
+ };
321
+
322
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
323
+ // ImplicitGemm helper functions //
324
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
325
+
326
+ /// Determine the problem size of the implicit GEMM operation
327
+ CUTLASS_HOST_DEVICE
328
+ cutlass::gemm::GemmCoord implicit_gemm_problem_size(
329
+ Operator conv_operator,
330
+ Conv2dProblemSize const &problem_size) {
331
+ // Compute problem size
332
+ switch (conv_operator) {
333
+ case Operator::kFprop:
334
+ return gemm::GemmCoord(
335
+ problem_size.N * problem_size.P * problem_size.Q,
336
+ problem_size.K,
337
+ problem_size.R * problem_size.S * problem_size.C / problem_size.groups
338
+ );
339
+ case Operator::kDeconv:
340
+ case Operator::kDgrad:
341
+ return gemm::GemmCoord(
342
+ problem_size.N * problem_size.H * problem_size.W,
343
+ problem_size.C,
344
+ problem_size.R * problem_size.S * problem_size.K
345
+ );
346
+ case Operator::kWgrad:
347
+ return gemm::GemmCoord(
348
+ problem_size.K,
349
+ problem_size.R * problem_size.S * problem_size.C,
350
+ problem_size.N * problem_size.P * problem_size.Q
351
+ );
352
+ default:
353
+ break;
354
+ }
355
+ return gemm::GemmCoord();
356
+ }
357
+
358
+ // Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
359
+ CUTLASS_HOST_DEVICE
360
+ int implicit_gemm_k_iterations(
361
+ Operator conv_operator,
362
+ int threadblock_K,
363
+ Conv2dProblemSize const &problem_size,
364
+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
365
+ GroupMode group_mode = GroupMode::kNone,
366
+ int threadblock_N = 0) {
367
+
368
+ int iterations = 0;
369
+
370
+ if (group_mode == GroupMode::kNone) {
371
+
372
+ if (algorithm == IteratorAlgorithm::kFixedChannels) {
373
+
374
+ int positions_per_iteration = threadblock_K / problem_size.C;
375
+ switch (conv_operator) {
376
+ case Operator::kFprop:
377
+ iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration;
378
+ break;
379
+
380
+ default:
381
+ break;
382
+ }
383
+ }
384
+ else if (algorithm == IteratorAlgorithm::kFewChannels) {
385
+
386
+ switch (conv_operator) {
387
+ case Operator::kFprop:
388
+ iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K;
389
+ break;
390
+
391
+ default:
392
+ break;
393
+ }
394
+ }
395
+ else {
396
+ int elements_per_split_k_slice = 0;
397
+
398
+ switch (conv_operator) {
399
+ case Operator::kFprop:
400
+ elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
401
+ iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
402
+ break;
403
+
404
+ case Operator::kDeconv:
405
+ case Operator::kDgrad:
406
+ elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
407
+ iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
408
+ break;
409
+
410
+ case Operator::kWgrad:
411
+ elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
412
+ iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
413
+ break;
414
+
415
+ default:
416
+ break;
417
+ }
418
+ }
419
+
420
+ } else if (group_mode == GroupMode::kDepthwise) {
421
+ int channels_per_cta = threadblock_N;
422
+
423
+ if (algorithm == IteratorAlgorithm::kAnalytic) {
424
+ switch (conv_operator) {
425
+ case Operator::kFprop:
426
+ iterations = problem_size.R * problem_size.S *
427
+ ((channels_per_cta + threadblock_K - 1) / threadblock_K);
428
+ break;
429
+
430
+ default:
431
+ break;
432
+ }
433
+ }
434
+ } else { // Group conv
435
+
436
+ int channels_per_group = problem_size.C / problem_size.groups;
437
+ int k_per_group = problem_size.K / problem_size.groups;
438
+
439
+ if (algorithm == IteratorAlgorithm::kAnalytic) {
440
+ switch (conv_operator) {
441
+ case Operator::kFprop:
442
+ iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
443
+ // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
444
+ if (problem_size.groups != 1) {
445
+ if (k_per_group < threadblock_N) {
446
+ iterations *= threadblock_N / k_per_group;
447
+ }
448
+ }
449
+ break;
450
+
451
+ default:
452
+ break;
453
+ }
454
+ } else if (algorithm == IteratorAlgorithm::kOptimized) {
455
+ // Current optimized iterator only support GroupMode::kSingleGroup
456
+ if (group_mode == GroupMode::kSingleGroup) {
457
+ switch (conv_operator) {
458
+ case Operator::kFprop:
459
+ iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
460
+ break;
461
+
462
+ default:
463
+ break;
464
+ }
465
+ }
466
+ }
467
+
468
+ }
469
+
470
+ return iterations;
471
+ }
472
+
473
+
474
+ template <int N = 1, int Output_P = 1, int Output_Q = 1>
475
+ CUTLASS_HOST_DEVICE
476
+ int depthwise_gemm_k_iterations(
477
+ Operator conv_operator,
478
+ int threadblock_K,
479
+ Conv2dProblemSize const &problem_size,
480
+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
481
+ GroupMode group_mode = GroupMode::kNone,
482
+ int threadblock_N = 0) {
483
+
484
+ int n = problem_size.N;
485
+ int p = (problem_size.P + Output_P - 1) / Output_P;
486
+ int q = (problem_size.Q + Output_Q - 1) / Output_Q;
487
+
488
+ int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
489
+ return iterations;
490
+ }
491
+
492
+
493
+ CUTLASS_HOST_DEVICE
494
+ int implicit_gemm_k_iterations_per_channel(
495
+ Operator conv_operator,
496
+ Conv2dProblemSize const &problem_size,
497
+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
498
+
499
+ int iterations = 0; //0 means not applicable
500
+ if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) {
501
+ switch (conv_operator) {
502
+ case Operator::kFprop:
503
+ iterations = problem_size.R * problem_size.S;
504
+ break;
505
+
506
+ case Operator::kDeconv:
507
+ case Operator::kDgrad:
508
+ iterations = problem_size.R * problem_size.S;
509
+ break;
510
+
511
+ default:
512
+ break;
513
+ }
514
+ }
515
+ return iterations;
516
+ }
517
+
518
+ ////////////////////////////////////////////////////////////////////////////////
519
+ // Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
520
+ ////////////////////////////////////////////////////////////////////////////////
521
+ /// Returns ImplicitGemm tensor A extent as Tensor4DCoord
522
+ CUTLASS_HOST_DEVICE
523
+ cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent(
524
+ Operator conv_operator,
525
+ Conv2dProblemSize const &problem_size) {
526
+ switch (conv_operator) {
527
+ case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
528
+ case cutlass::conv::Operator::kDeconv:
529
+ case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
530
+ case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
531
+ default : break;
532
+ }
533
+ return cutlass::Tensor4DCoord();
534
+ }
535
+
536
+ /// Returns ImplicitGemm tensor B extent as Tensor4DCoord
537
+ CUTLASS_HOST_DEVICE
538
+ cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent(
539
+ Operator conv_operator,
540
+ Conv2dProblemSize const &problem_size) {
541
+ switch (conv_operator) {
542
+ case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
543
+ case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true);
544
+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
545
+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
546
+ default : break;
547
+ }
548
+ return cutlass::Tensor4DCoord();
549
+ }
550
+
551
+ /// Returns ImplicitGemm tensor C extent as Tensor4DCoord
552
+ CUTLASS_HOST_DEVICE
553
+ cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent(
554
+ Operator conv_operator,
555
+ Conv2dProblemSize const &problem_size) {
556
+ switch (conv_operator) {
557
+ case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
558
+ case cutlass::conv::Operator::kDeconv:
559
+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
560
+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
561
+ default : break;
562
+ }
563
+ return cutlass::Tensor4DCoord();
564
+ }
565
+
566
+ /// Returns ImplicitGemm tensor A size in number of elements
567
+ CUTLASS_HOST_DEVICE
568
+ int64_t implicit_gemm_tensor_a_size(
569
+ Operator conv_operator,
570
+ Conv2dProblemSize const &problem_size) {
571
+ switch (conv_operator) {
572
+ case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
573
+ case cutlass::conv::Operator::kDeconv:
574
+ case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
575
+ case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
576
+ default : break;
577
+ }
578
+ return 0;
579
+ }
580
+
581
+ /// Returns ImplicitGemm tensor B size in number of elements
582
+ CUTLASS_HOST_DEVICE
583
+ int64_t implicit_gemm_tensor_b_size(
584
+ Operator conv_operator,
585
+ Conv2dProblemSize const &problem_size) {
586
+ switch (conv_operator) {
587
+ case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
588
+ case cutlass::conv::Operator::kDeconv:
589
+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
590
+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
591
+ default : break;
592
+ }
593
+ return 0;
594
+ }
595
+
596
+ /// Returns ImplicitGemm tensor C size in number of elements
597
+ CUTLASS_HOST_DEVICE
598
+ int64_t implicit_gemm_tensor_c_size(
599
+ Operator conv_operator,
600
+ Conv2dProblemSize const &problem_size) {
601
+ switch (conv_operator) {
602
+ case cutlass::conv::Operator::kFprop: return problem_size.output_size();
603
+ case cutlass::conv::Operator::kDeconv:
604
+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
605
+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
606
+ default : break;
607
+ }
608
+ return 0;
609
+ }
610
+
611
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
612
+
613
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
614
+ // Strided dgrad helper functions //
615
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
616
+ // Returns number of CTAs tile M to cover valid MMAs per starting filter postion
617
+ CUTLASS_HOST_DEVICE
618
+ int strided_dgrad_tile_m_per_filter(
619
+ Conv2dProblemSize const &problem_size,
620
+ int tile_size_m) {
621
+
622
+ // Compute NHW rows in Dx output that needs MMA per starting filter position
623
+ int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h;
624
+ int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w;
625
+ int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter;
626
+
627
+ // Number of CTAs tile M to cover valid MMAs per starting filter postion
628
+ int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m;
629
+
630
+ return tile_m_per_filter;
631
+ }
632
+
633
+ // Computes starting Dx coord (h, w) for given starting filter postion
634
+ CUTLASS_HOST_DEVICE
635
+ void strided_dgrad_starting_coords(
636
+ Conv2dProblemSize const &problem_size,
637
+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
638
+ int r, int s,
639
+ int &start_h, int &start_w) {
640
+
641
+ // function locals for remainder by fast divmod
642
+ int pad_h_rem_, pad_w_rem_;
643
+
644
+ // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
645
+ stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
646
+ int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
647
+ stride_h_divmod.divmod(start_h, r_);
648
+
649
+ //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
650
+ stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
651
+ int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
652
+ stride_w_divmod.divmod(start_w, s_);
653
+ }
654
+
655
+ } // namespace conv
656
+ } // namespace cutlass
657
+
658
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief This file contains definitions and utility functions for describing convolution problem sizes.
33
+
34
+ Conv3dProblem desciption:
35
+ activation (NDHWC),
36
+ filter (KTRSC),
37
+ output (NZPQK),
38
+ pading (pad_d, pad_h, pad_w),
39
+ stride (stride_d, stride_h, stride_w),
40
+ dilation (dilation_d, dilation_h, dilation_w).
41
+
42
+ Free functions to map:
43
+ Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
44
+ Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
45
+ Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
46
+ */
47
+
48
+ #pragma once
49
+
50
+ #include "cutlass/conv/convolution.h"
51
+ #include "cutlass/conv/conv2d_problem_size.h"
52
+
53
+ namespace cutlass {
54
+ namespace conv {
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ /// Problem size structure
59
+ struct Conv3dProblemSize : public Conv2dProblemSize {
60
+ //
61
+ // Type definitions
62
+ //
63
+
64
+ // 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions
65
+ using Coord3D = Coord<3>;
66
+
67
+ //
68
+ // Data members
69
+ //
70
+
71
+ // Conv3d strictly problem size parameters
72
+ int D, T, Z; // input depth, filter depth, output depth
73
+ int pad_d; // padding in depth dimension
74
+ int stride_d; // stride in depth dimension
75
+ int dilation_d; // dilation in depth dimension
76
+
77
+ //
78
+ // Methods
79
+ //
80
+ public:
81
+ CUTLASS_HOST_DEVICE
82
+ Conv3dProblemSize():
83
+ Conv2dProblemSize(),
84
+ D(0), T(0), Z(0),
85
+ pad_d(0),
86
+ stride_d(1),
87
+ dilation_d(1) { }
88
+
89
+ /// Constructor for default padding, stride, dilation, and split-K
90
+ CUTLASS_HOST_DEVICE
91
+ Conv3dProblemSize(
92
+ int N,
93
+ int D,
94
+ int H,
95
+ int W,
96
+ int C,
97
+ int Z,
98
+ int P,
99
+ int Q,
100
+ int K,
101
+ int T,
102
+ int R,
103
+ int S,
104
+ Mode mode
105
+ ):
106
+ Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode),
107
+ D(D), T(T), Z(Z),
108
+ pad_d(T / 2), stride_d(1), dilation_d(1) { }
109
+
110
+ /// Constructor
111
+ CUTLASS_HOST_DEVICE
112
+ Conv3dProblemSize(
113
+ int N,
114
+ int D,
115
+ int H,
116
+ int W,
117
+ int C,
118
+ int K,
119
+ int T,
120
+ int R,
121
+ int S,
122
+ int Z,
123
+ int P,
124
+ int Q,
125
+ int pad_d,
126
+ int pad_h,
127
+ int pad_w,
128
+ int stride_d,
129
+ int stride_h,
130
+ int stride_w,
131
+ int dilation_d,
132
+ int dilation_h,
133
+ int dilation_w,
134
+ Mode mode,
135
+ int split_k_slices = 1,
136
+ int groups = 1
137
+ ):
138
+ Conv2dProblemSize(
139
+ N, H, W, C, K, R, S, P, Q,
140
+ pad_h, pad_w,
141
+ stride_h, stride_w,
142
+ dilation_h, dilation_w,
143
+ mode, split_k_slices, groups),
144
+ D(D), T(T), Z(Z),
145
+ pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { }
146
+
147
+ /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
148
+ // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor)
149
+ CUTLASS_HOST_DEVICE
150
+ Conv3dProblemSize(
151
+ cutlass::Tensor5DCoord input_size, // NDHWC
152
+ cutlass::Tensor5DCoord filter_size, // KTRSC
153
+ Coord3D padding, // pad_d, pad_h, pad_w
154
+ Coord3D stride, // stride_d, stride_h, stride_w
155
+ Coord3D dilation, // dilation_d, dilation_h, dilation_w
156
+ cutlass::Tensor5DCoord output_size, // NZPQK
157
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
158
+ int split_k_slices = 1,
159
+ int groups = 1
160
+ ):
161
+ Conv2dProblemSize(
162
+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()},
163
+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
164
+ {padding[1], padding[1], padding[2], padding[2]},
165
+ {stride[1], stride[2]},
166
+ {dilation[1], dilation[2]},
167
+ {output_size.n(), output_size.h(), output_size.w(), output_size.c()},
168
+ mode, split_k_slices, groups),
169
+ D(input_size.d()), T(filter_size.d()), Z(output_size.d()),
170
+ pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { }
171
+
172
+ /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
173
+ // *computes* output size and sets Z, P and Q (include all data members in ctor)
174
+ CUTLASS_HOST_DEVICE
175
+ Conv3dProblemSize(
176
+ cutlass::Tensor5DCoord input_size, // NDHWC
177
+ cutlass::Tensor5DCoord filter_size, // KTRSC
178
+ Coord3D padding, // pad_d, pad_h, pad_w
179
+ Coord3D stride, // stride_d, stride_h, stride_w
180
+ Coord3D dilation, // dilation_d, dilation_h, dilation_w
181
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
182
+ int split_k_slices = 1,
183
+ int groups = 1
184
+ ):
185
+ Conv2dProblemSize(
186
+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()},
187
+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
188
+ {padding[1], padding[1], padding[2], padding[2]},
189
+ {stride[1], stride[2]},
190
+ {dilation[1], dilation[2]},
191
+ mode, split_k_slices, groups),
192
+ D(input_size.d()), T(filter_size.d()),
193
+ pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0])
194
+ {
195
+ // set output Z
196
+ Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1;
197
+ }
198
+
199
+ /// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D
200
+ // *computes* output size and sets Z, P and Q (include all data members in ctor)
201
+ CUTLASS_HOST_DEVICE
202
+ Conv3dProblemSize(
203
+ cutlass::Tensor5DCoord input_size, // NDHWC
204
+ cutlass::Tensor5DCoord filter_size, // KTRSC
205
+ CUTLASS_STL_NAMESPACE::tuple<Coord3D, Coord3D> padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q
206
+ Coord3D stride, // stride_d, stride_h, stride_w
207
+ Coord3D dilation, // dilation_d, dilation_h, dilation_w
208
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
209
+ int split_k_slices = 1,
210
+ int groups = 1
211
+ ):
212
+ Conv2dProblemSize(
213
+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()},
214
+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
215
+ {CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1],
216
+ CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]},
217
+ {stride[1], stride[2]},
218
+ {dilation[1], dilation[2]},
219
+ mode, split_k_slices, groups),
220
+ D(input_size.d()), T(filter_size.d()),
221
+ pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0])
222
+ {
223
+ // set output Z
224
+ Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1;
225
+ }
226
+
227
+ /// Equality operator (ignores mode and split_k_slice)
228
+ CUTLASS_HOST_DEVICE
229
+ bool operator==(Conv3dProblemSize const &conv) const {
230
+ return (
231
+ (N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) &&
232
+ (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) &&
233
+ (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) &&
234
+ (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
235
+ (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) &&
236
+ (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w)
237
+ );
238
+ }
239
+
240
+ /// Inequality operator
241
+ CUTLASS_HOST_DEVICE
242
+ bool operator!=(Conv3dProblemSize const &rhs) const {
243
+ return !(*this == rhs);
244
+ }
245
+
246
+ // Reset covolution mode in the problem
247
+ CUTLASS_HOST_DEVICE
248
+ Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) {
249
+ Conv3dProblemSize tmp(*this);
250
+ tmp.mode = mode_;
251
+ return tmp;
252
+ }
253
+
254
+ // Reset covolution mode in the problem
255
+ CUTLASS_HOST_DEVICE
256
+ Conv3dProblemSize reset_split_k_slices(int split_k_slices_) {
257
+ Conv3dProblemSize tmp(*this);
258
+ tmp.split_k_slices = split_k_slices_;
259
+ return tmp;
260
+ }
261
+
262
+ /// Returns activation extent as Tensor5DCoord
263
+ CUTLASS_HOST_DEVICE
264
+ cutlass::Tensor5DCoord activation_extent() const {
265
+
266
+ return cutlass::Tensor5DCoord ({N, D, H, W, C});
267
+ }
268
+
269
+ /// Returns filter extent as Tensor5DCoord
270
+ CUTLASS_HOST_DEVICE
271
+ cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const {
272
+
273
+ return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K})
274
+ : cutlass::Tensor5DCoord ({K, T, R, S, C});
275
+ }
276
+
277
+ /// Returns output extent as Tensor5DCoord
278
+ CUTLASS_HOST_DEVICE
279
+ cutlass::Tensor5DCoord output_extent() const {
280
+
281
+ return cutlass::Tensor5DCoord ({N, Z, P, Q, K});
282
+ }
283
+
284
+ /// Returns activation size in number of elements
285
+ CUTLASS_HOST_DEVICE
286
+ int64_t activation_size() const {
287
+
288
+ return static_cast<int64_t>(N) * static_cast<int64_t>(D) *
289
+ static_cast<int64_t>(H) * static_cast<int64_t>(W) *
290
+ static_cast<int64_t>(C);
291
+ }
292
+
293
+ /// Returns filter size in number of elements
294
+ CUTLASS_HOST_DEVICE
295
+ int64_t filter_size() const {
296
+
297
+ return static_cast<int64_t>(K) * static_cast<int64_t>(T) *
298
+ static_cast<int64_t>(R) * static_cast<int64_t>(S) *
299
+ static_cast<int64_t>(C);
300
+ }
301
+
302
+ /// Returns output size in number of elements
303
+ CUTLASS_HOST_DEVICE
304
+ int64_t output_size() const {
305
+
306
+ return static_cast<int64_t>(N) * static_cast<int64_t>(Z) *
307
+ static_cast<int64_t>(P) * static_cast<int64_t>(Q) *
308
+ static_cast<int64_t>(K);
309
+ }
310
+
311
+ /// Returns padding as Coord3D
312
+ CUTLASS_HOST_DEVICE
313
+ Coord3D padding() const {
314
+
315
+ return Coord3D ({pad_d, pad_h, pad_w});
316
+ }
317
+
318
+ /// Returns stride as MatrixCoord
319
+ CUTLASS_HOST_DEVICE
320
+ Coord3D stride() const {
321
+
322
+ return Coord3D ({stride_d, stride_h, stride_w});
323
+ }
324
+
325
+ /// Returns dilation as MatrixCoord
326
+ CUTLASS_HOST_DEVICE
327
+ Coord3D dilation() const {
328
+
329
+ return Coord3D ({dilation_d, dilation_h, dilation_w});
330
+ }
331
+
332
+ };
333
+
334
+
335
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
336
+ // ImplicitGemm helper functions //
337
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
338
+
339
+ /// Determine the problem size of the implicit GEMM operation
340
+ CUTLASS_HOST_DEVICE
341
+ cutlass::gemm::GemmCoord implicit_gemm_problem_size(
342
+ Operator conv_operator,
343
+ Conv3dProblemSize const &problem_size) {
344
+ // Compute problem size
345
+ switch (conv_operator) {
346
+ case Operator::kFprop:
347
+ return gemm::GemmCoord(
348
+ problem_size.N * problem_size.Z * problem_size.P * problem_size.Q,
349
+ problem_size.K,
350
+ problem_size.T * problem_size.R * problem_size.S * problem_size.C
351
+ );
352
+ case Operator::kDeconv:
353
+ case Operator::kDgrad:
354
+ return gemm::GemmCoord(
355
+ problem_size.N * problem_size.D * problem_size.H * problem_size.W,
356
+ problem_size.C,
357
+ problem_size.T * problem_size.R * problem_size.S * problem_size.K
358
+ );
359
+ case Operator::kWgrad:
360
+ return gemm::GemmCoord(
361
+ problem_size.K,
362
+ problem_size.T * problem_size.R * problem_size.S * problem_size.C,
363
+ problem_size.N * problem_size.Z * problem_size.P * problem_size.Q
364
+ );
365
+ default:
366
+ break;
367
+ }
368
+ return gemm::GemmCoord();
369
+ }
370
+
371
+ // Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
372
+ CUTLASS_HOST_DEVICE
373
+ int implicit_gemm_k_iterations(
374
+ Operator conv_operator,
375
+ int threadblock_K,
376
+ Conv3dProblemSize const &problem_size,
377
+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
378
+ GroupMode group_mode = GroupMode::kNone,
379
+ int threadblock_N = 0) {
380
+
381
+ int iterations = 0;
382
+ int elements_per_split_k_slice = 0;
383
+ if (group_mode == GroupMode::kNone) {
384
+ switch (conv_operator) {
385
+ case Operator::kFprop:
386
+ elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
387
+ iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
388
+ break;
389
+
390
+ case Operator::kDeconv:
391
+ case Operator::kDgrad:
392
+ elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
393
+ iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
394
+ break;
395
+
396
+ case Operator::kWgrad:
397
+ elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
398
+ iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
399
+ break;
400
+
401
+ default:
402
+ break;
403
+ }
404
+ } else if (group_mode == GroupMode::kDepthwise) {
405
+ int channels_per_cta = threadblock_N;
406
+
407
+ if (algorithm == IteratorAlgorithm::kAnalytic) {
408
+ switch (conv_operator) {
409
+ case Operator::kFprop:
410
+ iterations = problem_size.T * problem_size.R * problem_size.S *
411
+ ((channels_per_cta + threadblock_K - 1) / threadblock_K);
412
+ break;
413
+
414
+ default:
415
+ break;
416
+ }
417
+ }
418
+ }
419
+
420
+ return iterations;
421
+ }
422
+
423
+ ////////////////////////////////////////////////////////////////////////////////
424
+ // Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
425
+ ////////////////////////////////////////////////////////////////////////////////
426
+ /// Returns ImplicitGemm tensor A extent as Tensor5DCoord
427
+ CUTLASS_HOST_DEVICE
428
+ cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent(
429
+ Operator conv_operator,
430
+ Conv3dProblemSize const &problem_size) {
431
+ switch (conv_operator) {
432
+ case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
433
+ case cutlass::conv::Operator::kDeconv:
434
+ case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
435
+ case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
436
+ default : break;
437
+ }
438
+ return cutlass::Tensor5DCoord();
439
+ }
440
+
441
+ /// Returns ImplicitGemm tensor B extent as Tensor5DCoord
442
+ CUTLASS_HOST_DEVICE
443
+ cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent(
444
+ Operator conv_operator,
445
+ Conv3dProblemSize const &problem_size) {
446
+ switch (conv_operator) {
447
+ case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
448
+ case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true);
449
+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
450
+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
451
+ default : break;
452
+ }
453
+ return cutlass::Tensor5DCoord();
454
+ }
455
+
456
+ /// Returns ImplicitGemm tensor C extent as Tensor5DCoord
457
+ CUTLASS_HOST_DEVICE
458
+ cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent(
459
+ Operator conv_operator,
460
+ Conv3dProblemSize const &problem_size) {
461
+ switch (conv_operator) {
462
+ case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
463
+ case cutlass::conv::Operator::kDeconv:
464
+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
465
+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
466
+ default : break;
467
+ }
468
+ return cutlass::Tensor5DCoord();
469
+ }
470
+
471
+ /// Returns ImplicitGemm tensor A size in number of elements
472
+ CUTLASS_HOST_DEVICE
473
+ int64_t implicit_gemm_tensor_a_size(
474
+ Operator conv_operator,
475
+ Conv3dProblemSize const &problem_size) {
476
+ switch (conv_operator) {
477
+ case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
478
+ case cutlass::conv::Operator::kDeconv:
479
+ case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
480
+ case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
481
+ default : break;
482
+ }
483
+ return 0;
484
+ }
485
+
486
+ /// Returns ImplicitGemm tensor B size in number of elements
487
+ CUTLASS_HOST_DEVICE
488
+ int64_t implicit_gemm_tensor_b_size(
489
+ Operator conv_operator,
490
+ Conv3dProblemSize const &problem_size) {
491
+ switch (conv_operator) {
492
+ case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
493
+ case cutlass::conv::Operator::kDeconv:
494
+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
495
+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
496
+ default : break;
497
+ }
498
+ return 0;
499
+ }
500
+
501
+ /// Returns ImplicitGemm tensor C size in number of elements
502
+ CUTLASS_HOST_DEVICE
503
+ int64_t implicit_gemm_tensor_c_size(
504
+ Operator conv_operator,
505
+ Conv3dProblemSize const &problem_size) {
506
+ switch (conv_operator) {
507
+ case cutlass::conv::Operator::kFprop: return problem_size.output_size();
508
+ case cutlass::conv::Operator::kDeconv:
509
+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
510
+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
511
+ default : break;
512
+ }
513
+ return 0;
514
+ }
515
+
516
+ } // namespace conv
517
+ } // namespace cutlass
518
+
519
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief This file contains definitions and utility functions for describing convolution problem shapes.
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/tensor_coord.h"
38
+ #include "cutlass/conv/convolution.h"
39
+
40
+ #include "cute/container/array.hpp"
41
+
42
+ #if ! defined(__CUDACC_RTC__)
43
+ #include <initializer_list>
44
+ #endif
45
+
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass::conv {
50
+
51
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ // Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion.
54
+ // All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK)
55
+ // Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types.
56
+ template <
57
+ conv::Operator ConvOp_,
58
+ int NumSpatialDimensions_
59
+ >
60
+ struct ConvProblemShape {
61
+ //
62
+ // Alias types for members
63
+ //
64
+
65
+ static constexpr int RankS = NumSpatialDimensions_;
66
+ static constexpr int RankT = NumSpatialDimensions_ + 2;
67
+ static constexpr conv::Operator ConvOp = ConvOp_;
68
+ static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
69
+ using SpatialExtent = cute::array<int, RankS>;
70
+ using TensorExtent = cute::array<int, RankT>;
71
+ using TensorStride = cute::array<int64_t, RankT>;
72
+ using ShapePadding = SpatialExtent;
73
+ using TraversalStride = SpatialExtent;
74
+ using ShapeDilation = SpatialExtent;
75
+ using Corner = SpatialExtent;
76
+
77
+ //
78
+ // Members
79
+ //
80
+ cutlass::conv::Mode mode{};
81
+ TensorExtent shape_A{};
82
+ TensorStride stride_A{};
83
+ TensorExtent shape_B{};
84
+ TensorStride stride_B{};
85
+ TensorExtent shape_C{};
86
+ TensorStride stride_C{};
87
+
88
+ // asymmetric padding, both upper and lower padding must be >= 0
89
+ ShapePadding lower_padding{};
90
+ ShapePadding upper_padding{};
91
+ TraversalStride traversal_stride{};
92
+ ShapeDilation dilation{};
93
+ int groups = 1;
94
+
95
+ //
96
+ // Methods
97
+ //
98
+
99
+ ConvProblemShape() = default;
100
+
101
+ // Constructor accepts user facing arguments and computes to stores the corners as its internal state
102
+ ConvProblemShape(
103
+ conv::Mode mode, // convolution/cross-correlation
104
+ TensorExtent shape_act, // [n,d,h,w,c]
105
+ TensorStride stride_act, // [n,d,h,w,c]
106
+ TensorExtent shape_flt, // [k,t,r,s,c]
107
+ TensorStride stride_flt, // [k,t,r,s,c]
108
+ ShapePadding lower_padding, // [pad_d, pad_h, pad_w]
109
+ ShapePadding upper_padding, // [pad_d, pad_h, pad_w]
110
+ TraversalStride tstride, // [stride_d, stride_h, stride_w]
111
+ ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w]
112
+ int groups)
113
+ : mode(mode)
114
+ , lower_padding(lower_padding)
115
+ , upper_padding(upper_padding)
116
+ , traversal_stride(tstride)
117
+ , dilation(dilation)
118
+ , groups(groups) {
119
+
120
+ auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
121
+ set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
122
+ }
123
+
124
+ // Allow user input of xformed activation stride to support non-packed strides.
125
+ ConvProblemShape(
126
+ conv::Mode mode, // convolution/cross-correlation
127
+ TensorExtent shape_act, // [n,d,h,w,c]
128
+ TensorStride stride_act, // [n,d,h,w,c]
129
+ TensorExtent shape_flt, // [k,t,r,s,c]
130
+ TensorStride stride_flt, // [k,t,r,s,c]
131
+ TensorStride stride_xformed_act, // [n,z,p,q,k]
132
+ ShapePadding lower_padding, // [pad_d, pad_h, pad_w]
133
+ ShapePadding upper_padding, // [pad_d, pad_h, pad_w]
134
+ TraversalStride tstride, // [stride_d, stride_h, stride_w]
135
+ ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w]
136
+ int groups)
137
+ : mode(mode)
138
+ , lower_padding(lower_padding)
139
+ , upper_padding(upper_padding)
140
+ , traversal_stride(tstride)
141
+ , dilation(dilation)
142
+ , groups(groups) {
143
+
144
+ CUTLASS_ASSERT(stride_act[RankT - 1] == 1);
145
+ CUTLASS_ASSERT(stride_flt[RankT - 1] == 1);
146
+ CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1);
147
+
148
+ auto stride_act_packed = packed_stride_right_major(shape_act);
149
+ auto stride_flt_packed = packed_stride_right_major(shape_flt);
150
+ auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt);
151
+
152
+ CUTLASS_PRAGMA_UNROLL
153
+ for(int i = 0; i < RankT - 1; ++i) {
154
+ CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]);
155
+ CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]);
156
+ CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]);
157
+ }
158
+
159
+ set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
160
+ }
161
+
162
+ // Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order.
163
+ ConvProblemShape(
164
+ conv::Mode mode,
165
+ TensorExtent shape_act,
166
+ TensorExtent shape_flt,
167
+ ShapePadding lower_padding,
168
+ ShapePadding upper_padding,
169
+ TraversalStride tstride,
170
+ ShapeDilation dilation,
171
+ int groups)
172
+ : ConvProblemShape(
173
+ mode,
174
+ shape_act,
175
+ packed_stride_right_major(shape_act),
176
+ shape_flt,
177
+ packed_stride_right_major(shape_flt),
178
+ lower_padding,
179
+ upper_padding,
180
+ tstride,
181
+ dilation,
182
+ groups) {
183
+ }
184
+
185
+ #if ! defined(__CUDACC_RTC__)
186
+ // Constructor accepts user facing arguments and computes to stores the corners as its internal state
187
+ ConvProblemShape(
188
+ conv::Mode mode,
189
+ std::initializer_list<int> shape_act_,
190
+ std::initializer_list<int64_t> stride_act_,
191
+ std::initializer_list<int> shape_flt_,
192
+ std::initializer_list<int64_t> stride_flt_,
193
+ std::initializer_list<int> lower_padding_,
194
+ std::initializer_list<int> upper_padding_,
195
+ std::initializer_list<int> traversal_stride_,
196
+ std::initializer_list<int> dilation_,
197
+ int groups)
198
+ : mode(mode)
199
+ , groups(groups) {
200
+
201
+ TensorExtent shape_act{};
202
+ TensorStride stride_act{};
203
+ TensorExtent shape_flt{};
204
+ TensorStride stride_flt{};
205
+
206
+ assert(shape_act_.size() == shape_act.size());
207
+ assert(stride_act_.size() == stride_act.size());
208
+ assert(shape_flt_.size() == shape_flt.size());
209
+ assert(stride_flt_.size() == stride_flt.size());
210
+ assert(lower_padding_.size() == lower_padding.size());
211
+ assert(upper_padding_.size() == upper_padding.size());
212
+ assert(traversal_stride_.size() == traversal_stride.size());
213
+ assert(dilation_.size() == dilation.size());
214
+
215
+ std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
216
+ std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin());
217
+ std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
218
+ std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin());
219
+ std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
220
+ std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
221
+ std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
222
+ std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
223
+
224
+ auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
225
+ set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
226
+ }
227
+
228
+ // Allow user input of xformed activation stride to support non-packed strides.
229
+ ConvProblemShape(
230
+ conv::Mode mode,
231
+ std::initializer_list<int> shape_act_,
232
+ std::initializer_list<int64_t> stride_act_,
233
+ std::initializer_list<int> shape_flt_,
234
+ std::initializer_list<int64_t> stride_flt_,
235
+ std::initializer_list<int64_t> stride_xformed_act_,
236
+ std::initializer_list<int> lower_padding_,
237
+ std::initializer_list<int> upper_padding_,
238
+ std::initializer_list<int> traversal_stride_,
239
+ std::initializer_list<int> dilation_,
240
+ int groups)
241
+ : mode(mode)
242
+ , groups(groups) {
243
+ TensorExtent shape_act{};
244
+ TensorStride stride_act{};
245
+ TensorExtent shape_flt{};
246
+ TensorStride stride_flt{};
247
+ TensorStride stride_xformed_act{};
248
+
249
+ std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
250
+ std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin());
251
+ std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
252
+ std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin());
253
+ std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin());
254
+ std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
255
+ std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
256
+ std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
257
+ std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
258
+
259
+ CUTLASS_ASSERT(stride_act[RankT - 1] == 1);
260
+ CUTLASS_ASSERT(stride_flt[RankT - 1] == 1);
261
+ CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1);
262
+
263
+ auto stride_act_packed = packed_stride_right_major(shape_act);
264
+ auto stride_flt_packed = packed_stride_right_major(shape_flt);
265
+ auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt);
266
+
267
+ CUTLASS_PRAGMA_UNROLL
268
+ for(int i = 0; i < RankT - 1; ++i) {
269
+ CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]);
270
+ CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]);
271
+ CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]);
272
+ }
273
+
274
+ set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
275
+ }
276
+
277
+ // Constructor accepts user facing arguments and computes to stores the corners as its internal state
278
+ ConvProblemShape(
279
+ conv::Mode mode,
280
+ std::initializer_list<int> shape_act_,
281
+ std::initializer_list<int> shape_flt_,
282
+ std::initializer_list<int> lower_padding_,
283
+ std::initializer_list<int> upper_padding_,
284
+ std::initializer_list<int> traversal_stride_,
285
+ std::initializer_list<int> dilation_,
286
+ int groups)
287
+ : mode(mode)
288
+ , groups(groups) {
289
+ TensorExtent shape_act{};
290
+ TensorStride stride_act{};
291
+ TensorExtent shape_flt{};
292
+ TensorStride stride_flt{};
293
+
294
+ assert(shape_act_.size() == shape_act.size());
295
+ assert(shape_flt_.size() == shape_flt.size());
296
+ assert(lower_padding_.size() == lower_padding.size());
297
+ assert(upper_padding_.size() == upper_padding.size());
298
+ assert(traversal_stride_.size() == traversal_stride.size());
299
+ assert(dilation_.size() == dilation.size());
300
+
301
+ std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
302
+ std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
303
+ std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
304
+ std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
305
+ std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
306
+ std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
307
+ stride_act = packed_stride_right_major(shape_act);
308
+ stride_flt = packed_stride_right_major(shape_flt);
309
+
310
+ auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
311
+ set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
312
+ }
313
+ #endif // not defined(__CUDACC_RTC__)
314
+
315
+ // Set shape and stride of tensor A/B/C according to following table:
316
+ // | | Fprop | Dgrad | Wgrad |
317
+ // | ------ | ------ | ------ | ------|
318
+ // | ShapeA | NDHWC | NZPQK | NZPQK |
319
+ // | ShapeB | KTRSC | KTRSC | NDHWC |
320
+ // | ShapeC | NZPQK | NDHWC | KTRSC |
321
+ //
322
+ // Input comes from calculate_xformed_act, which does NOT depend on ConvOp.
323
+ CUTLASS_HOST_DEVICE
324
+ constexpr void
325
+ set_shape_stride_ABC(
326
+ TensorExtent shape_act,
327
+ TensorStride stride_act,
328
+ TensorExtent shape_flt,
329
+ TensorStride stride_flt,
330
+ TensorExtent shape_xformed_act,
331
+ TensorStride stride_xformed_act) {
332
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
333
+ printf("*** set_shape_stride_ABC ***");
334
+ printf("\n shape_act: ");
335
+ print(shape_act);
336
+ printf("\n stride_act: ");
337
+ print(stride_act);
338
+ printf("\n shape_flt: ");
339
+ print(shape_flt);
340
+ printf("\n stride_flt: ");
341
+ print(stride_flt);
342
+ printf("\n shape_xformed_act: ");
343
+ print(shape_xformed_act);
344
+ printf("\n stride_xformed_act: ");
345
+ print(stride_xformed_act);
346
+ if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
347
+ printf("\n ConvOp: Fprop");
348
+ }
349
+ if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
350
+ printf("\n ConvOp: Dgrad");
351
+ }
352
+ if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
353
+ printf("\n ConvOp: Wgrad");
354
+ }
355
+ printf("\n");
356
+ #endif
357
+
358
+ if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
359
+ shape_A = shape_act;
360
+ stride_A = stride_act;
361
+ shape_B = shape_flt;
362
+ stride_B = stride_flt;
363
+ shape_C = shape_xformed_act;
364
+ stride_C = stride_xformed_act;
365
+ }
366
+ else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
367
+ shape_A = shape_xformed_act;
368
+ stride_A = stride_xformed_act;
369
+ shape_B = shape_flt;
370
+ stride_B = stride_flt;
371
+ shape_C = shape_act;
372
+ stride_C = stride_act;
373
+ }
374
+ else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
375
+ shape_A = shape_xformed_act;
376
+ stride_A = stride_xformed_act;
377
+ shape_B = shape_act;
378
+ stride_B = stride_act;
379
+ shape_C = shape_flt;
380
+ stride_C = stride_flt;
381
+ }
382
+ #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
383
+ printf("\n shape_A: ");
384
+ print(shape_A);
385
+ printf("\n stride_A: ");
386
+ print(stride_A);
387
+ printf("\n shape_B: ");
388
+ print(shape_B);
389
+ printf("\n stride_B: ");
390
+ print(stride_B);
391
+ printf("\n shape_C: ");
392
+ print(shape_C);
393
+ printf("\n stride_C: ");
394
+ print(stride_C);
395
+ #endif
396
+ }
397
+
398
+ // Get A extents.
399
+ // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C))
400
+ // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K))
401
+ // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N))
402
+ CUTLASS_HOST_DEVICE
403
+ constexpr auto
404
+ get_shape_A() const {
405
+ using cute::make_shape;
406
+ using cute::take;
407
+
408
+ if constexpr (ConvOp == conv::Operator::kFprop ||
409
+ ConvOp == conv::Operator::kDgrad) {
410
+ return make_shape(
411
+ cute::reverse(take<0, RankT - 1>(shape_A)),
412
+ shape_A[RankT - 1]);
413
+ }
414
+ // For wgrad kernel, we need to linearize NZPQ for tensor A
415
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
416
+ return make_shape(
417
+ shape_A[RankT - 1],
418
+ cute::product(take<0, RankT - 1>(shape_A)));
419
+ }
420
+ }
421
+
422
+ // Get B extents.
423
+ // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T))
424
+ // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T))
425
+ // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N))
426
+ CUTLASS_HOST_DEVICE
427
+ constexpr auto
428
+ get_shape_B() const {
429
+ using cute::make_shape;
430
+ using cute::reverse;
431
+ using cute::take;
432
+
433
+ if constexpr (ConvOp == conv::Operator::kFprop) {
434
+ return make_shape(
435
+ shape_B[0],
436
+ reverse(take<1, RankT>(shape_B)));
437
+ }
438
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
439
+ return make_shape(
440
+ shape_B[RankT - 1],
441
+ reverse(take<0, RankT - 1>(shape_B)));
442
+ }
443
+ else if constexpr (ConvOp == conv::Operator::kDgrad) {
444
+ // shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)]
445
+ return make_shape(
446
+ shape_B[RankT - 1],
447
+ cute::insert<0>(
448
+ reverse(take<1, RankT - 1>(shape_B)),
449
+ shape_B[0]));
450
+ }
451
+ }
452
+
453
+ // Get C extents.
454
+ // fprop: C extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K))
455
+ // dgrad: C extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C))
456
+ // wgrad: C extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T))
457
+ CUTLASS_HOST_DEVICE
458
+ constexpr auto
459
+ get_shape_C() const {
460
+ using cute::make_shape;
461
+ using cute::reverse;
462
+ using cute::take;
463
+
464
+ if constexpr (ConvOp == conv::Operator::kFprop ||
465
+ ConvOp == conv::Operator::kDgrad) {
466
+ return make_shape(
467
+ reverse(take<0, RankT - 1>(shape_C)),
468
+ shape_C[RankT - 1]);
469
+ }
470
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
471
+ return make_shape(
472
+ shape_C[0],
473
+ reverse(take<1, RankT>(shape_C)));
474
+ }
475
+ }
476
+
477
+ // Static method that returns the canonical strides of tensors (layouts are right major and compact)
478
+ CUTLASS_HOST_DEVICE
479
+ static constexpr TensorStride
480
+ packed_stride_right_major(TensorExtent const& extents) {
481
+ TensorStride strides{};
482
+ strides[RankT-1] = 1;
483
+ cute::for_each(cute::make_rseq<RankT-1>{}, [&](auto i) {
484
+ strides[i] = extents[i+1] * strides[i+1];
485
+ });
486
+ return strides;
487
+ }
488
+
489
+ // Static method that returns the packed logical size of any TensorExtent
490
+ CUTLASS_HOST_DEVICE
491
+ static constexpr size_t
492
+ size(TensorExtent const& extents) {
493
+ size_t size = 1;
494
+ cute::for_each(cute::make_seq<RankT>{}, [&](auto i) {
495
+ size *= extents[i];
496
+ });
497
+ return size;
498
+ }
499
+
500
+ CUTLASS_HOST_DEVICE
501
+ constexpr size_t
502
+ size_A() const {
503
+ return shape_A[0] * stride_A[0];
504
+ }
505
+
506
+ CUTLASS_HOST_DEVICE
507
+ constexpr size_t
508
+ size_B() const {
509
+ return shape_B[0] * stride_B[0];
510
+ }
511
+
512
+ CUTLASS_HOST_DEVICE
513
+ constexpr size_t
514
+ size_C() const {
515
+ return shape_C[0] * stride_C[0];
516
+ }
517
+
518
+ // Equality operator
519
+ CUTLASS_HOST_DEVICE
520
+ bool operator==(ConvProblemShape<ConvOp, NumSpatialDimensions> const& rhs) const {
521
+ using cute::for_each;
522
+ using cute::make_seq;
523
+
524
+ bool is_equal = true;
525
+
526
+ // Compare all tensor extents
527
+ for_each(make_seq<RankT>{}, [&](auto i) {
528
+ is_equal = is_equal
529
+ && (shape_A[i] == rhs.shape_A[i])
530
+ && (shape_B[i] == rhs.shape_B[i]);
531
+ });
532
+
533
+ // Compare all spatial extents
534
+ for_each(make_seq<RankS>{}, [&](auto i) {
535
+ is_equal = is_equal
536
+ && (lower_padding[i] == rhs.lower_padding[i])
537
+ && (upper_padding[i] == rhs.upper_padding[i])
538
+ && (traversal_stride[i] == rhs.traversal_stride[i])
539
+ && (dilation[i] == rhs.dilation[i]);
540
+ });
541
+
542
+ return is_equal;
543
+ }
544
+
545
+ /// Inequality operator
546
+ CUTLASS_HOST_DEVICE
547
+ bool operator!=(ConvProblemShape<ConvOp, NumSpatialDimensions> const &rhs) const {
548
+ return !(*this == rhs);
549
+ }
550
+
551
+ private:
552
+ CUTLASS_HOST_DEVICE
553
+ constexpr auto
554
+ calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) {
555
+ TensorExtent shape_xformed_act{};
556
+ // calculate n,z,p,q,k.
557
+ // a helper lambda to compute a single spatial extent of the nzpqk tensor
558
+ auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) {
559
+ return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride;
560
+ };
561
+
562
+ shape_xformed_act[0] = shape_act[0]; // Activation N extent
563
+ cute::for_each(cute::make_seq<RankS>{}, [&](auto i) {
564
+ shape_xformed_act[i+1] = nzpqk_extent(
565
+ shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]);
566
+ });
567
+ shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent
568
+
569
+ TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act);
570
+
571
+ return cute::make_tuple(shape_xformed_act, stride_xformed_act);
572
+ }
573
+ };
574
+
575
+ template<
576
+ conv::Operator ConvOp,
577
+ int SpatialDim
578
+ >
579
+ void print(ConvProblemShape<ConvOp, SpatialDim> const& problem) {
580
+ printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n",
581
+ SpatialDim, int(ConvOp));
582
+ printf("\tTensorA: ");
583
+ cute::print(problem.shape_A); printf(":");
584
+ cute::print(problem.stride_A); printf("\n");
585
+ printf("\tTensorB: ");
586
+ cute::print(problem.shape_B); printf(":");
587
+ cute::print(problem.stride_B); printf("\n");
588
+ printf("\tTensorC: ");
589
+ cute::print(problem.shape_C); printf(":");
590
+ cute::print(problem.stride_C); printf("\n");
591
+ printf("\tLower padding: "); print(problem.lower_padding); printf("\n");
592
+ printf("\tUpper padding: "); print(problem.upper_padding); printf("\n");
593
+ printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n");
594
+ printf("\tDilation: "); print(problem.dilation); printf("\n");
595
+ }
596
+
597
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
598
+
599
+ } // namespace cutlass::conv
600
+
601
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+
34
+ This file contains definitions and utility functions for describing convolution problem sizes in terms of
35
+ activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and
36
+ dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm
37
+ tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types.
38
+
39
+ * Mapping convolutions to Gemm computation *
40
+
41
+ Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm
42
+ (general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output.
43
+ The underlying gemm operation follows the standard gemm definition:
44
+
45
+ C = A * B + C
46
+
47
+ A and B are input matrices
48
+ C is source and output matrix
49
+
50
+
51
+ For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped
52
+ to convolution tensors Activation, Filter and Output as described in the table below.
53
+
54
+ ___________________________________________________________________________
55
+ ConvolutionalOperator | A | B | C
56
+ ___________________________________________________________________________
57
+ | | | | |
58
+ | Fprop | Activation | Filter | Output |
59
+ | Dgrad | Output | Filter | Activation |
60
+ | Wgrad | Output | Activation | Filter |
61
+ ___________________________________________________________________________
62
+
63
+ In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output).
64
+
65
+ For example, it's confusing and error prone to document a convolution class or function
66
+ as operating on "A, B, Output." Instead, use the mapping functions below,
67
+ and adhere to using either A, B, C or Activation, Filter, Output.
68
+
69
+ Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap
70
+ Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap
71
+ */
72
+
73
+ #pragma once
74
+
75
+ #include "cutlass/cutlass.h"
76
+ #include "cutlass/layout/tensor.h"
77
+ #include "cutlass/tensor_coord.h"
78
+ #include "cutlass/fast_math.h"
79
+ #include "cutlass/gemm/gemm_enumerated_types.h"
80
+ #include "cutlass/matrix_coord.h"
81
+
82
+ namespace cutlass {
83
+ namespace conv {
84
+
85
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
86
+
87
+ /// Convolutional operator
88
+ enum class Operator {
89
+ kFprop,
90
+ kDgrad,
91
+ kWgrad,
92
+ kDeconv
93
+ };
94
+
95
+ /// Distinguishes convolution from cross correlation
96
+ enum class Mode {
97
+ kCrossCorrelation,
98
+ kConvolution
99
+ };
100
+
101
+ /// Selects among several implementation variants trading off performance with simplicity
102
+ enum class IteratorAlgorithm {
103
+ kAnalytic, ///< functionally correct in all cases but lower performance
104
+ kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad
105
+ kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize)
106
+ kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize)
107
+ kFixedStrideDilation ///< Optimized for fixed stride and dilation
108
+ };
109
+
110
+ /// Distinguishes among partial specializations that accelerate certain problems where convolution
111
+ /// stride is unit.
112
+ enum class StrideSupport {
113
+ kStrided, ///< arbitrary convolution stride
114
+ kUnity, ///< unit convolution stride
115
+ kFixed ///< fixed convolution stride
116
+ };
117
+
118
+ /// Identifies split-K mode
119
+ enum class SplitKMode {
120
+ kNone,
121
+ kSerial,
122
+ kParallel
123
+ };
124
+
125
+ /// Identifies group mode
126
+ enum class GroupMode {
127
+ kNone,
128
+ kSingleGroup, ///< One CTA calculates one group or less
129
+ kMultipleGroup, ///< One CTA calculates multiple groups
130
+ kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups)
131
+ };
132
+
133
+ /////////////////////////////////////////////////////////////////////////////////////////////////
134
+
135
+ /// Shape of a tensor
136
+ template <
137
+ int N = 1,
138
+ int H = 1,
139
+ int W = 1,
140
+ int C = 1
141
+ >
142
+ struct TensorNHWCShape {
143
+ static int const kN = N;
144
+ static int const kH = H;
145
+ static int const kW = W;
146
+ static int const kC = C;
147
+
148
+ static int const kHW = H * W;
149
+ static int const kNHW = N * kHW;
150
+ static int const kNHWC = N * H * W * C;
151
+
152
+ static int const kCount = kNHWC;
153
+
154
+ //
155
+ // Static member functions
156
+ //
157
+
158
+ /// Returns a Coord object
159
+ CUTLASS_HOST_DEVICE
160
+ static Coord<4> toCoord() {
161
+ return make_Coord(kN, kH, kW, kC);
162
+ }
163
+ };
164
+
165
+ /////////////////////////////////////////////////////////////////////////////////////////////////
166
+
167
+ /// Shape of a conv2d stride, which controls how the filter convolves around the input volume
168
+ template <
169
+ /// Stride in horizontal direction
170
+ int u = 1,
171
+ /// Stride in vertical direction
172
+ int v = 1
173
+ >
174
+ struct Stride2D {
175
+ static int const kU = u;
176
+ static int const kV = v;
177
+
178
+ //
179
+ // Static member functions
180
+ //
181
+
182
+ /// Returns a Coord object
183
+ CUTLASS_HOST_DEVICE
184
+ static Coord<2> toCoord() {
185
+ return make_Coord(kU, kV);
186
+ }
187
+ };
188
+
189
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
190
+
191
+ } // namespace conv
192
+ } // namespace cutlass
193
+
194
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /***************************************************************************************************
3
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ *
6
+ * Redistribution and use in source and binary forms, with or without
7
+ * modification, are permitted provided that the following conditions are met:
8
+ *
9
+ * 1. Redistributions of source code must retain the above copyright notice, this
10
+ * list of conditions and the following disclaimer.
11
+ *
12
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ * this list of conditions and the following disclaimer in the documentation
14
+ * and/or other materials provided with the distribution.
15
+ *
16
+ * 3. Neither the name of the copyright holder nor the names of its
17
+ * contributors may be used to endorse or promote products derived from
18
+ * this software without specific prior written permission.
19
+ *
20
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ *
31
+ **************************************************************************************************/
32
+ #pragma once
33
+
34
+ #include "cutlass/conv/convnd_problem_shape.hpp"
35
+
36
+ /////////////////////////////////////////////////////////////////////////////////////////////////
37
+
38
+ namespace cutlass::conv::detail {
39
+
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ // Helper function to get the problem shape
43
+ template <typename T, class ProblemShape>
44
+ auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) {
45
+ return T::get_problem_shape_MNKL(problem_shape);
46
+ }
47
+
48
+ template <typename T, class ProblemShape>
49
+ ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) {
50
+ return problem_shape;
51
+ }
52
+
53
+ // Get problem shape MNKL according to following table:
54
+ // | | Fprop | Dgrad | Wgrad |
55
+ // | ---- | --------- | -------- | -------- |
56
+ // | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) |
57
+ // | Shape_N | (K) | (C) | (C,S,R,T) |
58
+ // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) |
59
+ // | Shape_L | _1 | (V,U,O) | _1 |
60
+
61
+ template <class ProblemShape>
62
+ CUTLASS_HOST_DEVICE
63
+ constexpr auto
64
+ get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) {
65
+ return problem_shape;
66
+ }
67
+
68
+
69
+ template <conv::Operator ConvOp, int SpatialDim>
70
+ CUTLASS_HOST_DEVICE
71
+ constexpr auto
72
+ get_transformed_problem_shape_MNKL(ConvProblemShape<ConvOp, SpatialDim> const& problem_shape) {
73
+ using cute::insert;
74
+ using cute::make_shape;
75
+ using cute::reverse;
76
+ using cute::take;
77
+
78
+ constexpr int RankT = SpatialDim + 2;
79
+
80
+ if constexpr (ConvOp == conv::Operator::kWgrad) {
81
+ auto M_xformed = problem_shape.shape_C[0];
82
+ auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C));
83
+ auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A));
84
+ auto L_xformed = cute::Int<1>{};
85
+
86
+ return make_shape(M_xformed, N_xformed, K_xformed, L_xformed);
87
+ }
88
+ else if constexpr (ConvOp == conv::Operator::kFprop){
89
+ auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C));
90
+ auto N_xformed = problem_shape.shape_C[RankT - 1];
91
+ auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B));
92
+ auto L_xformed = cute::Int<1>{};
93
+
94
+ return make_shape(M_xformed, N_xformed, K_xformed, L_xformed);
95
+ }
96
+ else if constexpr (ConvOp == conv::Operator::kDgrad) {
97
+ auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O)
98
+ auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed);
99
+ auto N_xformed = problem_shape.shape_C[RankT - 1];
100
+ // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T]
101
+ auto K_xformed = insert<0>(
102
+ (reverse(take<1,RankT - 1>(problem_shape.shape_B))),
103
+ problem_shape.shape_B[0]);
104
+
105
+ return make_shape(M_xformed, N_xformed, K_xformed, L_xformed);
106
+ }
107
+ }
108
+
109
+ // Assuming im2col linearization
110
+ // Get problem shape MNKL according to following table:
111
+ // | | Fprop | Dgrad | Wgrad |
112
+ // | ---- | --------- | -------- | -------- |
113
+ // | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) |
114
+ // | Shape_N | (K) | (C) | (C,S,R,T) |
115
+ // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) |
116
+ // | Shape_L | _1 | (V*U*O) | _1 |
117
+ template <conv::Operator ConvOp, int SpatialDim>
118
+ CUTLASS_HOST_DEVICE
119
+ constexpr auto
120
+ get_linearized_problem_shape_MNKL(ConvProblemShape<ConvOp, SpatialDim> const& problem_shape) {
121
+
122
+ auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape);
123
+
124
+ if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) {
125
+ return cute::make_shape(cute::product(M), N, K, cute::product(L));
126
+ }
127
+ else if constexpr (ConvOp == conv::Operator::kWgrad) {
128
+ return cute::make_shape(M, N, cute::product(K), L);
129
+ }
130
+
131
+ }
132
+
133
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
134
+
135
+ } // namespace cutlass::conv::detail
136
+
137
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ // common
34
+ #include "cutlass/arch/mma.h"
35
+ #include "cutlass/cutlass.h"
36
+ #include "cutlass/arch/mma.h"
37
+ #include "cutlass/trace.h"
38
+ #include "cutlass/cluster_launch.hpp"
39
+ #include "cutlass/device_kernel.h"
40
+
41
+ #include "cutlass/conv/kernel/conv_universal.hpp"
42
+ #include "cutlass/gemm/gemm.h"
43
+ #include "cutlass/detail/layout.hpp"
44
+ #include "cutlass/cuda_host_adapter.hpp"
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass::conv::device {
49
+
50
+ ////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /*!
53
+ ConvUniversalAdapter is a stateful, reusable handle built around a kernel
54
+ of type cutlass::conv::kernel::ConvUniversal.
55
+
56
+ It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs
57
+ to create it from the host facing arguments. For power users, static methods
58
+ are exposed that bypass the stateful methods or args->params lowering.
59
+ */
60
+ template <class ConvKernel_>
61
+ class ConvUniversalAdapter
62
+ {
63
+ public:
64
+ using ConvKernel = GetUnderlyingKernel_t<ConvKernel_>;
65
+ using TileShape = typename ConvKernel::TileShape;
66
+ using ElementA = typename ConvKernel::ElementA;
67
+ using ElementB = typename ConvKernel::ElementB;
68
+ using ElementC = typename ConvKernel::ElementC;
69
+ using ElementD = typename ConvKernel::ElementD;
70
+ using ElementAccumulator = typename ConvKernel::TiledMma::ValTypeC;
71
+ using DispatchPolicy = typename ConvKernel::DispatchPolicy;
72
+ using CollectiveMainloop = typename ConvKernel::CollectiveMainloop;
73
+ using CollectiveEpilogue = typename ConvKernel::CollectiveEpilogue;
74
+
75
+ static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
76
+
77
+ // Tease out meta-information about the conv algorithm
78
+ static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp;
79
+ static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions;
80
+
81
+ // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop!
82
+ using OperatorClass = cute::conditional_t<
83
+ (cute::size(typename ConvKernel::TiledMma::AtomThrID{}) > 1),
84
+ cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>;
85
+
86
+ using ArchTag = typename ConvKernel::ArchTag;
87
+
88
+ // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape
89
+ using ThreadblockShape = cutlass::gemm::GemmShape<
90
+ cute::size<0>(TileShape{}),
91
+ cute::size<1>(TileShape{}),
92
+ cute::size<2>(TileShape{})>;
93
+
94
+ using ClusterShape = cutlass::gemm::GemmShape<
95
+ cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
96
+ cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
97
+ cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})>;
98
+
99
+ // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape
100
+ using InstructionShape = cutlass::gemm::GemmShape<
101
+ cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}),
102
+ cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}),
103
+ cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>;
104
+
105
+ // Legacy: provide a correct warp count, but no reliable warp shape
106
+ static int const kThreadCount = ConvKernel::MaxThreadsPerBlock;
107
+
108
+ // Warp shape is not a primary API type in 3.x
109
+ // But we can best approximate it by inspecting the TiledMma
110
+ // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K
111
+ // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads
112
+ static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename ConvKernel::TiledMma{})) / 32);
113
+ static constexpr int WarpsInMmaM = 4;
114
+ static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
115
+ using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>;
116
+ using WarpShape = cutlass::gemm::GemmShape<
117
+ CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM,
118
+ CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN,
119
+ CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>;
120
+
121
+ static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages;
122
+
123
+ // Inspect TiledCopy for A and B to compute the alignment size
124
+ static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
125
+ typename CollectiveMainloop::GmemTiledCopyA, ElementA>();
126
+ static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
127
+ typename CollectiveMainloop::GmemTiledCopyB, ElementB>();
128
+ static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
129
+ typename CollectiveEpilogue::GmemTiledCopyC, ElementC>();
130
+ static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
131
+ typename CollectiveEpilogue::GmemTiledCopyD, ElementD>();
132
+
133
+ using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp;
134
+
135
+ /// Argument structure: User API
136
+ using Arguments = typename ConvKernel::Arguments;
137
+ /// Argument structure: Kernel API
138
+ using Params = typename ConvKernel::Params;
139
+
140
+ private:
141
+
142
+ /// Kernel API parameters object
143
+ Params params_;
144
+
145
+ public:
146
+
147
+ /// Access the Params structure
148
+ Params const& params() const {
149
+ return params_;
150
+ }
151
+
152
+ /// Determines whether the conv can execute the given problem.
153
+ static Status
154
+ can_implement(Arguments const& args) {
155
+ if (ConvKernel::can_implement(args)) {
156
+ return Status::kSuccess;
157
+ }
158
+ else {
159
+ return Status::kInvalid;
160
+ }
161
+ }
162
+
163
+ /// Gets the workspace size
164
+ static size_t
165
+ get_workspace_size(Arguments const& args) {
166
+ size_t workspace_bytes = 0;
167
+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
168
+
169
+ workspace_bytes += ConvKernel::get_workspace_size(args);
170
+ return workspace_bytes;
171
+ }
172
+
173
+ /// Computes the grid shape
174
+ static dim3
175
+ get_grid_shape(Arguments const& args, void* workspace = nullptr) {
176
+ auto tmp_params = ConvKernel::to_underlying_arguments(args, workspace);
177
+ return ConvKernel::get_grid_shape(tmp_params);
178
+ }
179
+
180
+ /// Computes the grid shape
181
+ static dim3
182
+ get_grid_shape(Params const& params) {
183
+ return ConvKernel::get_grid_shape(params);
184
+ }
185
+
186
+ /// Computes the maximum number of active blocks per multiprocessor
187
+ static int maximum_active_blocks(int /* smem_capacity */ = -1) {
188
+ CUTLASS_TRACE_HOST("ConvUniversal::maximum_active_blocks()");
189
+ int max_active_blocks = -1;
190
+ int smem_size = ConvKernel::SharedStorageSize;
191
+
192
+ // first, account for dynamic smem capacity if needed
193
+ cudaError_t result;
194
+ if (smem_size >= (48 << 10)) {
195
+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
196
+ result = cudaFuncSetAttribute(
197
+ device_kernel<ConvKernel>,
198
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
199
+ smem_size);
200
+ if (cudaSuccess != result) {
201
+ result = cudaGetLastError(); // to clear the error bit
202
+ CUTLASS_TRACE_HOST(
203
+ " cudaFuncSetAttribute() returned error: "
204
+ << cudaGetErrorString(result));
205
+ return -1;
206
+ }
207
+ }
208
+
209
+ // query occupancy after setting smem size
210
+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
211
+ &max_active_blocks,
212
+ device_kernel<ConvKernel>,
213
+ ConvKernel::MaxThreadsPerBlock,
214
+ smem_size);
215
+
216
+ if (cudaSuccess != result) {
217
+ result = cudaGetLastError(); // to clear the error bit
218
+ CUTLASS_TRACE_HOST(
219
+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
220
+ << cudaGetErrorString(result));
221
+ return -1;
222
+ }
223
+
224
+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
225
+ return max_active_blocks;
226
+ }
227
+
228
+ /// Initializes conv state from arguments.
229
+ Status
230
+ initialize(
231
+ Arguments const& args,
232
+ void* workspace = nullptr,
233
+ cudaStream_t stream = nullptr,
234
+ CudaHostAdapter *cuda_adapter = nullptr) {
235
+
236
+ CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace "
237
+ << workspace << ", stream: " << (stream ? "non-null" : "null"));
238
+
239
+ // Initialize the workspace
240
+ Status status = ConvKernel::initialize_workspace(args, workspace, stream, cuda_adapter);
241
+ if (status != Status::kSuccess) {
242
+ return status;
243
+ }
244
+
245
+ // Initialize the Params structure
246
+ params_ = ConvKernel::to_underlying_arguments(args, workspace);
247
+
248
+ // Don't set the function attributes - require the CudaHostAdapter to set it.
249
+ if constexpr (kEnableCudaHostAdapter) {
250
+ CUTLASS_ASSERT(cuda_adapter);
251
+ return Status::kSuccess;
252
+ }
253
+ else {
254
+ // account for dynamic smem capacity if needed
255
+ int smem_size = ConvKernel::SharedStorageSize;
256
+ if (smem_size >= (48 << 10)) {
257
+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
258
+ cudaError_t result = cudaFuncSetAttribute(
259
+ device_kernel<ConvKernel>,
260
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
261
+ smem_size);
262
+ if (cudaSuccess != result) {
263
+ result = cudaGetLastError(); // to clear the error bit
264
+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
265
+ return Status::kErrorInternal;
266
+ }
267
+ }
268
+ }
269
+ return Status::kSuccess;
270
+ }
271
+
272
+ /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
273
+ Status
274
+ update(Arguments const& args, void* workspace = nullptr) {
275
+ CUTLASS_TRACE_HOST("ConvUniversal()::update() - workspace: " << workspace);
276
+
277
+ size_t workspace_bytes = get_workspace_size(args);
278
+ if (workspace_bytes > 0 && nullptr == workspace) {
279
+ return Status::kErrorWorkspaceNull;
280
+ }
281
+
282
+ params_ = ConvKernel::to_underlying_arguments(args, workspace);
283
+ return Status::kSuccess;
284
+ }
285
+
286
+ /// Primary run() entry point API that is static allowing users to create and manage their own params.
287
+ /// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments()
288
+ static Status
289
+ run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
290
+ CUTLASS_TRACE_HOST("ConvUniversal::run()");
291
+ dim3 const block = ConvKernel::get_block_shape();
292
+ dim3 const grid = get_grid_shape(params);
293
+
294
+ // configure smem size and carveout
295
+ int smem_size = ConvKernel::SharedStorageSize;
296
+
297
+ Status launch_result;
298
+ // Use extended launch API only for mainloops that use it
299
+ if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
300
+ [[maybe_unused]] constexpr bool is_static_1x1x1 =
301
+ cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and
302
+ cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1;
303
+ dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
304
+ cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
305
+ cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
306
+ // Dynamic cluster support
307
+ [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0};
308
+ if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
309
+ ConvKernel::ArchTag::kMinComputeCapability == 101) {
310
+ if constexpr (!cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape>) {
311
+ fallback_cluster = params.hw_info.cluster_shape_fallback;
312
+ cluster = params.hw_info.cluster_shape;
313
+ }
314
+ }
315
+
316
+ void* kernel_params[] = {&params};
317
+ if constexpr (kEnableCudaHostAdapter) {
318
+ //
319
+ // Use the cuda host adapter
320
+ //
321
+ CUTLASS_ASSERT(cuda_adapter);
322
+ if (cuda_adapter) {
323
+
324
+ launch_result = cuda_adapter->launch(grid,
325
+ cluster,
326
+ fallback_cluster,
327
+ block,
328
+ smem_size,
329
+ stream,
330
+ kernel_params,
331
+ kernel_index);
332
+ }
333
+ else {
334
+ return Status::kErrorInternal;
335
+ }
336
+ }
337
+ else {
338
+ CUTLASS_ASSERT(cuda_adapter == nullptr);
339
+ void const* kernel = (void const*) device_kernel<ConvKernel>;
340
+ if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90
341
+ || ConvKernel::ArchTag::kMinComputeCapability == 100
342
+ ) {
343
+ if constexpr (is_static_1x1x1) {
344
+ device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
345
+ launch_result = Status::kSuccess;
346
+ }
347
+ else {
348
+ launch_result = ClusterLauncher::launch(
349
+ grid, cluster, block, smem_size, stream, kernel, kernel_params);
350
+ }
351
+ }
352
+ else {
353
+ if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
354
+ ConvKernel::ArchTag::kMinComputeCapability == 101) {
355
+ launch_result = ClusterLauncher::launch_with_fallback_cluster(
356
+ grid,
357
+ cluster,
358
+ fallback_cluster,
359
+ block,
360
+ smem_size,
361
+ stream,
362
+ kernel,
363
+ kernel_params);
364
+ }
365
+ }
366
+ }
367
+ }
368
+ else {
369
+ launch_result = Status::kSuccess;
370
+
371
+ if constexpr (kEnableCudaHostAdapter) {
372
+ CUTLASS_ASSERT(cuda_adapter);
373
+ if (cuda_adapter) {
374
+ void* kernel_params[] = {&params};
375
+
376
+ launch_result = cuda_adapter->launch(
377
+ grid, block, smem_size, stream, kernel_params, 0
378
+ );
379
+
380
+ }
381
+ else {
382
+ return Status::kErrorInternal;
383
+ }
384
+ }
385
+ else {
386
+ CUTLASS_ASSERT(cuda_adapter == nullptr);
387
+ device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
388
+ }
389
+ }
390
+
391
+ cudaError_t result = cudaGetLastError();
392
+ if (cudaSuccess == result && Status::kSuccess == launch_result) {
393
+ return Status::kSuccess;
394
+ }
395
+ else {
396
+ CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
397
+ return Status::kErrorInternal;
398
+ }
399
+ }
400
+
401
+ //
402
+ // Non-static launch overloads that first create and set the internal params struct of this kernel handle.
403
+ //
404
+
405
+ /// Launches the kernel after first constructing Params internal state from supplied arguments.
406
+ Status
407
+ run(
408
+ Arguments const& args,
409
+ void* workspace = nullptr,
410
+ cudaStream_t stream = nullptr,
411
+ CudaHostAdapter *cuda_adapter = nullptr,
412
+ int32_t kernel_index = 0
413
+ ) {
414
+ Status status = initialize(args, workspace, stream, cuda_adapter);
415
+ if (Status::kSuccess == status) {
416
+ status = run(params_, stream, cuda_adapter, kernel_index);
417
+ }
418
+ return status;
419
+ }
420
+
421
+ /// Launches the kernel after first constructing Params internal state from supplied arguments.
422
+ Status
423
+ operator()(
424
+ Arguments const& args,
425
+ void* workspace = nullptr,
426
+ cudaStream_t stream = nullptr,
427
+ CudaHostAdapter *cuda_adapter = nullptr) {
428
+ return run(args, workspace, stream, cuda_adapter);
429
+ }
430
+
431
+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct.
432
+ Status
433
+ run(cudaStream_t stream = nullptr) {
434
+ return run(params_, stream);
435
+ }
436
+
437
+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct.
438
+ Status
439
+ operator()(cudaStream_t stream = nullptr) {
440
+ return run(params_, stream);
441
+ }
442
+ };
443
+
444
+ ////////////////////////////////////////////////////////////////////////////////
445
+
446
+ } // namespace cutlass::conv::device
447
+
448
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Template for device-level Depthwise Convolution
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <limits>
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/device_kernel.h"
41
+ #include "cutlass/conv/convolution.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace conv {
47
+ namespace device {
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ template<typename DirectConvolutionKernel_>
52
+ class DirectConvolution {
53
+ public:
54
+
55
+ using UnderlyingKernel = DirectConvolutionKernel_;
56
+
57
+ using ElementA = typename UnderlyingKernel::ElementA;
58
+ using LayoutA = typename UnderlyingKernel::LayoutA;
59
+ using ElementB = typename UnderlyingKernel::ElementB;
60
+ using LayoutB = typename UnderlyingKernel::LayoutB;
61
+ using ElementC = typename UnderlyingKernel::ElementC;
62
+ using LayoutC = typename UnderlyingKernel::LayoutC;
63
+ using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
64
+ using ElementCompute = typename UnderlyingKernel::ElementCompute;
65
+ using OperatorClass = typename UnderlyingKernel::OperatorClass;
66
+ using ArchTag = typename UnderlyingKernel::ArchTag;
67
+ using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
68
+ using WarpShape = typename UnderlyingKernel::WarpShape;
69
+ using InstructionShape = typename UnderlyingKernel::InstructionShape;
70
+ using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
71
+ using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
72
+ static int const kStages = UnderlyingKernel::kStages;
73
+ static int const kConvDim = UnderlyingKernel::kConvDim;
74
+ using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
75
+ using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
76
+ using MathOperator = typename UnderlyingKernel::MathOperator;
77
+
78
+ static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
79
+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
80
+ static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
81
+ static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
82
+
83
+ static int const kWarpCount =
84
+ (ThreadblockShape::kM / WarpShape::kM) *
85
+ (ThreadblockShape::kN / WarpShape::kN) *
86
+ (ThreadblockShape::kK / WarpShape::kK);
87
+
88
+ /// Argument structure
89
+ using Arguments = typename UnderlyingKernel::Arguments;
90
+
91
+ using ReorderKernel = typename UnderlyingKernel::ReorderKernel;
92
+
93
+ private:
94
+
95
+ /// Kernel parameters object
96
+ typename UnderlyingKernel::Params params_;
97
+
98
+ public:
99
+
100
+ /// Constructs Implicit GEMM
101
+ DirectConvolution() { }
102
+
103
+ /// Determines whether the Implicit GEMM can execute the given problem.
104
+ static Status can_implement(Arguments const &args) {
105
+
106
+ // dispatch to iterators
107
+ Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
108
+ if (Status::kSuccess != status) {
109
+ return status;
110
+ }
111
+
112
+ status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
113
+ if (Status::kSuccess != status) {
114
+ return status;
115
+ }
116
+
117
+ if (kGroupMode != conv::GroupMode::kDepthwise) {
118
+ return Status::kErrorInvalidProblem;
119
+ }
120
+
121
+ // C and K should be multiple of groups
122
+ if (args.problem_size.K != args.problem_size.groups &&
123
+ args.problem_size.C != args.problem_size.groups) {
124
+ return Status::kErrorInvalidProblem;
125
+ }
126
+
127
+
128
+ static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
129
+ if (kConvolutionalOperator == conv::Operator::kFprop) {
130
+ if (args.problem_size.K % kAlignmentC)
131
+ return Status::kErrorMisalignedOperand;
132
+ } else if (kConvolutionalOperator == conv::Operator::kDgrad) {
133
+ if (args.problem_size.C % kAlignmentC)
134
+ return Status::kErrorMisalignedOperand;
135
+ } else if (kConvolutionalOperator == conv::Operator::kWgrad) {
136
+ if (args.problem_size.C % kAlignmentC)
137
+ return Status::kErrorMisalignedOperand;
138
+ }
139
+
140
+ // Determine grid shape
141
+ ThreadblockSwizzle threadblock_swizzle;
142
+
143
+ dim3 grid = threadblock_swizzle.get_grid_shape(
144
+ threadblock_swizzle.get_tiled_shape(
145
+ kConvolutionalOperator,
146
+ args.problem_size,
147
+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
148
+ args.problem_size.split_k_slices));
149
+
150
+ if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
151
+ grid.z <= std::numeric_limits<uint16_t>::max())) {
152
+
153
+ return Status::kErrorInvalidProblem;
154
+ }
155
+
156
+ return Status::kSuccess;
157
+ }
158
+
159
+ /// Gets the workspace size
160
+ static size_t get_workspace_size(Arguments const &args) {
161
+ return 0;
162
+ }
163
+
164
+ /// Initializes GEMM state from arguments.
165
+ Status initialize(
166
+ Arguments const &args,
167
+ void *workspace = nullptr,
168
+ cudaStream_t stream = nullptr) {
169
+
170
+ // initialize the params structure from the arguments
171
+ params_ = typename UnderlyingKernel::Params(
172
+ args,
173
+ static_cast<int *>(workspace)
174
+ );
175
+
176
+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
177
+
178
+ if (smem_size >= (48 << 10)) {
179
+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
180
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
181
+ smem_size);
182
+
183
+ if (result != cudaSuccess) {
184
+ return Status::kErrorInternal;
185
+ }
186
+ }
187
+
188
+ return Status::kSuccess;
189
+ }
190
+
191
+ /// Initializes GEMM state from arguments.
192
+ Status update(Arguments const &args, void *workspace = nullptr) {
193
+
194
+ // update the params structure from the arguments
195
+ params_.ptr_A = args.ref_A.data();
196
+ params_.ptr_B = args.ref_B.data();
197
+ params_.ptr_C = args.ref_C.data();
198
+ params_.ptr_D = args.ref_D.data();
199
+ params_.output_op = args.output_op;
200
+ params_.ptr_reordered_B = args.ref_reordered_B.data();
201
+ params_.semaphore = static_cast<int *>(workspace);
202
+
203
+ return Status::kSuccess;
204
+ }
205
+
206
+ /// Runs the kernel using initialized state.
207
+ Status run(cudaStream_t stream = nullptr) {
208
+
209
+ // Launch reorder kernel
210
+ if (params_.ptr_reordered_B != nullptr) {
211
+ dim3 grid = ReorderKernel::get_grid_shape(params_);
212
+ dim3 block = ReorderKernel::get_block_shape();
213
+
214
+ cutlass::arch::synclog_setup();
215
+ cutlass::Kernel<ReorderKernel><<<grid, block, 0, stream>>>(params_);
216
+ }
217
+
218
+ // Launch main kernel
219
+ ThreadblockSwizzle threadblock_swizzle;
220
+
221
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
222
+ dim3 block(32 * kWarpCount, 1, 1);
223
+
224
+ // Dynamic SMEM size based on input params.
225
+ int smem_size = int(params_.get_smem_size());
226
+
227
+ // Make sure we can use that much shared memory.
228
+ cudaError_t status =
229
+ cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
230
+ if (status != cudaSuccess)
231
+ return Status::kErrorInternal;
232
+
233
+ cutlass::arch::synclog_setup();
234
+ cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
235
+
236
+ cudaError_t result = cudaGetLastError();
237
+
238
+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
239
+ }
240
+
241
+ /// Runs the kernel using initialized state.
242
+ Status operator()(cudaStream_t stream = nullptr) {
243
+ return run(stream);
244
+ }
245
+
246
+ /// Runs the kernel using initialized state.
247
+ Status operator()(
248
+ Arguments const &args,
249
+ void *workspace = nullptr,
250
+ cudaStream_t stream = nullptr) {
251
+
252
+ Status status = initialize(args, workspace, stream);
253
+
254
+ if (status == Status::kSuccess) {
255
+ status = run(stream);
256
+ }
257
+
258
+ return status;
259
+ }
260
+
261
+ int get_smem_size() { return int(params_.get_smem_size()); }
262
+ };
263
+
264
+ /////////////////////////////////////////////////////////////////////////////////////////////////
265
+
266
+ }
267
+ }
268
+ }
269
+
270
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Template for device-level Implicit GEMM Convolution
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <limits>
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/device_kernel.h"
41
+ #include "cutlass/conv/convolution.h"
42
+ #include "cutlass/cuda_host_adapter.hpp"
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace cutlass {
47
+ namespace conv {
48
+ namespace device {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ template<typename ImplicitGemmKernel_>
53
+ class ImplicitGemmConvolution {
54
+ public:
55
+
56
+ using UnderlyingKernel = GetUnderlyingKernel_t<ImplicitGemmKernel_>;
57
+
58
+ using ElementA = typename UnderlyingKernel::ElementA;
59
+ using LayoutA = typename UnderlyingKernel::LayoutA;
60
+ using ElementB = typename UnderlyingKernel::ElementB;
61
+ using LayoutB = typename UnderlyingKernel::LayoutB;
62
+ using ElementC = typename UnderlyingKernel::ElementC;
63
+ using LayoutC = typename UnderlyingKernel::LayoutC;
64
+ using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
65
+ using ElementCompute = typename UnderlyingKernel::ElementCompute;
66
+ using OperatorClass = typename UnderlyingKernel::OperatorClass;
67
+ using ArchTag = typename UnderlyingKernel::ArchTag;
68
+ using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
69
+ using WarpShape = typename UnderlyingKernel::WarpShape;
70
+ using InstructionShape = typename UnderlyingKernel::InstructionShape;
71
+ using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
72
+ using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
73
+ static int const kStages = UnderlyingKernel::kStages;
74
+ static int const kConvDim = UnderlyingKernel::kConvDim;
75
+ using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
76
+ using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
77
+ using MathOperator = typename UnderlyingKernel::MathOperator;
78
+
79
+ static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
80
+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
81
+ static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
82
+ static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
83
+
84
+ static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
85
+
86
+ static int const kWarpCount =
87
+ (ThreadblockShape::kM / WarpShape::kM) *
88
+ (ThreadblockShape::kN / WarpShape::kN) *
89
+ (ThreadblockShape::kK / WarpShape::kK);
90
+
91
+ /// Argument structure
92
+ using Arguments = typename UnderlyingKernel::Arguments;
93
+
94
+ private:
95
+
96
+ /// Kernel parameters object
97
+ typename UnderlyingKernel::Params params_;
98
+
99
+ public:
100
+
101
+ /// Constructs Implicit GEMM
102
+ ImplicitGemmConvolution() { }
103
+
104
+ /// Determines whether the Implicit GEMM can execute the given problem.
105
+ static Status can_implement(Arguments const &args) {
106
+ // dispatch to iterators
107
+ Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
108
+ if (Status::kSuccess != status) {
109
+ return status;
110
+ }
111
+
112
+ status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
113
+ if (Status::kSuccess != status) {
114
+ return status;
115
+ }
116
+
117
+ // Check that tensor sizes don't exceed maximum supported size
118
+ if (kConvolutionalOperator == conv::Operator::kFprop) {
119
+ if (args.problem_size.activation_size() * sizeof(ElementA) >=
120
+ (1ull << 31) ||
121
+ args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
122
+ args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) {
123
+ return Status::kErrorInvalidProblem;
124
+ }
125
+ }
126
+ else if (kConvolutionalOperator == conv::Operator::kDgrad ||
127
+ kConvolutionalOperator == conv::Operator::kDeconv) {
128
+ if (args.problem_size.activation_size() * sizeof(ElementC) >=
129
+ (1ull << 31) ||
130
+ args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
131
+ args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
132
+ return Status::kErrorInvalidProblem;
133
+ }
134
+ }
135
+ else if (kConvolutionalOperator == conv::Operator::kWgrad) {
136
+ if (args.problem_size.activation_size() * sizeof(ElementB) >=
137
+ (1ull << 31) ||
138
+ args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) ||
139
+ args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
140
+ return Status::kErrorInvalidProblem;
141
+ }
142
+ }
143
+
144
+ // check group conv constraint
145
+ if (args.problem_size.groups != 1) {
146
+ if (kGroupMode == conv::GroupMode::kNone) {
147
+ return Status::kErrorInvalidProblem;
148
+ }
149
+
150
+ // C and K should be multiple of groups
151
+ if (args.problem_size.K % args.problem_size.groups ||
152
+ args.problem_size.C % args.problem_size.groups) {
153
+ return Status::kErrorInvalidProblem;
154
+ }
155
+
156
+ // split-k is not supported
157
+ if (args.problem_size.split_k_slices != 1) {
158
+ return Status::kErrorInvalidProblem;
159
+ }
160
+
161
+ int k_per_group = args.problem_size.K / args.problem_size.groups;
162
+ // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group
163
+ if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) {
164
+ return Status::kErrorInvalidProblem;
165
+ }
166
+ // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups
167
+ if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) {
168
+ return Status::kErrorInvalidProblem;
169
+ }
170
+
171
+ // current optimized iterator algo only supports SingleGroup mode
172
+ if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized &&
173
+ kGroupMode != conv::GroupMode::kSingleGroup) {
174
+ return Status::kErrorInvalidProblem;
175
+ }
176
+ }
177
+
178
+ static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
179
+ if (kConvolutionalOperator == conv::Operator::kFprop) {
180
+ if (args.problem_size.K % kAlignmentC)
181
+ return Status::kErrorMisalignedOperand;
182
+ } else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) {
183
+ if (args.problem_size.C % kAlignmentC)
184
+ return Status::kErrorMisalignedOperand;
185
+ } else if (kConvolutionalOperator == conv::Operator::kWgrad) {
186
+ if (args.problem_size.C % kAlignmentC)
187
+ return Status::kErrorMisalignedOperand;
188
+ }
189
+
190
+ // check for unsupported problem sizes for strided dgrad / deconv implementation
191
+ if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) &&
192
+ kStrideSupport == conv::StrideSupport::kStrided) {
193
+ // split-k (serial or parallel) is not supported for strided dgrad / deconv
194
+ if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) {
195
+ return Status::kErrorNotSupported;
196
+ }
197
+
198
+ // dilation > {1x1} is not supported for strided dgrad / deconv
199
+ if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) {
200
+ return Status::kErrorNotSupported;
201
+ }
202
+ }
203
+
204
+ // Determine grid shape
205
+ ThreadblockSwizzle threadblock_swizzle;
206
+
207
+ dim3 grid = threadblock_swizzle.get_grid_shape(
208
+ threadblock_swizzle.get_tiled_shape(
209
+ kConvolutionalOperator,
210
+ args.problem_size,
211
+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
212
+ args.problem_size.split_k_slices));
213
+
214
+ if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
215
+ grid.z <= std::numeric_limits<uint16_t>::max())) {
216
+
217
+ return Status::kErrorInvalidProblem;
218
+ }
219
+
220
+ return Status::kSuccess;
221
+ }
222
+
223
+ /// Gets the workspace size
224
+ static size_t get_workspace_size(Arguments const &args) {
225
+
226
+ size_t workspace_bytes = 0;
227
+
228
+ // Determine grid shape
229
+ ThreadblockSwizzle threadblock_swizzle;
230
+
231
+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
232
+ kConvolutionalOperator,
233
+ args.problem_size,
234
+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
235
+ args.problem_size.split_k_slices);
236
+
237
+ if(args.split_k_mode == SplitKMode::kParallel) {
238
+
239
+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
240
+ // The user needs to call a reduction operator to optain the final output tensor
241
+ workspace_bytes =
242
+ sizeof(ElementAccumulator) *
243
+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
244
+ size_t(grid_tiled_shape.k());
245
+ }
246
+
247
+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
248
+
249
+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the
250
+ // final reduced output to user's output tensor
251
+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
252
+ }
253
+
254
+ return workspace_bytes;
255
+ }
256
+
257
+ /// Initializes GEMM state from arguments.
258
+ Status initialize(
259
+ Arguments const &args,
260
+ void *workspace = nullptr,
261
+ cudaStream_t stream = nullptr,
262
+ CudaHostAdapter *cuda_adapter = nullptr) {
263
+
264
+ if (args.problem_size.split_k_slices > 1) {
265
+
266
+ if (!workspace) {
267
+ return Status::kErrorWorkspaceNull;
268
+ }
269
+
270
+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
271
+
272
+ if (status != cudaSuccess) {
273
+ return Status::kErrorInternal;
274
+ }
275
+ }
276
+
277
+ // initialize the params structure from the arguments
278
+ params_ = typename UnderlyingKernel::Params(
279
+ args,
280
+ static_cast<int *>(workspace)
281
+ );
282
+
283
+ if constexpr (kEnableCudaHostAdapter) {
284
+ CUTLASS_ASSERT(cuda_adapter);
285
+ return Status::kSuccess;
286
+ }
287
+ else {
288
+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
289
+
290
+ if (smem_size >= (48 << 10)) {
291
+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
292
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
293
+ smem_size);
294
+
295
+ if (result != cudaSuccess) {
296
+ return Status::kErrorInternal;
297
+ }
298
+ }
299
+ }
300
+
301
+ return Status::kSuccess;
302
+ }
303
+
304
+ /// Initializes GEMM state from arguments.
305
+ Status update(Arguments const &args, void *workspace = nullptr) {
306
+
307
+ // update the params structure from the arguments
308
+ params_.ptr_A = args.ref_A.data();
309
+ params_.ptr_B = args.ref_B.data();
310
+ params_.ptr_C = args.ref_C.data();
311
+ params_.ptr_D = args.ref_D.data();
312
+ params_.output_op = args.output_op;
313
+ params_.semaphore = static_cast<int *>(workspace);
314
+
315
+ return Status::kSuccess;
316
+ }
317
+
318
+ /// Runs the kernel using initialized state.
319
+ Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
320
+
321
+
322
+ ThreadblockSwizzle threadblock_swizzle;
323
+
324
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
325
+ dim3 block(32 * kWarpCount, 1, 1);
326
+
327
+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
328
+ cutlass::Status launch_result = cutlass::Status::kSuccess ;
329
+
330
+ if constexpr (kEnableCudaHostAdapter) {
331
+ //
332
+ // Use the cuda host adapter
333
+ //
334
+ CUTLASS_ASSERT(cuda_adapter);
335
+ if (cuda_adapter) {
336
+
337
+ void* kernel_params[] = {&params_};
338
+ launch_result = cuda_adapter->launch(
339
+ grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index
340
+ );
341
+ }
342
+ else {
343
+ launch_result = Status::kErrorInternal;
344
+ }
345
+ }
346
+ else {
347
+ cutlass::arch::synclog_setup();
348
+ cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
349
+ }
350
+
351
+ cudaError_t result = cudaGetLastError();
352
+ if (cudaSuccess == result && Status::kSuccess == launch_result) {
353
+ return Status::kSuccess;
354
+ }
355
+ else {
356
+ CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
357
+ return Status::kErrorInternal;
358
+ }
359
+ }
360
+
361
+ /// Runs the kernel using initialized state.
362
+ Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
363
+ return run(stream, cuda_adapter, kernel_index);
364
+ }
365
+
366
+ /// Runs the kernel using initialized state.
367
+ Status operator()(
368
+ Arguments const &args,
369
+ void *workspace = nullptr,
370
+ cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
371
+
372
+ Status status = initialize(args, workspace, stream, cuda_adapter);
373
+
374
+ if (status == Status::kSuccess) {
375
+ status = run(stream, cuda_adapter, kernel_index);
376
+ }
377
+
378
+ return status;
379
+ }
380
+ };
381
+
382
+ /////////////////////////////////////////////////////////////////////////////////////////////////
383
+
384
+ }
385
+ }
386
+ }
387
+
388
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <limits>
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/device_kernel.h"
41
+ #include "cutlass/conv/convolution.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace conv {
47
+ namespace device {
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ template<typename ImplicitGemmFusionKernel_>
52
+ class ImplicitGemmConvolutionFusion {
53
+ public:
54
+
55
+ using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_;
56
+
57
+ using ElementA = typename ImplicitGemmFusionKernel::ElementA;
58
+ using LayoutA = typename ImplicitGemmFusionKernel::LayoutA;
59
+ using ElementB = typename ImplicitGemmFusionKernel::ElementB;
60
+ using LayoutB = typename ImplicitGemmFusionKernel::LayoutB;
61
+
62
+ // using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias;
63
+ // using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias;
64
+
65
+ using ElementC = typename ImplicitGemmFusionKernel::ElementC;
66
+ using LayoutC = typename ImplicitGemmFusionKernel::LayoutC;
67
+ using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator;
68
+ using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute;
69
+ using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass;
70
+ using ArchTag = typename ImplicitGemmFusionKernel::ArchTag;
71
+ using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape;
72
+ using WarpShape = typename ImplicitGemmFusionKernel::WarpShape;
73
+ using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape;
74
+ using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle;
75
+ using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp;
76
+ static int const kStages = ImplicitGemmFusionKernel::kStages;
77
+ static int const kConvDim = ImplicitGemmFusionKernel::kConvDim;
78
+ using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator;
79
+ using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator;
80
+ using MathOperator = typename ImplicitGemmFusionKernel::MathOperator;
81
+
82
+ static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator;
83
+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm;
84
+
85
+ static int const kWarpCount =
86
+ (ThreadblockShape::kM / WarpShape::kM) *
87
+ (ThreadblockShape::kN / WarpShape::kN) *
88
+ (ThreadblockShape::kK / WarpShape::kK);
89
+
90
+ /// Argument structure
91
+ using Arguments = typename ImplicitGemmFusionKernel::Arguments;
92
+
93
+ private:
94
+
95
+ /// Kernel parameters object
96
+ typename ImplicitGemmFusionKernel::Params params_;
97
+
98
+ public:
99
+
100
+ /// Constructs Implicit GEMM
101
+ ImplicitGemmConvolutionFusion() { }
102
+
103
+ /// Determines whether the Implicit GEMM can execute the given problem.
104
+ static Status can_implement(Arguments const &args) {
105
+
106
+ // dispatch to iterators
107
+ Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size);
108
+ if (Status::kSuccess != status) {
109
+ return status;
110
+ }
111
+
112
+ status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size);
113
+ if (Status::kSuccess != status) {
114
+ return status;
115
+ }
116
+
117
+ // Determine grid shape
118
+ ThreadblockSwizzle threadblock_swizzle;
119
+
120
+ dim3 grid = threadblock_swizzle.get_grid_shape(
121
+ threadblock_swizzle.get_tiled_shape(
122
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
123
+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
124
+ args.problem_size.split_k_slices));
125
+
126
+ if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
127
+ grid.z <= std::numeric_limits<uint16_t>::max())) {
128
+
129
+ return Status::kErrorInvalidProblem;
130
+ }
131
+
132
+ return Status::kSuccess;
133
+ }
134
+
135
+ /// Gets the workspace size
136
+ static size_t get_workspace_size(Arguments const &args) {
137
+
138
+ size_t workspace_bytes = 0;
139
+
140
+ // Determine grid shape
141
+ ThreadblockSwizzle threadblock_swizzle;
142
+
143
+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
144
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
145
+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
146
+ args.problem_size.split_k_slices);
147
+
148
+ if(args.split_k_mode == SplitKMode::kParallel) {
149
+
150
+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
151
+ // The user needs to call a reduction operator to optain the final output tensor
152
+ workspace_bytes =
153
+ sizeof(ElementAccumulator) *
154
+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
155
+ size_t(grid_tiled_shape.k());
156
+ }
157
+
158
+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
159
+
160
+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the
161
+ // final reduced output to user's output tensor
162
+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
163
+ }
164
+
165
+ return workspace_bytes;
166
+ }
167
+
168
+ /// Initializes GEMM state from arguments.
169
+ Status initialize(
170
+ Arguments const &args,
171
+ void *workspace = nullptr,
172
+ cudaStream_t stream = nullptr) {
173
+
174
+ if (args.problem_size.split_k_slices > 1) {
175
+
176
+ if (!workspace) {
177
+ return Status::kErrorWorkspaceNull;
178
+ }
179
+
180
+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
181
+
182
+ if (status != cudaSuccess) {
183
+ return Status::kErrorInternal;
184
+ }
185
+ }
186
+
187
+ // initialize the params structure from the arguments
188
+ params_ = typename ImplicitGemmFusionKernel::Params(
189
+ args,
190
+ static_cast<int *>(workspace)
191
+ );
192
+
193
+ int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
194
+
195
+ if (smem_size >= (48 << 10)) {
196
+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmFusionKernel>,
197
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
198
+ smem_size);
199
+
200
+ if (result != cudaSuccess) {
201
+ return Status::kErrorInternal;
202
+ }
203
+ }
204
+
205
+ return Status::kSuccess;
206
+ }
207
+
208
+ /// Initializes Impicit GEMM state from arguments.
209
+ Status update(Arguments const &args, void *workspace = nullptr) {
210
+
211
+ // update the params structure from the arguments
212
+ params_.ptr_A = args.ref_A.data();
213
+ params_.ptr_B = args.ref_B.data();
214
+ params_.ptr_scale = args.ref_A_scale.data();
215
+ params_.ptr_bias = args.ref_A_bias.data();
216
+ params_.ptr_C = args.ref_C.data();
217
+ params_.ptr_D = args.ref_D.data();
218
+ params_.output_op = args.output_op;
219
+ params_.semaphore = static_cast<int *>(workspace);
220
+
221
+ return Status::kSuccess;
222
+ }
223
+
224
+ /// Runs the kernel using initialized state.
225
+ Status run(cudaStream_t stream = nullptr) {
226
+
227
+ ThreadblockSwizzle threadblock_swizzle;
228
+
229
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
230
+ dim3 block(32 * kWarpCount, 1, 1);
231
+
232
+ int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
233
+
234
+ cutlass::arch::synclog_setup();
235
+ cutlass::Kernel<ImplicitGemmFusionKernel><<<grid, block, smem_size, stream>>>(params_);
236
+
237
+ cudaError_t result = cudaGetLastError();
238
+
239
+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
240
+ }
241
+
242
+ /// Runs the kernel using initialized state.
243
+ Status operator()(cudaStream_t stream = nullptr) {
244
+ return run(stream);
245
+ }
246
+
247
+ /// Runs the kernel using initialized state.
248
+ Status operator()(
249
+ Arguments const &args,
250
+ void *workspace = nullptr,
251
+ cudaStream_t stream = nullptr) {
252
+
253
+ Status status = initialize(args, workspace, stream);
254
+
255
+ if (status == Status::kSuccess) {
256
+ status = run(stream);
257
+ }
258
+
259
+ return status;
260
+ }
261
+ };
262
+
263
+ /////////////////////////////////////////////////////////////////////////////////////////////////
264
+
265
+ }
266
+ }
267
+ }
268
+
269
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/conv/convolution.h"
34
+ #include "cutlass/epilogue/thread/activation.h"
35
+ #include "cutlass/arch/arch.h"
36
+
37
+ #include "cute/layout.hpp"
38
+ #include "cute/numeric/integral_constant.hpp"
39
+
40
+ #include "cutlass/gemm/dispatch_policy.hpp"
41
+
42
+ //////////////////////////////////////////////////////////////////////////////
43
+
44
+ //////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace cutlass::conv {
47
+
48
+ //////////////////////////////////////////////////////////////////////////////
49
+
50
+ //
51
+ // Policies for categorical dispatch of mainloop against kernel grid schedules
52
+ //
53
+ struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { };
54
+ struct KernelImplicitTmaWarpSpecializedSm90Cooperative { };
55
+ struct KernelImplicitTmaWarpSpecializedSm90Pingpong { };
56
+
57
+ //
58
+ // Collective Mainloop Policies
59
+ //
60
+
61
+ // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA
62
+ // for fprop
63
+ template<
64
+ conv::Operator ConvOp_,
65
+ int Stages_,
66
+ int NumSpatialDimensions_,
67
+ class ClusterShape_ = cute::Shape<cute::C<1>,cute::C<1>,cute::C<1>>,
68
+ class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90,
69
+ int PipelineAsyncMmaStages_ = 1
70
+ >
71
+ struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm {
72
+ static constexpr int Stages = Stages_;
73
+ static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
74
+ static constexpr Operator ConvOp = ConvOp_;
75
+ static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_;
76
+ using ClusterShape = ClusterShape_;
77
+ using ArchTag = arch::Sm90;
78
+ using Schedule = KernelSchedule;
79
+
80
+ static_assert(NumSpatialDimensions >= 1);
81
+ static_assert(! (cute::is_same_v<KernelSchedule,KernelImplicitTmaWarpSpecializedSm90Cooperative> ||
82
+ cute::is_same_v<KernelSchedule,KernelImplicitTmaWarpSpecializedSm90Pingpong>),
83
+ "Persistent schedules not support for conv yet.");
84
+ };
85
+
86
+
87
+
88
+ // SM100 tensor op kernel schedule
89
+ struct KernelImplicitTmaWarpSpecializedSm100 {
90
+ static constexpr int SchedulerPipelineStageCount = 0;
91
+ static constexpr int AccumulatorPipelineStageCount = 0;
92
+ };
93
+
94
+ // Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100
95
+ // but for opting into 1 or 2 SM atoms
96
+ struct KernelImplicitTmaWarpSpecialized1SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { };
97
+ struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { };
98
+
99
+ struct KernelStridedDgradTmaWs1SmSm100 { };
100
+ struct KernelStridedDgradTmaWs2SmSm100 { };
101
+
102
+ // Policy for implicit gemm kernel
103
+ template<
104
+ int SchedulerPipelineStageCount_,
105
+ int AccumulatorPipelineStageCount_
106
+ >
107
+ struct KernelScheduleImplicitTmaWarpSpecializedSm100 : KernelImplicitTmaWarpSpecializedSm100 {
108
+ static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
109
+ static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
110
+ };
111
+
112
+ // n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop
113
+ template<
114
+ conv::Operator ConvOp_,
115
+ int Stages_,
116
+ int NumSpatialDimensions_,
117
+ int SchedulerPipelineStageCount_,
118
+ int AccumulatorPipelineStageCount_,
119
+ class ClusterShape_ = cute::Shape<cute::C<1>,cute::C<1>,cute::C<1>>
120
+ >
121
+ struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm {
122
+ static constexpr int Stages = Stages_;
123
+ static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
124
+ static constexpr Operator ConvOp = ConvOp_;
125
+ using ClusterShape = ClusterShape_;
126
+ using ArchTag = arch::Sm100;
127
+ using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
128
+
129
+ static_assert(NumSpatialDimensions >= 1);
130
+ };
131
+
132
+ //////////////////////////////////////////////////////////////////////////////
133
+
134
+ } // namespace cutlass::conv
135
+
136
+ //////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/conv/convnd_problem_shape.hpp"
34
+ #include "cutlass/detail/dependent_false.hpp"
35
+
36
+ ////////////////////////////////////////////////////////////////////////////////
37
+
38
+ namespace cutlass::conv::kernel {
39
+
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+
42
+ /*
43
+ * Stateless universal device CONV kernel type that treats CONV as
44
+ * a composition of a collective mainloop and a collective epilogue.
45
+ **/
46
+ template <
47
+ class ProblemShape_,
48
+ class CollectiveMainloop_,
49
+ class CollectiveEpilogue_,
50
+ class TileSchedulerTag_ = void,
51
+ class Enable = void
52
+ >
53
+ class ConvUniversal {
54
+ static_assert(cutlass::detail::dependent_false<Enable>,
55
+ "Could not find a valid specialization at the kernel layer to dispatch against.");
56
+ };
57
+
58
+ ////////////////////////////////////////////////////////////////////////////////
59
+
60
+ } // namespace cutlass::conv::kernel
61
+
62
+ ////////////////////////////////////////////////////////////////////////////////
63
+ #include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp"
64
+ #include "cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp"
65
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue.
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/gemm/threadblock/default_mma.h"
41
+ #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
42
+ #include "cutlass/conv/threadblock/threadblock_swizzle.h"
43
+ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
44
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
45
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
46
+ #include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h"
47
+ #include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h"
48
+ #include "cutlass/conv/convolution.h"
49
+ #include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
50
+ #include "cutlass/conv/threadblock/implicit_gemm_pipelined.h"
51
+ #include "cutlass/conv/threadblock/implicit_gemm_multistage.h"
52
+ #include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h"
53
+ #include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h"
54
+ #include "cutlass/conv/kernel/implicit_gemm_convolution.h"
55
+ #include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h"
56
+ #include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace conv {
62
+ namespace kernel {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+
66
+ namespace detail {
67
+
68
+ template <
69
+ typename ArchTag,
70
+ typename Shape,
71
+ typename WarpMmaTensorOp,
72
+ int PartitionsK,
73
+ typename OutputOp
74
+ >
75
+ struct DefaultConvEpilogue {
76
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
77
+ Shape,
78
+ WarpMmaTensorOp,
79
+ PartitionsK,
80
+ OutputOp,
81
+ OutputOp::kCount
82
+ >::Epilogue;
83
+ };
84
+
85
+ template <
86
+ typename Shape,
87
+ typename WarpMmaTensorOp,
88
+ int PartitionsK,
89
+ typename OutputOp
90
+ >
91
+ struct DefaultConvEpilogue<
92
+ arch::Sm70,
93
+ Shape,
94
+ WarpMmaTensorOp,
95
+ PartitionsK,
96
+ OutputOp
97
+ > {
98
+
99
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
100
+ Shape,
101
+ WarpMmaTensorOp,
102
+ PartitionsK,
103
+ OutputOp,
104
+ OutputOp::kCount
105
+ >::Epilogue;
106
+ };
107
+
108
+ /////////////////////////////////////////////////////////////////////////////////////////////////
109
+ template <
110
+ typename ArchTag,
111
+ typename Shape,
112
+ typename WarpMmaSimt,
113
+ typename ElementOutput,
114
+ typename ElementTensor,
115
+ typename ElementVector,
116
+ typename OutputOp,
117
+ int ElementsPerAccess,
118
+ typename PermuteDLayout = layout::NoPermute,
119
+ conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
120
+ int Rank = 4
121
+ >
122
+ struct DefaultConvEpilogueWithBroadcastSimt {
123
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt<
124
+ Shape,
125
+ WarpMmaSimt,
126
+ ElementOutput,
127
+ ElementTensor,
128
+ ElementVector,
129
+ OutputOp,
130
+ ElementsPerAccess,
131
+ false,
132
+ PermuteDLayout,
133
+ StrideSupport,
134
+ Rank
135
+ >::Epilogue;
136
+ };
137
+
138
+ template <
139
+ typename ArchTag,
140
+ typename Shape,
141
+ typename WarpMmaSimt,
142
+ typename ElementOutput,
143
+ typename ElementTensor,
144
+ typename ElementVector,
145
+ typename OutputOp,
146
+ int ElementsPerAccess
147
+ >
148
+ struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad {
149
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad<
150
+ Shape,
151
+ WarpMmaSimt,
152
+ ElementOutput,
153
+ ElementTensor,
154
+ ElementVector,
155
+ OutputOp,
156
+ ElementsPerAccess
157
+ >::Epilogue;
158
+ };
159
+
160
+ template <
161
+ typename ArchTag,
162
+ typename Shape,
163
+ typename WarpMmaTensorOp,
164
+ int PartitionsK,
165
+ typename ElementOutput,
166
+ typename ElementTensor,
167
+ typename ElementVector,
168
+ typename OutputOp,
169
+ int ElementsPerAccess
170
+ >
171
+ struct DefaultConvEpilogueWithBroadcastTensorOp {
172
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp<
173
+ Shape,
174
+ WarpMmaTensorOp,
175
+ PartitionsK,
176
+ ElementOutput,
177
+ ElementTensor,
178
+ ElementVector,
179
+ OutputOp,
180
+ ElementsPerAccess
181
+ >::Epilogue;
182
+ };
183
+
184
+ template <
185
+ typename Shape,
186
+ typename WarpMmaTensorOp,
187
+ int PartitionsK,
188
+ typename ElementOutput,
189
+ typename ElementTensor,
190
+ typename ElementVector,
191
+ typename OutputOp,
192
+ int ElementsPerAccess
193
+ >
194
+ struct DefaultConvEpilogueWithBroadcastTensorOp<
195
+ arch::Sm70,
196
+ Shape,
197
+ WarpMmaTensorOp,
198
+ PartitionsK,
199
+ ElementOutput,
200
+ ElementTensor,
201
+ ElementVector,
202
+ OutputOp,
203
+ ElementsPerAccess
204
+ > {
205
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp<
206
+ Shape,
207
+ WarpMmaTensorOp,
208
+ PartitionsK,
209
+ ElementOutput,
210
+ ElementTensor,
211
+ ElementVector,
212
+ OutputOp,
213
+ ElementsPerAccess
214
+ >::Epilogue;
215
+ };
216
+
217
+ /////////////////////////////////////////////////////////////////////////////////////////////////
218
+
219
+ template <
220
+ typename ArchTag,
221
+ typename Shape,
222
+ typename WarpMmaTensorOp,
223
+ int PartitionsK,
224
+ typename ElementOutput,
225
+ typename OutputOp,
226
+ typename ReductionOp,
227
+ int ElementsPerAccess
228
+ >
229
+ struct DefaultConvEpilogueWithReductionTensorOp {
230
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp<
231
+ Shape,
232
+ WarpMmaTensorOp,
233
+ PartitionsK,
234
+ ElementOutput,
235
+ OutputOp,
236
+ ReductionOp,
237
+ ElementsPerAccess
238
+ >::Epilogue;
239
+ };
240
+
241
+ template <
242
+ typename Shape,
243
+ typename WarpMmaTensorOp,
244
+ int PartitionsK,
245
+ typename ElementOutput,
246
+ typename OutputOp,
247
+ typename ReductionOp,
248
+ int ElementsPerAccess
249
+ >
250
+ struct DefaultConvEpilogueWithReductionTensorOp<
251
+ arch::Sm70,
252
+ Shape,
253
+ WarpMmaTensorOp,
254
+ PartitionsK,
255
+ ElementOutput,
256
+ OutputOp,
257
+ ReductionOp,
258
+ ElementsPerAccess
259
+ > {
260
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp<
261
+ Shape,
262
+ WarpMmaTensorOp,
263
+ PartitionsK,
264
+ ElementOutput,
265
+ OutputOp,
266
+ ReductionOp,
267
+ ElementsPerAccess
268
+ >::Epilogue;
269
+ };
270
+
271
+ /////////////////////////////////////////////////////////////////////////////////////////////////
272
+
273
+ // Defaults for strided Dgrad
274
+ template <
275
+ typename ArchTag,
276
+ typename Shape,
277
+ typename WarpMmaTensorOp,
278
+ int PartitionsK,
279
+ typename OutputOp
280
+ >
281
+ struct DefaultConvEpilogueStridedDgrad {
282
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
283
+ Shape,
284
+ WarpMmaTensorOp,
285
+ PartitionsK,
286
+ OutputOp,
287
+ OutputOp::kCount
288
+ >::Epilogue;
289
+ };
290
+
291
+ template <
292
+ typename Shape,
293
+ typename WarpMmaTensorOp,
294
+ int PartitionsK,
295
+ typename OutputOp
296
+ >
297
+ struct DefaultConvEpilogueStridedDgrad<
298
+ arch::Sm70,
299
+ Shape,
300
+ WarpMmaTensorOp,
301
+ PartitionsK,
302
+ OutputOp
303
+ > {
304
+
305
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad<
306
+ Shape,
307
+ WarpMmaTensorOp,
308
+ PartitionsK,
309
+ OutputOp,
310
+ OutputOp::kCount
311
+ >::Epilogue;
312
+ };
313
+
314
+ } // namespace detail
315
+
316
+ /////////////////////////////////////////////////////////////////////////////////////////////////
317
+
318
+ } // namespace kernel
319
+ } // namespace conv
320
+ } // namespace cutlass
321
+
322
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h ADDED
@@ -0,0 +1,1927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h"
45
+ #include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h"
46
+ #include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h"
47
+ #include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass {
52
+ namespace conv {
53
+ namespace kernel {
54
+
55
+ /////////////////////////////////////////////////////////////////////////////////////////////////
56
+ /// Defines a kernel for Conv2dDgrad
57
+ template <
58
+ typename ElementA,
59
+ typename LayoutA,
60
+ typename ElementB,
61
+ typename LayoutB,
62
+ typename ElementC,
63
+ typename LayoutC,
64
+ typename ElementAccumulator,
65
+ typename OperatorClass,
66
+ typename ArchTag,
67
+ typename ThreadblockShape,
68
+ typename WarpShape,
69
+ typename InstructionShape,
70
+ typename EpilogueOutputOp,
71
+ typename ThreadblockSwizzle,
72
+ int Stages,
73
+ typename MathOperatorTag,
74
+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
75
+ conv::StrideSupport StrideSupport = StrideSupport::kStrided,
76
+ /// Access granularity of A matrix in units of elements
77
+ int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
78
+ /// Access granularity of B matrix in units of elements
79
+ int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
80
+ > struct DefaultConv2dDgrad;
81
+
82
+ /////////////////////////////////////////////////////////////////////////////////////////////////
83
+ // OpClassTensorOp convolutions
84
+ /////////////////////////////////////////////////////////////////////////////////////////////////
85
+
86
+ /// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and
87
+ // multistage pipeline.
88
+ template <
89
+ typename ElementA,
90
+ typename LayoutA,
91
+ typename ElementB,
92
+ typename LayoutB,
93
+ typename ElementC,
94
+ typename LayoutC,
95
+ typename ElementAccumulator,
96
+ typename ArchTag,
97
+ typename ThreadblockShape,
98
+ typename WarpShape,
99
+ typename InstructionShape,
100
+ typename EpilogueOutputOp,
101
+ typename ThreadblockSwizzle,
102
+ int Stages,
103
+ typename MathOperatorTag,
104
+ int AlignmentA,
105
+ int AlignmentB
106
+ >
107
+ struct DefaultConv2dDgrad <
108
+ ElementA,
109
+ LayoutA,
110
+ ElementB,
111
+ LayoutB,
112
+ ElementC,
113
+ LayoutC,
114
+ ElementAccumulator,
115
+ arch::OpClassTensorOp,
116
+ ArchTag,
117
+ ThreadblockShape,
118
+ WarpShape,
119
+ InstructionShape,
120
+ EpilogueOutputOp,
121
+ ThreadblockSwizzle,
122
+ Stages,
123
+ MathOperatorTag,
124
+ IteratorAlgorithm::kAnalytic,
125
+ StrideSupport::kStrided,
126
+ AlignmentA,
127
+ AlignmentB
128
+ > {
129
+
130
+ // Define the core components from GEMM
131
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
132
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
133
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
134
+ Stages, MathOperatorTag>;
135
+
136
+ // Define iterators over tiles from the A operand
137
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
138
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
139
+ using IteratorA =
140
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
141
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
142
+ ElementA,
143
+ ThreadMapA,
144
+ StrideSupport::kStrided,
145
+ AccessTypeA
146
+ >;
147
+
148
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
149
+
150
+ // Define iterators over tiles from the B operand
151
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
152
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
153
+ using IteratorB =
154
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
155
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
156
+ ElementB,
157
+ ThreadMapB,
158
+ StrideSupport::kStrided,
159
+ AccessTypeB
160
+ >;
161
+
162
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
163
+
164
+ // Warp-level GEMM components
165
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
166
+ using MmaPolicy = typename MmaCore::MmaPolicy;
167
+
168
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
169
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
170
+ ? cutlass::arch::CacheOperation::Global
171
+ : cutlass::arch::CacheOperation::Always;
172
+
173
+ // Define the Mma
174
+ using Mma = threadblock::ImplicitGemmMultistage<
175
+ ThreadblockShape,
176
+ IteratorA,
177
+ SmemIteratorA,
178
+ arch::CacheOperation::Always,
179
+ IteratorB,
180
+ SmemIteratorB,
181
+ CacheOpB,
182
+ MmaPolicy,
183
+ Stages
184
+ >;
185
+
186
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
187
+
188
+ // Define the epilogue
189
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
190
+ ThreadblockShape,
191
+ WarpMmaTensorOp,
192
+ kPartitionsK,
193
+ EpilogueOutputOp,
194
+ EpilogueOutputOp::kCount
195
+ >::Epilogue;
196
+
197
+ // Define the kernel
198
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
199
+ Mma,
200
+ Epilogue,
201
+ ThreadblockSwizzle,
202
+ conv::Operator::kDgrad
203
+ >;
204
+ };
205
+
206
+ /// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided
207
+ // and 2 stage pipeline.
208
+ template <
209
+ typename ElementA,
210
+ typename LayoutA,
211
+ typename ElementB,
212
+ typename LayoutB,
213
+ typename ElementC,
214
+ typename LayoutC,
215
+ typename ElementAccumulator,
216
+ typename ArchTag,
217
+ typename ThreadblockShape,
218
+ typename WarpShape,
219
+ typename InstructionShape,
220
+ typename EpilogueOutputOp,
221
+ typename ThreadblockSwizzle,
222
+ typename MathOperatorTag,
223
+ int AlignmentA,
224
+ int AlignmentB
225
+ >
226
+ struct DefaultConv2dDgrad <
227
+ ElementA,
228
+ LayoutA,
229
+ ElementB,
230
+ LayoutB,
231
+ ElementC,
232
+ LayoutC,
233
+ ElementAccumulator,
234
+ arch::OpClassTensorOp,
235
+ ArchTag,
236
+ ThreadblockShape,
237
+ WarpShape,
238
+ InstructionShape,
239
+ EpilogueOutputOp,
240
+ ThreadblockSwizzle,
241
+ 2,
242
+ MathOperatorTag,
243
+ IteratorAlgorithm::kAnalytic,
244
+ StrideSupport::kStrided,
245
+ AlignmentA,
246
+ AlignmentB
247
+ > {
248
+
249
+ // Define the core components from GEMM
250
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
251
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
252
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
253
+ 2, MathOperatorTag>;
254
+
255
+ // Define iterators over tiles from the A operand
256
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
257
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
258
+ using IteratorA =
259
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
260
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
261
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
262
+ ElementA,
263
+ ThreadMapA,
264
+ StrideSupport::kStrided,
265
+ AccessTypeA
266
+ >
267
+ >;
268
+
269
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
270
+
271
+ // Define iterators over tiles from the B operand
272
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
273
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
274
+ using IteratorB =
275
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
276
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
277
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
278
+ ElementB,
279
+ ThreadMapB,
280
+ StrideSupport::kStrided,
281
+ AccessTypeB
282
+ >
283
+ >;
284
+
285
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
286
+
287
+ // Warp-level GEMM components
288
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
289
+ using MmaPolicy = typename MmaCore::MmaPolicy;
290
+
291
+ // Define the Mma
292
+ using Mma = threadblock::ImplicitGemmPipelined<
293
+ ThreadblockShape,
294
+ IteratorA,
295
+ SmemIteratorA,
296
+ IteratorB,
297
+ SmemIteratorB,
298
+ ElementC,
299
+ LayoutC,
300
+ MmaPolicy
301
+ >;
302
+
303
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
304
+
305
+ // Define the epilogue
306
+ using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
307
+ ArchTag,
308
+ ThreadblockShape,
309
+ WarpMmaTensorOp,
310
+ kPartitionsK,
311
+ EpilogueOutputOp
312
+ >::Epilogue;
313
+
314
+ // Define the kernel
315
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
316
+ Mma,
317
+ Epilogue,
318
+ ThreadblockSwizzle,
319
+ conv::Operator::kDgrad
320
+ >;
321
+ };
322
+
323
+ /////////////////////////////////////////////////////////////////////////////////////////////////
324
+
325
+ /// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided
326
+ // and multistage pipeline.
327
+ template <
328
+ typename ElementA,
329
+ typename LayoutA,
330
+ typename ElementB,
331
+ typename LayoutB,
332
+ typename ElementC,
333
+ typename LayoutC,
334
+ typename ElementAccumulator,
335
+ typename ArchTag,
336
+ typename ThreadblockShape,
337
+ typename WarpShape,
338
+ typename InstructionShape,
339
+ typename EpilogueOutputOp,
340
+ typename ThreadblockSwizzle,
341
+ int Stages,
342
+ typename MathOperatorTag,
343
+ int AlignmentA,
344
+ int AlignmentB
345
+ >
346
+ struct DefaultConv2dDgrad <
347
+ ElementA,
348
+ LayoutA,
349
+ ElementB,
350
+ LayoutB,
351
+ ElementC,
352
+ LayoutC,
353
+ ElementAccumulator,
354
+ arch::OpClassTensorOp,
355
+ ArchTag,
356
+ ThreadblockShape,
357
+ WarpShape,
358
+ InstructionShape,
359
+ EpilogueOutputOp,
360
+ ThreadblockSwizzle,
361
+ Stages,
362
+ MathOperatorTag,
363
+ IteratorAlgorithm::kAnalytic,
364
+ StrideSupport::kUnity,
365
+ AlignmentA,
366
+ AlignmentB
367
+ > {
368
+
369
+ // Define the core components from GEMM
370
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
371
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
372
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
373
+ Stages, MathOperatorTag>;
374
+
375
+ // Define iterators over tiles from the A operand
376
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
377
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
378
+ using IteratorA =
379
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
380
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
381
+ ElementA,
382
+ ThreadMapA,
383
+ StrideSupport::kUnity,
384
+ AccessTypeA
385
+ >;
386
+
387
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
388
+
389
+ // Define iterators over tiles from the B operand
390
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
391
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
392
+ using IteratorB =
393
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
394
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
395
+ ElementB,
396
+ ThreadMapB,
397
+ StrideSupport::kUnity,
398
+ AccessTypeB
399
+ >;
400
+
401
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
402
+
403
+ // Warp-level GEMM components
404
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
405
+ using MmaPolicy = typename MmaCore::MmaPolicy;
406
+
407
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
408
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
409
+ ? cutlass::arch::CacheOperation::Global
410
+ : cutlass::arch::CacheOperation::Always;
411
+
412
+ // Define the Mma
413
+ using Mma = threadblock::ImplicitGemmMultistage<
414
+ ThreadblockShape,
415
+ IteratorA,
416
+ SmemIteratorA,
417
+ arch::CacheOperation::Always,
418
+ IteratorB,
419
+ SmemIteratorB,
420
+ CacheOpB,
421
+ MmaPolicy,
422
+ Stages
423
+ >;
424
+
425
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
426
+
427
+ // Define the epilogue
428
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
429
+ ThreadblockShape,
430
+ WarpMmaTensorOp,
431
+ kPartitionsK,
432
+ EpilogueOutputOp,
433
+ EpilogueOutputOp::kCount
434
+ >::Epilogue;
435
+
436
+ // Define the kernel
437
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
438
+ Mma,
439
+ Epilogue,
440
+ ThreadblockSwizzle,
441
+ conv::Operator::kDgrad
442
+ >;
443
+ };
444
+
445
+ /// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity
446
+ // 2 stage pipeline.
447
+ template <
448
+ typename ElementA,
449
+ typename LayoutA,
450
+ typename ElementB,
451
+ typename LayoutB,
452
+ typename ElementC,
453
+ typename LayoutC,
454
+ typename ElementAccumulator,
455
+ typename ArchTag,
456
+ typename ThreadblockShape,
457
+ typename WarpShape,
458
+ typename InstructionShape,
459
+ typename EpilogueOutputOp,
460
+ typename ThreadblockSwizzle,
461
+ typename MathOperatorTag,
462
+ int AlignmentA,
463
+ int AlignmentB
464
+ >
465
+ struct DefaultConv2dDgrad <
466
+ ElementA,
467
+ LayoutA,
468
+ ElementB,
469
+ LayoutB,
470
+ ElementC,
471
+ LayoutC,
472
+ ElementAccumulator,
473
+ arch::OpClassTensorOp,
474
+ ArchTag,
475
+ ThreadblockShape,
476
+ WarpShape,
477
+ InstructionShape,
478
+ EpilogueOutputOp,
479
+ ThreadblockSwizzle,
480
+ 2,
481
+ MathOperatorTag,
482
+ IteratorAlgorithm::kAnalytic,
483
+ StrideSupport::kUnity,
484
+ AlignmentA,
485
+ AlignmentB
486
+ > {
487
+
488
+ // Define the core components from GEMM
489
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
490
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
491
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
492
+ 2, MathOperatorTag>;
493
+
494
+ // Define iterators over tiles from the A operand
495
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
496
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
497
+ using IteratorA =
498
+ cutlass::conv::threadblock::TileIterator<
499
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
500
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
501
+ ElementA,
502
+ ThreadMapA,
503
+ StrideSupport::kUnity,
504
+ AccessTypeA
505
+ >
506
+ >;
507
+
508
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
509
+
510
+ // Define iterators over tiles from the B operand
511
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
512
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
513
+ using IteratorB =
514
+ cutlass::conv::threadblock::TileIterator<
515
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
516
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
517
+ ElementB,
518
+ ThreadMapB,
519
+ StrideSupport::kUnity,
520
+ AccessTypeB
521
+ >
522
+ >;
523
+
524
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
525
+
526
+ // Warp-level GEMM components
527
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
528
+ using MmaPolicy = typename MmaCore::MmaPolicy;
529
+
530
+ // Define the Mma
531
+ using Mma = threadblock::ImplicitGemmPipelined<
532
+ ThreadblockShape,
533
+ IteratorA,
534
+ SmemIteratorA,
535
+ IteratorB,
536
+ SmemIteratorB,
537
+ ElementC,
538
+ LayoutC,
539
+ MmaPolicy
540
+ >;
541
+
542
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
543
+
544
+ // Define the epilogue
545
+ using Epilogue = typename detail::DefaultConvEpilogue<
546
+ ArchTag,
547
+ ThreadblockShape,
548
+ WarpMmaTensorOp,
549
+ kPartitionsK,
550
+ EpilogueOutputOp
551
+ >::Epilogue;
552
+
553
+ // Define the kernel
554
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
555
+ Mma,
556
+ Epilogue,
557
+ ThreadblockSwizzle,
558
+ conv::Operator::kDgrad
559
+ >;
560
+ };
561
+
562
+ /////////////////////////////////////////////////////////////////////////////////////////////////
563
+
564
+ /// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided
565
+ // and multistage pipeline.
566
+ template <
567
+ typename ElementA,
568
+ typename LayoutA,
569
+ typename ElementB,
570
+ typename LayoutB,
571
+ typename ElementC,
572
+ typename LayoutC,
573
+ typename ElementAccumulator,
574
+ typename ArchTag,
575
+ typename ThreadblockShape,
576
+ typename WarpShape,
577
+ typename InstructionShape,
578
+ typename EpilogueOutputOp,
579
+ typename ThreadblockSwizzle,
580
+ int Stages,
581
+ typename MathOperatorTag,
582
+ int AlignmentA,
583
+ int AlignmentB
584
+ >
585
+ struct DefaultConv2dDgrad <
586
+ ElementA,
587
+ LayoutA,
588
+ ElementB,
589
+ LayoutB,
590
+ ElementC,
591
+ LayoutC,
592
+ ElementAccumulator,
593
+ arch::OpClassTensorOp,
594
+ ArchTag,
595
+ ThreadblockShape,
596
+ WarpShape,
597
+ InstructionShape,
598
+ EpilogueOutputOp,
599
+ ThreadblockSwizzle,
600
+ Stages,
601
+ MathOperatorTag,
602
+ IteratorAlgorithm::kOptimized,
603
+ StrideSupport::kUnity,
604
+ AlignmentA,
605
+ AlignmentB
606
+ > {
607
+
608
+ // Define the core components from GEMM
609
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
610
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
611
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
612
+ Stages, MathOperatorTag>;
613
+
614
+ // Define iterators over tiles from the A operand
615
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
616
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
617
+ using IteratorA =
618
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
619
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
620
+ ElementA,
621
+ ThreadMapA,
622
+ StrideSupport::kUnity,
623
+ AccessTypeA
624
+ >;
625
+
626
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
627
+
628
+ // Define iterators over tiles from the B operand
629
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
630
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
631
+ using IteratorB =
632
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
633
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
634
+ ElementB,
635
+ ThreadMapB,
636
+ StrideSupport::kUnity,
637
+ AccessTypeB
638
+ >;
639
+
640
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
641
+
642
+ // Warp-level GEMM components
643
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
644
+ using MmaPolicy = typename MmaCore::MmaPolicy;
645
+
646
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
647
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
648
+ ? cutlass::arch::CacheOperation::Global
649
+ : cutlass::arch::CacheOperation::Always;
650
+
651
+ // Define the Mma
652
+ using Mma = threadblock::ImplicitGemmMultistage<
653
+ ThreadblockShape,
654
+ IteratorA,
655
+ SmemIteratorA,
656
+ arch::CacheOperation::Always,
657
+ IteratorB,
658
+ SmemIteratorB,
659
+ CacheOpB,
660
+ MmaPolicy,
661
+ Stages
662
+ >;
663
+
664
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
665
+
666
+ // Define the epilogue
667
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
668
+ ThreadblockShape,
669
+ WarpMmaTensorOp,
670
+ kPartitionsK,
671
+ EpilogueOutputOp,
672
+ EpilogueOutputOp::kCount
673
+ >::Epilogue;
674
+
675
+ // Define the kernel
676
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
677
+ Mma,
678
+ Epilogue,
679
+ ThreadblockSwizzle,
680
+ conv::Operator::kDgrad
681
+ >;
682
+ };
683
+
684
+ /// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and
685
+ // multistage pipeline.
686
+ template <
687
+ typename ElementA,
688
+ typename LayoutA,
689
+ typename ElementB,
690
+ typename LayoutB,
691
+ typename ElementC,
692
+ typename LayoutC,
693
+ typename ElementAccumulator,
694
+ typename ArchTag,
695
+ typename ThreadblockShape,
696
+ typename WarpShape,
697
+ typename InstructionShape,
698
+ typename EpilogueOutputOp,
699
+ typename ThreadblockSwizzle,
700
+ int Stages,
701
+ typename MathOperatorTag,
702
+ int AlignmentA,
703
+ int AlignmentB
704
+ >
705
+ struct DefaultConv2dDgrad <
706
+ ElementA,
707
+ LayoutA,
708
+ ElementB,
709
+ LayoutB,
710
+ ElementC,
711
+ LayoutC,
712
+ ElementAccumulator,
713
+ arch::OpClassTensorOp,
714
+ ArchTag,
715
+ ThreadblockShape,
716
+ WarpShape,
717
+ InstructionShape,
718
+ EpilogueOutputOp,
719
+ ThreadblockSwizzle,
720
+ Stages,
721
+ MathOperatorTag,
722
+ IteratorAlgorithm::kOptimized,
723
+ StrideSupport::kStrided,
724
+ AlignmentA,
725
+ AlignmentB
726
+ > {
727
+
728
+ // Define the core components from GEMM
729
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
730
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
731
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
732
+ Stages, MathOperatorTag>;
733
+
734
+ // Define iterators over tiles from the A operand
735
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
736
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
737
+ using IteratorA =
738
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
739
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
740
+ ElementA,
741
+ ThreadMapA,
742
+ StrideSupport::kStrided,
743
+ AccessTypeA
744
+ >;
745
+
746
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
747
+
748
+ // Define iterators over tiles from the B operand
749
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
750
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
751
+ using IteratorB =
752
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
753
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
754
+ ElementB,
755
+ ThreadMapB,
756
+ StrideSupport::kStrided,
757
+ AccessTypeB
758
+ >;
759
+
760
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
761
+
762
+ // Warp-level GEMM components
763
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
764
+ using MmaPolicy = typename MmaCore::MmaPolicy;
765
+
766
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
767
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
768
+ ? cutlass::arch::CacheOperation::Global
769
+ : cutlass::arch::CacheOperation::Always;
770
+
771
+ // Define the Mma
772
+ using Mma = threadblock::ImplicitGemmMultistage<
773
+ ThreadblockShape,
774
+ IteratorA,
775
+ SmemIteratorA,
776
+ arch::CacheOperation::Always,
777
+ IteratorB,
778
+ SmemIteratorB,
779
+ CacheOpB,
780
+ MmaPolicy,
781
+ Stages
782
+ >;
783
+
784
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
785
+
786
+ // Define the epilogue
787
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
788
+ ThreadblockShape,
789
+ WarpMmaTensorOp,
790
+ kPartitionsK,
791
+ EpilogueOutputOp,
792
+ EpilogueOutputOp::kCount
793
+ >::Epilogue;
794
+
795
+ // Define the kernel
796
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
797
+ Mma,
798
+ Epilogue,
799
+ ThreadblockSwizzle,
800
+ conv::Operator::kDgrad
801
+ >;
802
+ };
803
+
804
+ /// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided
805
+ // and 2 stage pipeline.
806
+ template <
807
+ typename ElementA,
808
+ typename LayoutA,
809
+ typename ElementB,
810
+ typename LayoutB,
811
+ typename ElementC,
812
+ typename LayoutC,
813
+ typename ElementAccumulator,
814
+ typename ArchTag,
815
+ typename ThreadblockShape,
816
+ typename WarpShape,
817
+ typename InstructionShape,
818
+ typename EpilogueOutputOp,
819
+ typename ThreadblockSwizzle,
820
+ typename MathOperatorTag,
821
+ int AlignmentA,
822
+ int AlignmentB
823
+ >
824
+ struct DefaultConv2dDgrad <
825
+ ElementA,
826
+ LayoutA,
827
+ ElementB,
828
+ LayoutB,
829
+ ElementC,
830
+ LayoutC,
831
+ ElementAccumulator,
832
+ arch::OpClassTensorOp,
833
+ ArchTag,
834
+ ThreadblockShape,
835
+ WarpShape,
836
+ InstructionShape,
837
+ EpilogueOutputOp,
838
+ ThreadblockSwizzle,
839
+ 2,
840
+ MathOperatorTag,
841
+ IteratorAlgorithm::kOptimized,
842
+ StrideSupport::kStrided,
843
+ AlignmentA,
844
+ AlignmentB
845
+ > {
846
+
847
+ // Define the core components from GEMM
848
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
849
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
850
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
851
+ 2, MathOperatorTag>;
852
+
853
+ // Define iterators over tiles from the A operand
854
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
855
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
856
+ using IteratorA =
857
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
858
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
859
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
860
+ ElementA,
861
+ ThreadMapA,
862
+ StrideSupport::kStrided,
863
+ AccessTypeA
864
+ >
865
+ >;
866
+
867
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
868
+
869
+ // Define iterators over tiles from the B operand
870
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
871
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
872
+ using IteratorB =
873
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
874
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
875
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
876
+ ElementB,
877
+ ThreadMapB,
878
+ StrideSupport::kStrided,
879
+ AccessTypeB
880
+ >
881
+ >;
882
+
883
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
884
+
885
+ // Warp-level GEMM components
886
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
887
+ using MmaPolicy = typename MmaCore::MmaPolicy;
888
+
889
+ // Define the Mma
890
+ using Mma = threadblock::ImplicitGemmPipelined<
891
+ ThreadblockShape,
892
+ IteratorA,
893
+ SmemIteratorA,
894
+ IteratorB,
895
+ SmemIteratorB,
896
+ ElementC,
897
+ LayoutC,
898
+ MmaPolicy
899
+ >;
900
+
901
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
902
+
903
+ // Define the epilogue
904
+ using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
905
+ ArchTag,
906
+ ThreadblockShape,
907
+ WarpMmaTensorOp,
908
+ kPartitionsK,
909
+ EpilogueOutputOp
910
+ >::Epilogue;
911
+
912
+ // Define the kernel
913
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
914
+ Mma,
915
+ Epilogue,
916
+ ThreadblockSwizzle,
917
+ conv::Operator::kDgrad
918
+ >;
919
+ };
920
+
921
+ /// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity
922
+ // 2 stage pipeline
923
+ template <
924
+ typename ElementA,
925
+ typename LayoutA,
926
+ typename ElementB,
927
+ typename LayoutB,
928
+ typename ElementC,
929
+ typename LayoutC,
930
+ typename ElementAccumulator,
931
+ typename ArchTag,
932
+ typename ThreadblockShape,
933
+ typename WarpShape,
934
+ typename InstructionShape,
935
+ typename EpilogueOutputOp,
936
+ typename ThreadblockSwizzle,
937
+ typename MathOperatorTag,
938
+ int AlignmentA,
939
+ int AlignmentB
940
+ >
941
+ struct DefaultConv2dDgrad <
942
+ ElementA,
943
+ LayoutA,
944
+ ElementB,
945
+ LayoutB,
946
+ ElementC,
947
+ LayoutC,
948
+ ElementAccumulator,
949
+ arch::OpClassTensorOp,
950
+ ArchTag,
951
+ ThreadblockShape,
952
+ WarpShape,
953
+ InstructionShape,
954
+ EpilogueOutputOp,
955
+ ThreadblockSwizzle,
956
+ 2,
957
+ MathOperatorTag,
958
+ IteratorAlgorithm::kOptimized,
959
+ StrideSupport::kUnity,
960
+ AlignmentA,
961
+ AlignmentB
962
+ > {
963
+
964
+ // Define the core components from GEMM
965
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
966
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
967
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
968
+ 2, MathOperatorTag>;
969
+
970
+ // Define iterators over tiles from the A operand
971
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
972
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
973
+ using IteratorA =
974
+ cutlass::conv::threadblock::TileIterator<
975
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
976
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
977
+ ElementA,
978
+ ThreadMapA,
979
+ StrideSupport::kUnity,
980
+ AccessTypeA
981
+ >
982
+ >;
983
+
984
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
985
+
986
+ // Define iterators over tiles from the B operand
987
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
988
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
989
+ using IteratorB =
990
+ cutlass::conv::threadblock::TileIterator<
991
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
992
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
993
+ ElementB,
994
+ ThreadMapB,
995
+ StrideSupport::kUnity,
996
+ AccessTypeB
997
+ >
998
+ >;
999
+
1000
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1001
+
1002
+ // Warp-level GEMM components
1003
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
1004
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1005
+
1006
+ // Define the Mma
1007
+ using Mma = threadblock::ImplicitGemmPipelined<
1008
+ ThreadblockShape,
1009
+ IteratorA,
1010
+ SmemIteratorA,
1011
+ IteratorB,
1012
+ SmemIteratorB,
1013
+ ElementC,
1014
+ LayoutC,
1015
+ MmaPolicy
1016
+ >;
1017
+
1018
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
1019
+
1020
+ // Define the epilogue
1021
+ using Epilogue = typename detail::DefaultConvEpilogue<
1022
+ ArchTag,
1023
+ ThreadblockShape,
1024
+ WarpMmaTensorOp,
1025
+ kPartitionsK,
1026
+ EpilogueOutputOp
1027
+ >::Epilogue;
1028
+
1029
+ // Define the kernel
1030
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1031
+ Mma,
1032
+ Epilogue,
1033
+ ThreadblockSwizzle,
1034
+ conv::Operator::kDgrad
1035
+ >;
1036
+ };
1037
+
1038
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1039
+ // OpClassSimt convolutions
1040
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1041
+ /// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm,
1042
+ /// multi-stage pipeline, and FFMA-based mainloop for SM80
1043
+
1044
+ template <
1045
+ typename ElementA,
1046
+ typename LayoutA,
1047
+ typename ElementB,
1048
+ typename LayoutB,
1049
+ typename ElementC,
1050
+ typename LayoutC,
1051
+ typename ElementAccumulator,
1052
+ typename ArchTag,
1053
+ typename ThreadblockShape,
1054
+ typename WarpShape,
1055
+ typename InstructionShape,
1056
+ typename EpilogueOutputOp,
1057
+ typename ThreadblockSwizzle,
1058
+ int Stages,
1059
+ typename MathOperatorTag,
1060
+ int AlignmentA,
1061
+ int AlignmentB
1062
+ >
1063
+ struct DefaultConv2dDgrad <
1064
+ ElementA,
1065
+ LayoutA,
1066
+ ElementB,
1067
+ LayoutB,
1068
+ ElementC,
1069
+ LayoutC,
1070
+ ElementAccumulator,
1071
+ arch::OpClassSimt,
1072
+ ArchTag,
1073
+ ThreadblockShape,
1074
+ WarpShape,
1075
+ InstructionShape,
1076
+ EpilogueOutputOp,
1077
+ ThreadblockSwizzle,
1078
+ Stages,
1079
+ MathOperatorTag,
1080
+ IteratorAlgorithm::kAnalytic,
1081
+ conv::StrideSupport::kUnity,
1082
+ AlignmentA,
1083
+ AlignmentB
1084
+ > {
1085
+
1086
+ // Define the core components from GEMM
1087
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1088
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1089
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1090
+ Stages, MathOperatorTag>;
1091
+
1092
+ // Define iterators over tiles from the A operand
1093
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1094
+ using IteratorA =
1095
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
1096
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1097
+ ElementA,
1098
+ ThreadMapA,
1099
+ conv::StrideSupport::kUnity
1100
+ >;
1101
+
1102
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1103
+
1104
+ // Define iterators over tiles from the B operand
1105
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1106
+ using IteratorB =
1107
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
1108
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1109
+ ElementB,
1110
+ ThreadMapB,
1111
+ conv::StrideSupport::kUnity
1112
+ >;
1113
+
1114
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1115
+
1116
+ // Warp-level GEMM components
1117
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1118
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1119
+
1120
+ // Define the Mma
1121
+ using Mma = threadblock::ImplicitGemmMultistage<
1122
+ ThreadblockShape,
1123
+ IteratorA,
1124
+ SmemIteratorA,
1125
+ arch::CacheOperation::Always,
1126
+ IteratorB,
1127
+ SmemIteratorB,
1128
+ arch::CacheOperation::Always,
1129
+ MmaPolicy,
1130
+ Stages
1131
+ >;
1132
+
1133
+ // Define the epilogue
1134
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1135
+ ThreadblockShape,
1136
+ WarpMmaSimtOp,
1137
+ EpilogueOutputOp,
1138
+ EpilogueOutputOp::kCount
1139
+ >::Epilogue;
1140
+
1141
+ // Define the kernel
1142
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1143
+ Mma,
1144
+ Epilogue,
1145
+ ThreadblockSwizzle,
1146
+ conv::Operator::kDgrad
1147
+ >;
1148
+
1149
+ };
1150
+
1151
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1152
+
1153
+ template <
1154
+ typename ElementA,
1155
+ typename LayoutA,
1156
+ typename ElementB,
1157
+ typename LayoutB,
1158
+ typename ElementC,
1159
+ typename LayoutC,
1160
+ typename ElementAccumulator,
1161
+ typename ArchTag,
1162
+ typename ThreadblockShape,
1163
+ typename WarpShape,
1164
+ typename InstructionShape,
1165
+ typename EpilogueOutputOp,
1166
+ typename ThreadblockSwizzle,
1167
+ int Stages,
1168
+ typename MathOperatorTag,
1169
+ int AlignmentA,
1170
+ int AlignmentB
1171
+ >
1172
+ struct DefaultConv2dDgrad <
1173
+ ElementA,
1174
+ LayoutA,
1175
+ ElementB,
1176
+ LayoutB,
1177
+ ElementC,
1178
+ LayoutC,
1179
+ ElementAccumulator,
1180
+ arch::OpClassSimt,
1181
+ ArchTag,
1182
+ ThreadblockShape,
1183
+ WarpShape,
1184
+ InstructionShape,
1185
+ EpilogueOutputOp,
1186
+ ThreadblockSwizzle,
1187
+ Stages,
1188
+ MathOperatorTag,
1189
+ IteratorAlgorithm::kAnalytic,
1190
+ conv::StrideSupport::kStrided,
1191
+ AlignmentA,
1192
+ AlignmentB
1193
+ > {
1194
+
1195
+ // Define the core components from GEMM
1196
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1197
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1198
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1199
+ Stages, MathOperatorTag>;
1200
+
1201
+ // Define iterators over tiles from the A operand
1202
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1203
+ using IteratorA =
1204
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
1205
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1206
+ ElementA,
1207
+ ThreadMapA,
1208
+ conv::StrideSupport::kStrided
1209
+ >;
1210
+
1211
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1212
+
1213
+ // Define iterators over tiles from the B operand
1214
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1215
+ using IteratorB =
1216
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
1217
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1218
+ ElementB,
1219
+ ThreadMapB,
1220
+ conv::StrideSupport::kStrided
1221
+ >;
1222
+
1223
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1224
+
1225
+ // Warp-level GEMM components
1226
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1227
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1228
+
1229
+ // Define the Mma
1230
+ using Mma = threadblock::ImplicitGemmMultistage<
1231
+ ThreadblockShape,
1232
+ IteratorA,
1233
+ SmemIteratorA,
1234
+ arch::CacheOperation::Always,
1235
+ IteratorB,
1236
+ SmemIteratorB,
1237
+ arch::CacheOperation::Always,
1238
+ MmaPolicy,
1239
+ Stages
1240
+ >;
1241
+
1242
+ // Define the epilogue
1243
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
1244
+ ThreadblockShape,
1245
+ WarpMmaSimtOp,
1246
+ EpilogueOutputOp,
1247
+ EpilogueOutputOp::kCount
1248
+ >::Epilogue;
1249
+
1250
+ // Define the kernel
1251
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
1252
+ Mma,
1253
+ Epilogue,
1254
+ ThreadblockSwizzle,
1255
+ conv::Operator::kDgrad
1256
+ >;
1257
+
1258
+ };
1259
+
1260
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1261
+
1262
+ /// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm,
1263
+ /// multi-stage pipeline, and FFMA-based mainloop for SM80
1264
+
1265
+ template <
1266
+ typename ElementA,
1267
+ typename LayoutA,
1268
+ typename ElementB,
1269
+ typename LayoutB,
1270
+ typename ElementC,
1271
+ typename LayoutC,
1272
+ typename ElementAccumulator,
1273
+ typename ArchTag,
1274
+ typename ThreadblockShape,
1275
+ typename WarpShape,
1276
+ typename InstructionShape,
1277
+ typename EpilogueOutputOp,
1278
+ typename ThreadblockSwizzle,
1279
+ int Stages,
1280
+ typename MathOperatorTag,
1281
+ int AlignmentA,
1282
+ int AlignmentB
1283
+ >
1284
+ struct DefaultConv2dDgrad <
1285
+ ElementA,
1286
+ LayoutA,
1287
+ ElementB,
1288
+ LayoutB,
1289
+ ElementC,
1290
+ LayoutC,
1291
+ ElementAccumulator,
1292
+ arch::OpClassSimt,
1293
+ ArchTag,
1294
+ ThreadblockShape,
1295
+ WarpShape,
1296
+ InstructionShape,
1297
+ EpilogueOutputOp,
1298
+ ThreadblockSwizzle,
1299
+ Stages,
1300
+ MathOperatorTag,
1301
+ IteratorAlgorithm::kOptimized,
1302
+ StrideSupport::kUnity,
1303
+ AlignmentA,
1304
+ AlignmentB
1305
+ > {
1306
+
1307
+ // Define the core components from GEMM
1308
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1309
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1310
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1311
+ Stages, MathOperatorTag>;
1312
+
1313
+ // Define iterators over tiles from the A operand
1314
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1315
+ using IteratorA =
1316
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
1317
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1318
+ ElementA,
1319
+ ThreadMapA,
1320
+ StrideSupport::kUnity
1321
+ >;
1322
+
1323
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1324
+
1325
+ // Define iterators over tiles from the B operand
1326
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1327
+ using IteratorB =
1328
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
1329
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1330
+ ElementB,
1331
+ ThreadMapB,
1332
+ StrideSupport::kUnity
1333
+ >;
1334
+
1335
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1336
+
1337
+ // Warp-level GEMM components
1338
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1339
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1340
+
1341
+ // Define the Mma
1342
+ using Mma = threadblock::ImplicitGemmMultistage<
1343
+ ThreadblockShape,
1344
+ IteratorA,
1345
+ SmemIteratorA,
1346
+ arch::CacheOperation::Always,
1347
+ IteratorB,
1348
+ SmemIteratorB,
1349
+ arch::CacheOperation::Always,
1350
+ MmaPolicy,
1351
+ Stages
1352
+ >;
1353
+
1354
+ // Define the epilogue
1355
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1356
+ ThreadblockShape,
1357
+ WarpMmaSimtOp,
1358
+ EpilogueOutputOp,
1359
+ EpilogueOutputOp::kCount
1360
+ >::Epilogue;
1361
+
1362
+ // Define the kernel
1363
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1364
+ Mma,
1365
+ Epilogue,
1366
+ ThreadblockSwizzle,
1367
+ conv::Operator::kDgrad
1368
+ >;
1369
+ };
1370
+
1371
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1372
+ template <
1373
+ typename ElementA,
1374
+ typename LayoutA,
1375
+ typename ElementB,
1376
+ typename LayoutB,
1377
+ typename ElementC,
1378
+ typename LayoutC,
1379
+ typename ElementAccumulator,
1380
+ typename ArchTag,
1381
+ typename ThreadblockShape,
1382
+ typename WarpShape,
1383
+ typename InstructionShape,
1384
+ typename EpilogueOutputOp,
1385
+ typename ThreadblockSwizzle,
1386
+ int Stages,
1387
+ typename MathOperatorTag,
1388
+ int AlignmentA,
1389
+ int AlignmentB
1390
+ >
1391
+ struct DefaultConv2dDgrad <
1392
+ ElementA,
1393
+ LayoutA,
1394
+ ElementB,
1395
+ LayoutB,
1396
+ ElementC,
1397
+ LayoutC,
1398
+ ElementAccumulator,
1399
+ arch::OpClassSimt,
1400
+ ArchTag,
1401
+ ThreadblockShape,
1402
+ WarpShape,
1403
+ InstructionShape,
1404
+ EpilogueOutputOp,
1405
+ ThreadblockSwizzle,
1406
+ Stages,
1407
+ MathOperatorTag,
1408
+ IteratorAlgorithm::kOptimized,
1409
+ conv::StrideSupport::kStrided,
1410
+ AlignmentA,
1411
+ AlignmentB
1412
+ > {
1413
+
1414
+ // Define the core components from GEMM
1415
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1416
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1417
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1418
+ Stages, MathOperatorTag>;
1419
+
1420
+ // Define iterators over tiles from the A operand
1421
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1422
+ using IteratorA =
1423
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
1424
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1425
+ ElementA,
1426
+ ThreadMapA,
1427
+ conv::StrideSupport::kStrided
1428
+ >;
1429
+
1430
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1431
+
1432
+ // Define iterators over tiles from the B operand
1433
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1434
+ using IteratorB =
1435
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
1436
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1437
+ ElementB,
1438
+ ThreadMapB,
1439
+ conv::StrideSupport::kStrided
1440
+ >;
1441
+
1442
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1443
+
1444
+ // Warp-level GEMM components
1445
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1446
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1447
+
1448
+ // Define the Mma
1449
+ using Mma = threadblock::ImplicitGemmMultistage<
1450
+ ThreadblockShape,
1451
+ IteratorA,
1452
+ SmemIteratorA,
1453
+ arch::CacheOperation::Always,
1454
+ IteratorB,
1455
+ SmemIteratorB,
1456
+ arch::CacheOperation::Always,
1457
+ MmaPolicy,
1458
+ Stages
1459
+ >;
1460
+
1461
+ // Define the epilogue
1462
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
1463
+ ThreadblockShape,
1464
+ WarpMmaSimtOp,
1465
+ EpilogueOutputOp,
1466
+ EpilogueOutputOp::kCount
1467
+ >::Epilogue;
1468
+
1469
+ // Define the kernel
1470
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
1471
+ Mma,
1472
+ Epilogue,
1473
+ ThreadblockSwizzle,
1474
+ conv::Operator::kDgrad
1475
+ >;
1476
+
1477
+ };
1478
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1479
+
1480
+ /// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm,
1481
+ /// 2 stage pipeline, and FFMA-based mainloop for SM50
1482
+ template <
1483
+ typename ElementA,
1484
+ typename LayoutA,
1485
+ typename ElementB,
1486
+ typename LayoutB,
1487
+ typename ElementC,
1488
+ typename LayoutC,
1489
+ typename ElementAccumulator,
1490
+ typename ArchTag,
1491
+ typename ThreadblockShape,
1492
+ typename WarpShape,
1493
+ typename InstructionShape,
1494
+ typename EpilogueOutputOp,
1495
+ typename ThreadblockSwizzle,
1496
+ typename MathOperatorTag,
1497
+ int AlignmentA,
1498
+ int AlignmentB
1499
+ >
1500
+ struct DefaultConv2dDgrad <
1501
+ ElementA,
1502
+ LayoutA,
1503
+ ElementB,
1504
+ LayoutB,
1505
+ ElementC,
1506
+ LayoutC,
1507
+ ElementAccumulator,
1508
+ arch::OpClassSimt,
1509
+ ArchTag,
1510
+ ThreadblockShape,
1511
+ WarpShape,
1512
+ InstructionShape,
1513
+ EpilogueOutputOp,
1514
+ ThreadblockSwizzle,
1515
+ 2,
1516
+ MathOperatorTag,
1517
+ IteratorAlgorithm::kAnalytic,
1518
+ conv::StrideSupport::kUnity,
1519
+ AlignmentA,
1520
+ AlignmentB
1521
+ > {
1522
+
1523
+ // Define the core components from GEMM
1524
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1525
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1526
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1527
+ 2, MathOperatorTag>;
1528
+
1529
+ // Define iterators over tiles from the A operand
1530
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1531
+ using IteratorA =
1532
+ cutlass::conv::threadblock::TileIterator<
1533
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
1534
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1535
+ ElementA,
1536
+ ThreadMapA,
1537
+ conv::StrideSupport::kUnity
1538
+ >
1539
+ >;
1540
+
1541
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1542
+
1543
+ // Define iterators over tiles from the B operand
1544
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1545
+ using IteratorB =
1546
+ cutlass::conv::threadblock::TileIterator<
1547
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
1548
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1549
+ ElementB,
1550
+ ThreadMapB,
1551
+ conv::StrideSupport::kUnity
1552
+ >
1553
+ >;
1554
+
1555
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1556
+
1557
+ // Warp-level GEMM components
1558
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1559
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1560
+
1561
+ // Define the Mma
1562
+ using Mma = threadblock::ImplicitGemmPipelined<
1563
+ ThreadblockShape,
1564
+ IteratorA,
1565
+ SmemIteratorA,
1566
+ IteratorB,
1567
+ SmemIteratorB,
1568
+ ElementC,
1569
+ LayoutC,
1570
+ MmaPolicy
1571
+ >;
1572
+
1573
+ // Define the epilogue
1574
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1575
+ ThreadblockShape,
1576
+ WarpMmaSimtOp,
1577
+ EpilogueOutputOp,
1578
+ EpilogueOutputOp::kCount
1579
+ >::Epilogue;
1580
+
1581
+ // Define the kernel
1582
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1583
+ Mma,
1584
+ Epilogue,
1585
+ ThreadblockSwizzle,
1586
+ conv::Operator::kDgrad
1587
+ >;
1588
+
1589
+ };
1590
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1591
+
1592
+ template <
1593
+ typename ElementA,
1594
+ typename LayoutA,
1595
+ typename ElementB,
1596
+ typename LayoutB,
1597
+ typename ElementC,
1598
+ typename LayoutC,
1599
+ typename ElementAccumulator,
1600
+ typename ArchTag,
1601
+ typename ThreadblockShape,
1602
+ typename WarpShape,
1603
+ typename InstructionShape,
1604
+ typename EpilogueOutputOp,
1605
+ typename ThreadblockSwizzle,
1606
+ typename MathOperatorTag,
1607
+ int AlignmentA,
1608
+ int AlignmentB
1609
+ >
1610
+ struct DefaultConv2dDgrad <
1611
+ ElementA,
1612
+ LayoutA,
1613
+ ElementB,
1614
+ LayoutB,
1615
+ ElementC,
1616
+ LayoutC,
1617
+ ElementAccumulator,
1618
+ arch::OpClassSimt,
1619
+ ArchTag,
1620
+ ThreadblockShape,
1621
+ WarpShape,
1622
+ InstructionShape,
1623
+ EpilogueOutputOp,
1624
+ ThreadblockSwizzle,
1625
+ 2,
1626
+ MathOperatorTag,
1627
+ IteratorAlgorithm::kAnalytic,
1628
+ conv::StrideSupport::kStrided,
1629
+ AlignmentA,
1630
+ AlignmentB
1631
+ > {
1632
+
1633
+ // Define the core components from GEMM
1634
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1635
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1636
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1637
+ 2, MathOperatorTag>;
1638
+
1639
+ // Define iterators over tiles from the A operand
1640
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1641
+ using IteratorA =
1642
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
1643
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
1644
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1645
+ ElementA,
1646
+ ThreadMapA,
1647
+ conv::StrideSupport::kStrided
1648
+ >
1649
+ >;
1650
+
1651
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1652
+
1653
+ // Define iterators over tiles from the B operand
1654
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1655
+ using IteratorB =
1656
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
1657
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
1658
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1659
+ ElementB,
1660
+ ThreadMapB,
1661
+ conv::StrideSupport::kStrided
1662
+ >
1663
+ >;
1664
+
1665
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1666
+
1667
+ // Warp-level GEMM components
1668
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1669
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1670
+
1671
+ // Define the Mma
1672
+ using Mma = threadblock::ImplicitGemmPipelined<
1673
+ ThreadblockShape,
1674
+ IteratorA,
1675
+ SmemIteratorA,
1676
+ IteratorB,
1677
+ SmemIteratorB,
1678
+ ElementC,
1679
+ LayoutC,
1680
+ MmaPolicy
1681
+ >;
1682
+
1683
+ // Define the epilogue
1684
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
1685
+ ThreadblockShape,
1686
+ WarpMmaSimtOp,
1687
+ EpilogueOutputOp,
1688
+ EpilogueOutputOp::kCount
1689
+ >::Epilogue;
1690
+
1691
+ // Define the kernel
1692
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
1693
+ Mma,
1694
+ Epilogue,
1695
+ ThreadblockSwizzle,
1696
+ conv::Operator::kDgrad
1697
+ >;
1698
+ };
1699
+
1700
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1701
+
1702
+ /// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm,
1703
+ /// 2 stage pipeline, and FFMA-based mainloop for SM50
1704
+ template <
1705
+ typename ElementA,
1706
+ typename LayoutA,
1707
+ typename ElementB,
1708
+ typename LayoutB,
1709
+ typename ElementC,
1710
+ typename LayoutC,
1711
+ typename ElementAccumulator,
1712
+ typename ArchTag,
1713
+ typename ThreadblockShape,
1714
+ typename WarpShape,
1715
+ typename InstructionShape,
1716
+ typename EpilogueOutputOp,
1717
+ typename ThreadblockSwizzle,
1718
+ typename MathOperatorTag,
1719
+ int AlignmentA,
1720
+ int AlignmentB
1721
+ >
1722
+ struct DefaultConv2dDgrad <
1723
+ ElementA,
1724
+ LayoutA,
1725
+ ElementB,
1726
+ LayoutB,
1727
+ ElementC,
1728
+ LayoutC,
1729
+ ElementAccumulator,
1730
+ arch::OpClassSimt,
1731
+ ArchTag,
1732
+ ThreadblockShape,
1733
+ WarpShape,
1734
+ InstructionShape,
1735
+ EpilogueOutputOp,
1736
+ ThreadblockSwizzle,
1737
+ 2,
1738
+ MathOperatorTag,
1739
+ IteratorAlgorithm::kOptimized,
1740
+ StrideSupport::kUnity,
1741
+ AlignmentA,
1742
+ AlignmentB
1743
+ > {
1744
+
1745
+ // Define the core components from GEMM
1746
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1747
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1748
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1749
+ 2, MathOperatorTag>;
1750
+
1751
+ // Define iterators over tiles from the A operand
1752
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1753
+ using IteratorA =
1754
+ cutlass::conv::threadblock::TileIterator<
1755
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
1756
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1757
+ ElementA,
1758
+ ThreadMapA,
1759
+ StrideSupport::kUnity
1760
+ >
1761
+ >;
1762
+
1763
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1764
+
1765
+ // Define iterators over tiles from the B operand
1766
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1767
+ using IteratorB =
1768
+ cutlass::conv::threadblock::TileIterator<
1769
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
1770
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1771
+ ElementB,
1772
+ ThreadMapB,
1773
+ StrideSupport::kUnity
1774
+ >
1775
+ >;
1776
+
1777
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1778
+
1779
+ // Warp-level GEMM components
1780
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1781
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1782
+
1783
+ // Define the Mma
1784
+ using Mma = threadblock::ImplicitGemmPipelined<
1785
+ ThreadblockShape,
1786
+ IteratorA,
1787
+ SmemIteratorA,
1788
+ IteratorB,
1789
+ SmemIteratorB,
1790
+ ElementC,
1791
+ LayoutC,
1792
+ MmaPolicy
1793
+ >;
1794
+
1795
+ // Define the epilogue
1796
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1797
+ ThreadblockShape,
1798
+ WarpMmaSimtOp,
1799
+ EpilogueOutputOp,
1800
+ EpilogueOutputOp::kCount
1801
+ >::Epilogue;
1802
+
1803
+ // Define the kernel
1804
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1805
+ Mma,
1806
+ Epilogue,
1807
+ ThreadblockSwizzle,
1808
+ conv::Operator::kDgrad
1809
+ >;
1810
+
1811
+ };
1812
+
1813
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1814
+ template <
1815
+ typename ElementA,
1816
+ typename LayoutA,
1817
+ typename ElementB,
1818
+ typename LayoutB,
1819
+ typename ElementC,
1820
+ typename LayoutC,
1821
+ typename ElementAccumulator,
1822
+ typename ArchTag,
1823
+ typename ThreadblockShape,
1824
+ typename WarpShape,
1825
+ typename InstructionShape,
1826
+ typename EpilogueOutputOp,
1827
+ typename ThreadblockSwizzle,
1828
+ typename MathOperatorTag,
1829
+ int AlignmentA,
1830
+ int AlignmentB
1831
+ >
1832
+ struct DefaultConv2dDgrad <
1833
+ ElementA,
1834
+ LayoutA,
1835
+ ElementB,
1836
+ LayoutB,
1837
+ ElementC,
1838
+ LayoutC,
1839
+ ElementAccumulator,
1840
+ arch::OpClassSimt,
1841
+ ArchTag,
1842
+ ThreadblockShape,
1843
+ WarpShape,
1844
+ InstructionShape,
1845
+ EpilogueOutputOp,
1846
+ ThreadblockSwizzle,
1847
+ 2,
1848
+ MathOperatorTag,
1849
+ IteratorAlgorithm::kOptimized,
1850
+ conv::StrideSupport::kStrided,
1851
+ AlignmentA,
1852
+ AlignmentB
1853
+ > {
1854
+
1855
+ // Define the core components from GEMM
1856
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1857
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1858
+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1859
+ 2, MathOperatorTag>;
1860
+
1861
+ // Define iterators over tiles from the A operand
1862
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1863
+ using IteratorA =
1864
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
1865
+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
1866
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1867
+ ElementA,
1868
+ ThreadMapA,
1869
+ conv::StrideSupport::kStrided
1870
+ >
1871
+ >;
1872
+
1873
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1874
+
1875
+ // Define iterators over tiles from the B operand
1876
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1877
+ using IteratorB =
1878
+ cutlass::conv::threadblock::TileIteratorStridedDgrad<
1879
+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
1880
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1881
+ ElementB,
1882
+ ThreadMapB,
1883
+ conv::StrideSupport::kStrided
1884
+ >
1885
+ >;
1886
+
1887
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1888
+
1889
+ // Warp-level GEMM components
1890
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1891
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1892
+
1893
+ // Define the Mma
1894
+ using Mma = threadblock::ImplicitGemmPipelined<
1895
+ ThreadblockShape,
1896
+ IteratorA,
1897
+ SmemIteratorA,
1898
+ IteratorB,
1899
+ SmemIteratorB,
1900
+ ElementC,
1901
+ LayoutC,
1902
+ MmaPolicy
1903
+ >;
1904
+
1905
+ // Define the epilogue
1906
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
1907
+ ThreadblockShape,
1908
+ WarpMmaSimtOp,
1909
+ EpilogueOutputOp,
1910
+ EpilogueOutputOp::kCount
1911
+ >::Epilogue;
1912
+
1913
+ // Define the kernel
1914
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
1915
+ Mma,
1916
+ Epilogue,
1917
+ ThreadblockSwizzle,
1918
+ conv::Operator::kDgrad
1919
+ >;
1920
+
1921
+ };
1922
+
1923
+ } // namespace kernel
1924
+ } // namespace conv
1925
+ } // namespace cutlass
1926
+
1927
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h ADDED
@@ -0,0 +1,2007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h"
47
+
48
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
49
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
50
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h"
51
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h"
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ namespace cutlass {
56
+ namespace conv {
57
+ namespace kernel {
58
+
59
+ /////////////////////////////////////////////////////////////////////////////////////////////////
60
+ /// Defines a kernel for Conv2dFprop
61
+ template <
62
+ typename ElementA,
63
+ typename LayoutA,
64
+ typename ElementB,
65
+ typename LayoutB,
66
+ typename ElementC,
67
+ typename LayoutC,
68
+ typename ElementAccumulator,
69
+ typename OperatorClass,
70
+ typename ArchTag,
71
+ typename ThreadblockShape,
72
+ typename WarpShape,
73
+ typename InstructionShape,
74
+ typename EpilogueOutputOp,
75
+ typename ThreadblockSwizzle,
76
+ int Stages,
77
+ typename MathOperatorTag,
78
+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
79
+ conv::StrideSupport StrideSupport = StrideSupport::kUnity,
80
+ /// Access granularity of A matrix in units of elements
81
+ int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
82
+ /// Access granularity of B matrix in units of elements
83
+ int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
84
+ > struct DefaultConv2dFprop;
85
+
86
+ /////////////////////////////////////////////////////////////////////////////////////////////////
87
+ // OpClassTensorOp convolutions
88
+ /////////////////////////////////////////////////////////////////////////////////////////////////
89
+
90
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
91
+ /// pipeline.
92
+ template <
93
+ typename ElementA,
94
+ typename LayoutA,
95
+ typename ElementB,
96
+ typename LayoutB,
97
+ typename ElementC,
98
+ typename LayoutC,
99
+ typename ElementAccumulator,
100
+ typename ArchTag,
101
+ typename ThreadblockShape,
102
+ typename WarpShape,
103
+ typename InstructionShape,
104
+ typename EpilogueOutputOp,
105
+ typename ThreadblockSwizzle,
106
+ int Stages,
107
+ typename MathOperatorTag,
108
+ conv::StrideSupport StrideSupport,
109
+ int AlignmentA,
110
+ int AlignmentB
111
+ >
112
+ struct DefaultConv2dFprop <
113
+ ElementA,
114
+ LayoutA,
115
+ ElementB,
116
+ LayoutB,
117
+ ElementC,
118
+ LayoutC,
119
+ ElementAccumulator,
120
+ arch::OpClassTensorOp,
121
+ ArchTag,
122
+ ThreadblockShape,
123
+ WarpShape,
124
+ InstructionShape,
125
+ EpilogueOutputOp,
126
+ ThreadblockSwizzle,
127
+ Stages,
128
+ MathOperatorTag,
129
+ IteratorAlgorithm::kAnalytic,
130
+ StrideSupport,
131
+ AlignmentA,
132
+ AlignmentB
133
+ > {
134
+
135
+ // Define the core components from GEMM
136
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
137
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
138
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
139
+ Stages, MathOperatorTag>;
140
+
141
+ // Define iterators over tiles from the A operand
142
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
143
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
144
+ using IteratorA =
145
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
146
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
147
+ ElementA, LayoutA,
148
+ ThreadMapA,
149
+ AccessTypeA
150
+ >;
151
+
152
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
153
+
154
+ // Define iterators over tiles from the B operand
155
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
156
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
157
+ using IteratorB =
158
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
159
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
160
+ ElementB, LayoutB,
161
+ ThreadMapB,
162
+ AccessTypeB
163
+ >;
164
+
165
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
166
+
167
+ // Warp-level GEMM components
168
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
169
+ using MmaPolicy = typename MmaCore::MmaPolicy;
170
+
171
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
172
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
173
+ ? cutlass::arch::CacheOperation::Global
174
+ : cutlass::arch::CacheOperation::Always;
175
+
176
+ // Define the Mma
177
+ using Mma = threadblock::ImplicitGemmMultistage<
178
+ ThreadblockShape,
179
+ IteratorA,
180
+ SmemIteratorA,
181
+ arch::CacheOperation::Always,
182
+ IteratorB,
183
+ SmemIteratorB,
184
+ CacheOpB,
185
+ MmaPolicy,
186
+ Stages
187
+ >;
188
+
189
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
190
+
191
+ // Define the epilogue
192
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
193
+ ThreadblockShape,
194
+ WarpMmaTensorOp,
195
+ kPartitionsK,
196
+ EpilogueOutputOp,
197
+ EpilogueOutputOp::kCount
198
+ >::Epilogue;
199
+
200
+ // Define the kernel
201
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
202
+ Mma,
203
+ Epilogue,
204
+ ThreadblockSwizzle,
205
+ conv::Operator::kFprop
206
+ >;
207
+ };
208
+
209
+ /////////////////////////////////////////////////////////////////////////////////////////////////
210
+
211
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
212
+ /// pipeline.
213
+ template <
214
+ typename ElementA,
215
+ typename LayoutA,
216
+ typename ElementB,
217
+ typename LayoutB,
218
+ typename ElementC,
219
+ typename LayoutC,
220
+ typename ElementAccumulator,
221
+ typename ArchTag,
222
+ typename ThreadblockShape,
223
+ typename WarpShape,
224
+ typename InstructionShape,
225
+ typename EpilogueOutputOp,
226
+ typename ThreadblockSwizzle,
227
+ int Stages,
228
+ typename MathOperatorTag,
229
+ conv::StrideSupport StrideSupport,
230
+ int AlignmentA,
231
+ int AlignmentB
232
+ >
233
+ struct DefaultConv2dFprop <
234
+ ElementA,
235
+ LayoutA,
236
+ ElementB,
237
+ LayoutB,
238
+ ElementC,
239
+ LayoutC,
240
+ ElementAccumulator,
241
+ arch::OpClassTensorOp,
242
+ ArchTag,
243
+ ThreadblockShape,
244
+ WarpShape,
245
+ InstructionShape,
246
+ EpilogueOutputOp,
247
+ ThreadblockSwizzle,
248
+ Stages,
249
+ MathOperatorTag,
250
+ IteratorAlgorithm::kFixedChannels,
251
+ StrideSupport,
252
+ AlignmentA,
253
+ AlignmentB
254
+ > {
255
+
256
+ // Define the core components from GEMM
257
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
258
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
259
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
260
+ Stages, MathOperatorTag>;
261
+
262
+ // Define iterators over tiles from the A operand
263
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
264
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
265
+ using IteratorA =
266
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels<
267
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
268
+ ElementA, LayoutA,
269
+ ThreadMapA,
270
+ AccessTypeA
271
+ >;
272
+
273
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
274
+
275
+ // Define iterators over tiles from the B operand
276
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
277
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
278
+ using IteratorB =
279
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels<
280
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
281
+ ElementB, LayoutB,
282
+ ThreadMapB,
283
+ AccessTypeB
284
+ >;
285
+
286
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
287
+
288
+ // Warp-level GEMM components
289
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
290
+ using MmaPolicy = typename MmaCore::MmaPolicy;
291
+
292
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
293
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
294
+ ? cutlass::arch::CacheOperation::Global
295
+ : cutlass::arch::CacheOperation::Always;
296
+
297
+ // Define the Mma
298
+ using Mma = threadblock::ImplicitGemmMultistage<
299
+ ThreadblockShape,
300
+ IteratorA,
301
+ SmemIteratorA,
302
+ arch::CacheOperation::Always,
303
+ IteratorB,
304
+ SmemIteratorB,
305
+ CacheOpB,
306
+ MmaPolicy,
307
+ Stages
308
+ >;
309
+
310
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
311
+
312
+ // Define the epilogue
313
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
314
+ ThreadblockShape,
315
+ WarpMmaTensorOp,
316
+ kPartitionsK,
317
+ EpilogueOutputOp,
318
+ EpilogueOutputOp::kCount
319
+ >::Epilogue;
320
+
321
+ // Define the kernel
322
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
323
+ Mma,
324
+ Epilogue,
325
+ ThreadblockSwizzle,
326
+ conv::Operator::kFprop
327
+ >;
328
+ };
329
+
330
+ /////////////////////////////////////////////////////////////////////////////////////////////////
331
+
332
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage
333
+ /// pipeline.
334
+ template <
335
+ typename ElementA,
336
+ typename LayoutA,
337
+ typename ElementB,
338
+ typename LayoutB,
339
+ typename ElementC,
340
+ typename LayoutC,
341
+ typename ElementAccumulator,
342
+ typename ArchTag,
343
+ typename ThreadblockShape,
344
+ typename WarpShape,
345
+ typename InstructionShape,
346
+ typename EpilogueOutputOp,
347
+ typename ThreadblockSwizzle,
348
+ typename MathOperatorTag,
349
+ conv::StrideSupport StrideSupport,
350
+ int AlignmentA,
351
+ int AlignmentB
352
+ >
353
+ struct DefaultConv2dFprop <
354
+ ElementA,
355
+ LayoutA,
356
+ ElementB,
357
+ LayoutB,
358
+ ElementC,
359
+ LayoutC,
360
+ ElementAccumulator,
361
+ arch::OpClassTensorOp,
362
+ ArchTag,
363
+ ThreadblockShape,
364
+ WarpShape,
365
+ InstructionShape,
366
+ EpilogueOutputOp,
367
+ ThreadblockSwizzle,
368
+ 2,
369
+ MathOperatorTag,
370
+ IteratorAlgorithm::kFixedChannels,
371
+ StrideSupport,
372
+ AlignmentA,
373
+ AlignmentB
374
+ > {
375
+
376
+ // Define the core components from GEMM
377
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
378
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
379
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
380
+ 2, MathOperatorTag>;
381
+
382
+ // Define iterators over tiles from the A operand
383
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
384
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
385
+ using IteratorA =
386
+ cutlass::conv::threadblock::TileIterator<
387
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels<
388
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
389
+ ElementA, LayoutA,
390
+ ThreadMapA,
391
+ AccessTypeA
392
+ >
393
+ >;
394
+
395
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
396
+
397
+ // Define iterators over tiles from the B operand
398
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
399
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
400
+ using IteratorB =
401
+ cutlass::conv::threadblock::TileIterator<
402
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels<
403
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
404
+ ElementB, LayoutB,
405
+ ThreadMapB,
406
+ AccessTypeB
407
+ >
408
+ >;
409
+
410
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
411
+
412
+ // Warp-level GEMM components
413
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
414
+ using MmaPolicy = typename MmaCore::MmaPolicy;
415
+
416
+ // Define the Mma
417
+ using Mma = threadblock::ImplicitGemmPipelined<
418
+ ThreadblockShape,
419
+ IteratorA,
420
+ SmemIteratorA,
421
+ IteratorB,
422
+ SmemIteratorB,
423
+ ElementC,
424
+ LayoutC,
425
+ MmaPolicy
426
+ >;
427
+
428
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
429
+
430
+ // Define the epilogue
431
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
432
+ ThreadblockShape,
433
+ WarpMmaTensorOp,
434
+ kPartitionsK,
435
+ EpilogueOutputOp,
436
+ EpilogueOutputOp::kCount
437
+ >::Epilogue;
438
+
439
+ // Define the kernel
440
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
441
+ Mma,
442
+ Epilogue,
443
+ ThreadblockSwizzle,
444
+ conv::Operator::kFprop
445
+ >;
446
+ };
447
+
448
+ /////////////////////////////////////////////////////////////////////////////////////////////////
449
+
450
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
451
+ /// pipeline.
452
+ template <
453
+ typename ElementA,
454
+ typename LayoutA,
455
+ typename ElementB,
456
+ typename LayoutB,
457
+ typename ElementC,
458
+ typename LayoutC,
459
+ typename ElementAccumulator,
460
+ typename ArchTag,
461
+ typename ThreadblockShape,
462
+ typename WarpShape,
463
+ typename InstructionShape,
464
+ typename EpilogueOutputOp,
465
+ typename ThreadblockSwizzle,
466
+ int Stages,
467
+ typename MathOperatorTag,
468
+ conv::StrideSupport StrideSupport,
469
+ int AlignmentA,
470
+ int AlignmentB
471
+ >
472
+ struct DefaultConv2dFprop <
473
+ ElementA,
474
+ LayoutA,
475
+ ElementB,
476
+ LayoutB,
477
+ ElementC,
478
+ LayoutC,
479
+ ElementAccumulator,
480
+ arch::OpClassTensorOp,
481
+ ArchTag,
482
+ ThreadblockShape,
483
+ WarpShape,
484
+ InstructionShape,
485
+ EpilogueOutputOp,
486
+ ThreadblockSwizzle,
487
+ Stages,
488
+ MathOperatorTag,
489
+ IteratorAlgorithm::kFewChannels,
490
+ StrideSupport,
491
+ AlignmentA,
492
+ AlignmentB
493
+ > {
494
+
495
+ // Define the core components from GEMM
496
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
497
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
498
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
499
+ Stages, MathOperatorTag>;
500
+
501
+ // Define iterators over tiles from the A operand
502
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
503
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
504
+ using IteratorA =
505
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels<
506
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
507
+ ElementA, LayoutA,
508
+ ThreadMapA,
509
+ AccessTypeA
510
+ >;
511
+
512
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
513
+
514
+ // Define iterators over tiles from the B operand
515
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
516
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
517
+ using IteratorB =
518
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels<
519
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
520
+ ElementB, LayoutB,
521
+ ThreadMapB,
522
+ AccessTypeB
523
+ >;
524
+
525
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
526
+
527
+ // Warp-level GEMM components
528
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
529
+ using MmaPolicy = typename MmaCore::MmaPolicy;
530
+
531
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
532
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
533
+ ? cutlass::arch::CacheOperation::Global
534
+ : cutlass::arch::CacheOperation::Always;
535
+
536
+ // Define the Mma
537
+ using Mma = threadblock::ImplicitGemmMultistage<
538
+ ThreadblockShape,
539
+ IteratorA,
540
+ SmemIteratorA,
541
+ arch::CacheOperation::Always,
542
+ IteratorB,
543
+ SmemIteratorB,
544
+ CacheOpB,
545
+ MmaPolicy,
546
+ Stages
547
+ >;
548
+
549
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
550
+
551
+ // Define the epilogue
552
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
553
+ ThreadblockShape,
554
+ WarpMmaTensorOp,
555
+ kPartitionsK,
556
+ EpilogueOutputOp,
557
+ EpilogueOutputOp::kCount
558
+ >::Epilogue;
559
+
560
+ // Define the kernel
561
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
562
+ Mma,
563
+ Epilogue,
564
+ ThreadblockSwizzle,
565
+ conv::Operator::kFprop
566
+ >;
567
+ };
568
+
569
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
570
+ /// pipeline.
571
+ template <
572
+ typename ElementA,
573
+ typename LayoutA,
574
+ typename ElementB,
575
+ typename LayoutB,
576
+ typename ElementC,
577
+ typename LayoutC,
578
+ typename ElementAccumulator,
579
+ typename ArchTag,
580
+ typename ThreadblockShape,
581
+ typename WarpShape,
582
+ typename InstructionShape,
583
+ typename EpilogueOutputOp,
584
+ typename ThreadblockSwizzle,
585
+ typename MathOperatorTag,
586
+ conv::StrideSupport StrideSupport,
587
+ int AlignmentA,
588
+ int AlignmentB
589
+ >
590
+ struct DefaultConv2dFprop <
591
+ ElementA,
592
+ LayoutA,
593
+ ElementB,
594
+ LayoutB,
595
+ ElementC,
596
+ LayoutC,
597
+ ElementAccumulator,
598
+ arch::OpClassTensorOp,
599
+ ArchTag,
600
+ ThreadblockShape,
601
+ WarpShape,
602
+ InstructionShape,
603
+ EpilogueOutputOp,
604
+ ThreadblockSwizzle,
605
+ 2,
606
+ MathOperatorTag,
607
+ IteratorAlgorithm::kFewChannels,
608
+ StrideSupport,
609
+ AlignmentA,
610
+ AlignmentB
611
+ > {
612
+
613
+ // Define the core components from GEMM
614
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
615
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
616
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
617
+ 2, MathOperatorTag>;
618
+
619
+ // Define iterators over tiles from the A operand
620
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
621
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
622
+ using IteratorA =
623
+ cutlass::conv::threadblock::TileIterator<
624
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels<
625
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
626
+ ElementA, LayoutA,
627
+ ThreadMapA,
628
+ AccessTypeA
629
+ >
630
+ >;
631
+
632
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
633
+
634
+ // Define iterators over tiles from the B operand
635
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
636
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
637
+ using IteratorB =
638
+
639
+ cutlass::conv::threadblock::TileIterator<
640
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels<
641
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
642
+ ElementB, LayoutB,
643
+ ThreadMapB,
644
+ AccessTypeB
645
+ >
646
+ >;
647
+
648
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
649
+
650
+ // Warp-level GEMM components
651
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
652
+ using MmaPolicy = typename MmaCore::MmaPolicy;
653
+
654
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
655
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
656
+ ? cutlass::arch::CacheOperation::Global
657
+ : cutlass::arch::CacheOperation::Always;
658
+
659
+ // Define the Mma
660
+ using Mma = threadblock::ImplicitGemmPipelined<
661
+ ThreadblockShape,
662
+ IteratorA,
663
+ SmemIteratorA,
664
+ IteratorB,
665
+ SmemIteratorB,
666
+ ElementC,
667
+ LayoutC,
668
+ MmaPolicy
669
+ >;
670
+
671
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
672
+
673
+ // Define the epilogue
674
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
675
+ ThreadblockShape,
676
+ WarpMmaTensorOp,
677
+ kPartitionsK,
678
+ EpilogueOutputOp,
679
+ EpilogueOutputOp::kCount
680
+ >::Epilogue;
681
+
682
+ // Define the kernel
683
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
684
+ Mma,
685
+ Epilogue,
686
+ ThreadblockSwizzle,
687
+ conv::Operator::kFprop
688
+ >;
689
+ };
690
+
691
+ /////////////////////////////////////////////////////////////////////////////////////////////////
692
+
693
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
694
+ /// pipeline with interleaved layout.
695
+ template <
696
+ typename ElementA,
697
+ typename ElementB,
698
+ typename ElementC,
699
+ typename LayoutC,
700
+ typename ElementAccumulator,
701
+ typename ArchTag,
702
+ typename ThreadblockShape,
703
+ typename WarpShape,
704
+ typename InstructionShape,
705
+ typename EpilogueOutputOp,
706
+ typename ThreadblockSwizzle,
707
+ int Stages,
708
+ typename MathOperatorTag,
709
+ conv::StrideSupport StrideSupport,
710
+ int AlignmentA,
711
+ int AlignmentB,
712
+ int InterleavedK
713
+ >
714
+ struct DefaultConv2dFprop <
715
+ ElementA,
716
+ layout::TensorNCxHWx<InterleavedK>,
717
+ ElementB,
718
+ layout::TensorCxRSKx<InterleavedK>,
719
+ ElementC,
720
+ LayoutC,
721
+ ElementAccumulator,
722
+ arch::OpClassTensorOp,
723
+ ArchTag,
724
+ ThreadblockShape,
725
+ WarpShape,
726
+ InstructionShape,
727
+ EpilogueOutputOp,
728
+ ThreadblockSwizzle,
729
+ Stages,
730
+ MathOperatorTag,
731
+ IteratorAlgorithm::kAnalytic,
732
+ StrideSupport,
733
+ AlignmentA,
734
+ AlignmentB
735
+ > {
736
+
737
+ // Define the core components from GEMM
738
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
739
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
740
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
741
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
742
+ Stages, MathOperatorTag, true>;
743
+
744
+ // Define iterators over tiles from the A operand
745
+ // Note GEMM shared memory threadmap is used here because conv global memory
746
+ // layout needs to be mapped to fprop which is similar to the crosswise
747
+ // layout which is used by the interleaved GEMM shared memory threadmap.
748
+ // The Interleaved GEMM global memory layout is similar to the congruous
749
+ // layout.
750
+ using ThreadMapA = typename MmaCore::SmemThreadMapA;
751
+ using IteratorA =
752
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
753
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
754
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
755
+ ThreadMapA
756
+ >;
757
+
758
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
759
+
760
+ // Define iterators over tiles from the B operand
761
+ // Note GEMM shared memory threadmap is used here because conv global memory
762
+ // layout needs to be mapped to fprop which is similar to the crosswise
763
+ // layout which is used by the interleaved GEMM shared memory threadmap.
764
+ // The Interleaved GEMM global memory layout is similar to the congruous
765
+ // layout.
766
+ using ThreadMapB = typename MmaCore::SmemThreadMapB;
767
+ using IteratorB =
768
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
769
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
770
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
771
+ ThreadMapB
772
+ >;
773
+
774
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
775
+
776
+ // Warp-level GEMM components
777
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
778
+ using MmaPolicy = typename MmaCore::MmaPolicy;
779
+
780
+ // Define the Mma
781
+ using Mma = threadblock::ImplicitGemmMultistage<
782
+ ThreadblockShape,
783
+ IteratorA,
784
+ SmemIteratorA,
785
+ arch::CacheOperation::Always,
786
+ IteratorB,
787
+ SmemIteratorB,
788
+ arch::CacheOperation::Global,
789
+ MmaPolicy,
790
+ Stages
791
+ >;
792
+
793
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
794
+
795
+ // Define the epilogue
796
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
797
+ ThreadblockShape,
798
+ WarpMmaTensorOp,
799
+ kPartitionsK,
800
+ EpilogueOutputOp,
801
+ EpilogueOutputOp::kCount,
802
+ InterleavedK
803
+ >::Epilogue;
804
+
805
+ // Define the kernel
806
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
807
+ Mma,
808
+ Epilogue,
809
+ ThreadblockSwizzle,
810
+ conv::Operator::kFprop
811
+ >;
812
+ };
813
+
814
+ /////////////////////////////////////////////////////////////////////////////////////////////////
815
+
816
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
817
+ /// and 2 stage pipeline.
818
+ template <
819
+ typename ElementA,
820
+ typename LayoutA,
821
+ typename ElementB,
822
+ typename LayoutB,
823
+ typename ElementC,
824
+ typename LayoutC,
825
+ typename ElementAccumulator,
826
+ typename ArchTag,
827
+ typename ThreadblockShape,
828
+ typename WarpShape,
829
+ typename InstructionShape,
830
+ typename EpilogueOutputOp,
831
+ typename ThreadblockSwizzle,
832
+ typename MathOperatorTag,
833
+ conv::StrideSupport StrideSupport,
834
+ int AlignmentA,
835
+ int AlignmentB
836
+ >
837
+ struct DefaultConv2dFprop <
838
+ ElementA,
839
+ LayoutA,
840
+ ElementB,
841
+ LayoutB,
842
+ ElementC,
843
+ LayoutC,
844
+ ElementAccumulator,
845
+ arch::OpClassTensorOp,
846
+ ArchTag,
847
+ ThreadblockShape,
848
+ WarpShape,
849
+ InstructionShape,
850
+ EpilogueOutputOp,
851
+ ThreadblockSwizzle,
852
+ 2,
853
+ MathOperatorTag,
854
+ IteratorAlgorithm::kAnalytic,
855
+ StrideSupport,
856
+ AlignmentA,
857
+ AlignmentB
858
+ > {
859
+
860
+ // Define the core components from GEMM
861
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
862
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
863
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
864
+ 2, MathOperatorTag>;
865
+
866
+ // Define iterators over tiles from the A operand
867
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
868
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
869
+ using IteratorA =
870
+ cutlass::conv::threadblock::TileIterator<
871
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
872
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
873
+ ElementA, LayoutA,
874
+ ThreadMapA,
875
+ AccessTypeA
876
+ >
877
+ >;
878
+
879
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
880
+
881
+ // Define iterators over tiles from the B operand
882
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
883
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
884
+ using IteratorB =
885
+ cutlass::conv::threadblock::TileIterator<
886
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
887
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
888
+ ElementB, LayoutB,
889
+ ThreadMapB,
890
+ AccessTypeB
891
+ >
892
+ >;
893
+
894
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
895
+
896
+ // Warp-level GEMM components
897
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
898
+ using MmaPolicy = typename MmaCore::MmaPolicy;
899
+
900
+ // Define the Mma
901
+ using Mma = threadblock::ImplicitGemmPipelined<
902
+ ThreadblockShape,
903
+ IteratorA,
904
+ SmemIteratorA,
905
+ IteratorB,
906
+ SmemIteratorB,
907
+ ElementC,
908
+ LayoutC,
909
+ MmaPolicy
910
+ >;
911
+
912
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
913
+
914
+ // Define the epilogue
915
+ using Epilogue = typename detail::DefaultConvEpilogue<
916
+ ArchTag,
917
+ ThreadblockShape,
918
+ WarpMmaTensorOp,
919
+ kPartitionsK,
920
+ EpilogueOutputOp
921
+ >::Epilogue;
922
+
923
+ // Define the kernel
924
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
925
+ Mma,
926
+ Epilogue,
927
+ ThreadblockSwizzle,
928
+ conv::Operator::kFprop
929
+ >;
930
+ };
931
+
932
+ /////////////////////////////////////////////////////////////////////////////////////////////////
933
+
934
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
935
+ /// pipeline with interleaved layout.
936
+ template <
937
+ typename ElementA,
938
+ typename ElementB,
939
+ typename ElementC,
940
+ typename LayoutC,
941
+ typename ElementAccumulator,
942
+ typename ArchTag,
943
+ typename ThreadblockShape,
944
+ typename WarpShape,
945
+ typename InstructionShape,
946
+ typename EpilogueOutputOp,
947
+ typename ThreadblockSwizzle,
948
+ typename MathOperatorTag,
949
+ conv::StrideSupport StrideSupport,
950
+ int AlignmentA,
951
+ int AlignmentB,
952
+ int InterleavedK
953
+ >
954
+ struct DefaultConv2dFprop <
955
+ ElementA,
956
+ layout::TensorNCxHWx<InterleavedK>,
957
+ ElementB,
958
+ layout::TensorCxRSKx<InterleavedK>,
959
+ ElementC,
960
+ LayoutC,
961
+ ElementAccumulator,
962
+ arch::OpClassTensorOp,
963
+ ArchTag,
964
+ ThreadblockShape,
965
+ WarpShape,
966
+ InstructionShape,
967
+ EpilogueOutputOp,
968
+ ThreadblockSwizzle,
969
+ 2,
970
+ MathOperatorTag,
971
+ IteratorAlgorithm::kAnalytic,
972
+ StrideSupport,
973
+ AlignmentA,
974
+ AlignmentB
975
+ > {
976
+
977
+ // Define the core components from GEMM
978
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
979
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
980
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
981
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
982
+ 2, MathOperatorTag, true>;
983
+
984
+ // Define iterators over tiles from the A operand
985
+ // Note GEMM shared memory threadmap is used here because conv global memory
986
+ // layout needs to be mapped to fprop which is similar to the crosswise
987
+ // layout which is used by the interleaved GEMM shared memory threadmap.
988
+ // The Interleaved GEMM global memory layout is similar to the congruous
989
+ // layout.
990
+ using ThreadMapA = typename MmaCore::SmemThreadMapA;
991
+ using IteratorA =
992
+ cutlass::conv::threadblock::TileIterator<
993
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
994
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
995
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
996
+ ThreadMapA
997
+ >
998
+ >;
999
+
1000
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1001
+
1002
+ // Define iterators over tiles from the B operand
1003
+ // Note GEMM shared memory threadmap is used here because conv global memory
1004
+ // layout needs to be mapped to fprop which is similar to the crosswise
1005
+ // layout which is used by the interleaved GEMM shared memory threadmap.
1006
+ // The Interleaved GEMM global memory layout is similar to the congruous
1007
+ // layout.
1008
+ using ThreadMapB = typename MmaCore::SmemThreadMapB;
1009
+ using IteratorB =
1010
+ cutlass::conv::threadblock::TileIterator<
1011
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
1012
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1013
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
1014
+ ThreadMapB
1015
+ >
1016
+ >;
1017
+
1018
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1019
+
1020
+ // Warp-level GEMM components
1021
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
1022
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1023
+
1024
+ // Define the Mma
1025
+ using Mma = threadblock::ImplicitGemmPipelined<
1026
+ ThreadblockShape,
1027
+ IteratorA,
1028
+ SmemIteratorA,
1029
+ IteratorB,
1030
+ SmemIteratorB,
1031
+ ElementC,
1032
+ LayoutC,
1033
+ MmaPolicy
1034
+ >;
1035
+
1036
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
1037
+
1038
+ // Define the epilogue
1039
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
1040
+ ThreadblockShape,
1041
+ WarpMmaTensorOp,
1042
+ kPartitionsK,
1043
+ EpilogueOutputOp,
1044
+ EpilogueOutputOp::kCount,
1045
+ InterleavedK
1046
+ >::Epilogue;
1047
+
1048
+ // Define the kernel
1049
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1050
+ Mma,
1051
+ Epilogue,
1052
+ ThreadblockSwizzle,
1053
+ conv::Operator::kFprop
1054
+ >;
1055
+ };
1056
+
1057
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1058
+
1059
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
1060
+ /// multistage pipeline.
1061
+ template <
1062
+ typename ElementA,
1063
+ typename LayoutA,
1064
+ typename ElementB,
1065
+ typename LayoutB,
1066
+ typename ElementC,
1067
+ typename LayoutC,
1068
+ typename ElementAccumulator,
1069
+ typename ArchTag,
1070
+ typename ThreadblockShape,
1071
+ typename WarpShape,
1072
+ typename InstructionShape,
1073
+ typename EpilogueOutputOp,
1074
+ typename ThreadblockSwizzle,
1075
+ int Stages,
1076
+ typename MathOperatorTag,
1077
+ conv::StrideSupport StrideSupport,
1078
+ int AlignmentA,
1079
+ int AlignmentB
1080
+ >
1081
+ struct DefaultConv2dFprop <
1082
+ ElementA,
1083
+ LayoutA,
1084
+ ElementB,
1085
+ LayoutB,
1086
+ ElementC,
1087
+ LayoutC,
1088
+ ElementAccumulator,
1089
+ arch::OpClassTensorOp,
1090
+ ArchTag,
1091
+ ThreadblockShape,
1092
+ WarpShape,
1093
+ InstructionShape,
1094
+ EpilogueOutputOp,
1095
+ ThreadblockSwizzle,
1096
+ Stages,
1097
+ MathOperatorTag,
1098
+ IteratorAlgorithm::kOptimized,
1099
+ StrideSupport,
1100
+ AlignmentA,
1101
+ AlignmentB
1102
+ > {
1103
+
1104
+ // Define the core components from GEMM
1105
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1106
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1107
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
1108
+ Stages, MathOperatorTag
1109
+ >;
1110
+
1111
+ // Define iterators over tiles from the A operand
1112
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1113
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
1114
+ using IteratorA =
1115
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
1116
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1117
+ ElementA,
1118
+ LayoutA,
1119
+ ThreadMapA,
1120
+ AccessTypeA
1121
+ >;
1122
+
1123
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1124
+
1125
+ // Define iterators over tiles from the B operand
1126
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1127
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
1128
+ using IteratorB =
1129
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
1130
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1131
+ ElementB,
1132
+ LayoutB,
1133
+ ThreadMapB,
1134
+ AccessTypeB
1135
+ >;
1136
+
1137
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1138
+
1139
+ // Warp-level GEMM components
1140
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
1141
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1142
+
1143
+ static cutlass::arch::CacheOperation::Kind const CacheOpB =
1144
+ ((sizeof_bits<ElementB>::value * AlignmentB) == 128)
1145
+ ? cutlass::arch::CacheOperation::Global
1146
+ : cutlass::arch::CacheOperation::Always;
1147
+
1148
+ // Define the Mma
1149
+ using Mma = threadblock::ImplicitGemmMultistage<
1150
+ ThreadblockShape,
1151
+ IteratorA,
1152
+ SmemIteratorA,
1153
+ arch::CacheOperation::Always,
1154
+ IteratorB,
1155
+ SmemIteratorB,
1156
+ CacheOpB,
1157
+ MmaPolicy,
1158
+ Stages
1159
+ >;
1160
+
1161
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
1162
+
1163
+ // Define the epilogue
1164
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
1165
+ ThreadblockShape,
1166
+ WarpMmaTensorOp,
1167
+ kPartitionsK,
1168
+ EpilogueOutputOp,
1169
+ EpilogueOutputOp::kCount,
1170
+ false,
1171
+ layout::NoPermute,
1172
+ StrideSupport,
1173
+ 4
1174
+ >::Epilogue;
1175
+
1176
+ // Define the kernel
1177
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1178
+ Mma,
1179
+ Epilogue,
1180
+ ThreadblockSwizzle,
1181
+ conv::Operator::kFprop
1182
+ >;
1183
+ };
1184
+
1185
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1186
+
1187
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
1188
+ // multistage pipeline with interleaved layout.
1189
+ template <
1190
+ typename ElementA,
1191
+ typename ElementB,
1192
+ typename ElementC,
1193
+ typename LayoutC,
1194
+ typename ElementAccumulator,
1195
+ typename ArchTag,
1196
+ typename ThreadblockShape,
1197
+ typename WarpShape,
1198
+ typename InstructionShape,
1199
+ typename EpilogueOutputOp,
1200
+ typename ThreadblockSwizzle,
1201
+ int Stages,
1202
+ typename MathOperatorTag,
1203
+ conv::StrideSupport StrideSupport,
1204
+ int AlignmentA,
1205
+ int AlignmentB,
1206
+ int InterleavedK
1207
+ >
1208
+ struct DefaultConv2dFprop <
1209
+ ElementA,
1210
+ layout::TensorNCxHWx<InterleavedK>,
1211
+ ElementB,
1212
+ layout::TensorCxRSKx<InterleavedK>,
1213
+ ElementC,
1214
+ LayoutC,
1215
+ ElementAccumulator,
1216
+ arch::OpClassTensorOp,
1217
+ ArchTag,
1218
+ ThreadblockShape,
1219
+ WarpShape,
1220
+ InstructionShape,
1221
+ EpilogueOutputOp,
1222
+ ThreadblockSwizzle,
1223
+ Stages,
1224
+ MathOperatorTag,
1225
+ IteratorAlgorithm::kOptimized,
1226
+ StrideSupport,
1227
+ AlignmentA,
1228
+ AlignmentB
1229
+ > {
1230
+
1231
+ // Define the core components from GEMM
1232
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1233
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
1234
+ ElementB, layout::RowMajorInterleaved<InterleavedK>, ElementAccumulator, LayoutC, arch::OpClassTensorOp,
1235
+ Stages, MathOperatorTag, true
1236
+ >;
1237
+
1238
+ // Define iterators over tiles from the A operand
1239
+ using ThreadMapA = typename MmaCore::SmemThreadMapA;
1240
+ using IteratorA =
1241
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
1242
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1243
+ ElementA,
1244
+ layout::TensorNCxHWx<InterleavedK>,
1245
+ ThreadMapA
1246
+ >;
1247
+
1248
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1249
+
1250
+ // Define iterators over tiles from the B operand
1251
+ using ThreadMapB = typename MmaCore::SmemThreadMapB;
1252
+ using IteratorB =
1253
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
1254
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1255
+ ElementB,
1256
+ layout::TensorCxRSKx<InterleavedK>,
1257
+ ThreadMapB
1258
+ >;
1259
+
1260
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1261
+
1262
+ // Warp-level GEMM components
1263
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
1264
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1265
+
1266
+ // Define the Mma
1267
+ using Mma = threadblock::ImplicitGemmMultistage<
1268
+ ThreadblockShape,
1269
+ IteratorA,
1270
+ SmemIteratorA,
1271
+ arch::CacheOperation::Always,
1272
+ IteratorB,
1273
+ SmemIteratorB,
1274
+ arch::CacheOperation::Global,
1275
+ MmaPolicy,
1276
+ Stages
1277
+ >;
1278
+
1279
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
1280
+
1281
+ // Define the epilogue
1282
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
1283
+ ThreadblockShape,
1284
+ WarpMmaTensorOp,
1285
+ kPartitionsK,
1286
+ EpilogueOutputOp,
1287
+ EpilogueOutputOp::kCount,
1288
+ InterleavedK
1289
+ >::Epilogue;
1290
+
1291
+ // Define the kernel
1292
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1293
+ Mma,
1294
+ Epilogue,
1295
+ ThreadblockSwizzle,
1296
+ conv::Operator::kFprop
1297
+ >;
1298
+ };
1299
+
1300
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1301
+
1302
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
1303
+ /// and 2 stage pipeline.
1304
+ template <
1305
+ typename ElementA,
1306
+ typename LayoutA,
1307
+ typename ElementB,
1308
+ typename LayoutB,
1309
+ typename ElementC,
1310
+ typename LayoutC,
1311
+ typename ElementAccumulator,
1312
+ typename ArchTag,
1313
+ typename ThreadblockShape,
1314
+ typename WarpShape,
1315
+ typename InstructionShape,
1316
+ typename EpilogueOutputOp,
1317
+ typename ThreadblockSwizzle,
1318
+ typename MathOperatorTag,
1319
+ conv::StrideSupport StrideSupport,
1320
+ int AlignmentA,
1321
+ int AlignmentB
1322
+ >
1323
+ struct DefaultConv2dFprop <
1324
+ ElementA,
1325
+ LayoutA,
1326
+ ElementB,
1327
+ LayoutB,
1328
+ ElementC,
1329
+ LayoutC,
1330
+ ElementAccumulator,
1331
+ arch::OpClassTensorOp,
1332
+ ArchTag,
1333
+ ThreadblockShape,
1334
+ WarpShape,
1335
+ InstructionShape,
1336
+ EpilogueOutputOp,
1337
+ ThreadblockSwizzle,
1338
+ 2,
1339
+ MathOperatorTag,
1340
+ IteratorAlgorithm::kOptimized,
1341
+ StrideSupport,
1342
+ AlignmentA,
1343
+ AlignmentB
1344
+ > {
1345
+
1346
+ // Define the core components from GEMM
1347
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1348
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1349
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
1350
+ 2, MathOperatorTag>;
1351
+
1352
+ // Define iterators over tiles from the A operand
1353
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1354
+ using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
1355
+ using IteratorA =
1356
+ cutlass::conv::threadblock::TileIterator<
1357
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
1358
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1359
+ ElementA,
1360
+ LayoutA,
1361
+ ThreadMapA,
1362
+ AccessTypeA
1363
+ >
1364
+ >;
1365
+
1366
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1367
+
1368
+ // Define iterators over tiles from the B operand
1369
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1370
+ using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
1371
+ using IteratorB =
1372
+ cutlass::conv::threadblock::TileIterator<
1373
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
1374
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1375
+ ElementB,
1376
+ LayoutB,
1377
+ ThreadMapB,
1378
+ AccessTypeB
1379
+ >
1380
+ >;
1381
+
1382
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1383
+
1384
+ // Warp-level GEMM components
1385
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
1386
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1387
+
1388
+ // Define the Mma
1389
+ using Mma = threadblock::ImplicitGemmPipelined<
1390
+ ThreadblockShape,
1391
+ IteratorA,
1392
+ SmemIteratorA,
1393
+ IteratorB,
1394
+ SmemIteratorB,
1395
+ ElementC,
1396
+ LayoutC,
1397
+ MmaPolicy
1398
+ >;
1399
+
1400
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
1401
+
1402
+ // Define the epilogue
1403
+ using Epilogue = typename detail::DefaultConvEpilogue<
1404
+ ArchTag,
1405
+ ThreadblockShape,
1406
+ WarpMmaTensorOp,
1407
+ kPartitionsK,
1408
+ EpilogueOutputOp
1409
+ >::Epilogue;
1410
+
1411
+ // Define the kernel
1412
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1413
+ Mma,
1414
+ Epilogue,
1415
+ ThreadblockSwizzle,
1416
+ conv::Operator::kFprop
1417
+ >;
1418
+ };
1419
+
1420
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1421
+
1422
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
1423
+ /// pipeline with interleaved layout.
1424
+ template <
1425
+ typename ElementA,
1426
+ typename ElementB,
1427
+ typename ElementC,
1428
+ typename LayoutC,
1429
+ typename ElementAccumulator,
1430
+ typename ArchTag,
1431
+ typename ThreadblockShape,
1432
+ typename WarpShape,
1433
+ typename InstructionShape,
1434
+ typename EpilogueOutputOp,
1435
+ typename ThreadblockSwizzle,
1436
+ typename MathOperatorTag,
1437
+ conv::StrideSupport StrideSupport,
1438
+ int AlignmentA,
1439
+ int AlignmentB,
1440
+ int InterleavedK
1441
+ >
1442
+ struct DefaultConv2dFprop <
1443
+ ElementA,
1444
+ layout::TensorNCxHWx<InterleavedK>,
1445
+ ElementB,
1446
+ layout::TensorCxRSKx<InterleavedK>,
1447
+ ElementC,
1448
+ LayoutC,
1449
+ ElementAccumulator,
1450
+ arch::OpClassTensorOp,
1451
+ ArchTag,
1452
+ ThreadblockShape,
1453
+ WarpShape,
1454
+ InstructionShape,
1455
+ EpilogueOutputOp,
1456
+ ThreadblockSwizzle,
1457
+ 2,
1458
+ MathOperatorTag,
1459
+ IteratorAlgorithm::kOptimized,
1460
+ StrideSupport,
1461
+ AlignmentA,
1462
+ AlignmentB
1463
+ > {
1464
+
1465
+ // Define the core components from GEMM
1466
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1467
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
1468
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
1469
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
1470
+ 2, MathOperatorTag, true>;
1471
+
1472
+ // Define iterators over tiles from the A operand
1473
+ using ThreadMapA = typename MmaCore::SmemThreadMapA;
1474
+ using IteratorA =
1475
+ cutlass::conv::threadblock::TileIterator<
1476
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
1477
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1478
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
1479
+ ThreadMapA
1480
+ >
1481
+ >;
1482
+
1483
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1484
+
1485
+ // Define iterators over tiles from the B operand
1486
+ using ThreadMapB = typename MmaCore::SmemThreadMapB;
1487
+ using IteratorB =
1488
+ cutlass::conv::threadblock::TileIterator<
1489
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
1490
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1491
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
1492
+ ThreadMapB
1493
+ >
1494
+ >;
1495
+
1496
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1497
+
1498
+ // Warp-level GEMM components
1499
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
1500
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1501
+
1502
+ // Define the Mma
1503
+ using Mma = threadblock::ImplicitGemmPipelined<
1504
+ ThreadblockShape,
1505
+ IteratorA,
1506
+ SmemIteratorA,
1507
+ IteratorB,
1508
+ SmemIteratorB,
1509
+ ElementC,
1510
+ LayoutC,
1511
+ MmaPolicy
1512
+ >;
1513
+
1514
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
1515
+
1516
+ // Define the epilogue
1517
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
1518
+ ThreadblockShape,
1519
+ WarpMmaTensorOp,
1520
+ kPartitionsK,
1521
+ EpilogueOutputOp,
1522
+ EpilogueOutputOp::kCount,
1523
+ InterleavedK
1524
+ >::Epilogue;
1525
+
1526
+ // Define the kernel
1527
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1528
+ Mma,
1529
+ Epilogue,
1530
+ ThreadblockSwizzle,
1531
+ conv::Operator::kFprop
1532
+ >;
1533
+ };
1534
+
1535
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1536
+ // OpClassSimt convolutions
1537
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1538
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm,
1539
+ /// multi-stage pipeline, and FFMA-based mainloop for SM80
1540
+
1541
+ template <
1542
+ typename ElementA,
1543
+ typename LayoutA,
1544
+ typename ElementB,
1545
+ typename LayoutB,
1546
+ typename ElementC,
1547
+ typename LayoutC,
1548
+ typename ElementAccumulator,
1549
+ typename ArchTag,
1550
+ typename ThreadblockShape,
1551
+ typename WarpShape,
1552
+ typename InstructionShape,
1553
+ typename EpilogueOutputOp,
1554
+ typename ThreadblockSwizzle,
1555
+ int Stages,
1556
+ typename MathOperatorTag,
1557
+ conv::StrideSupport StrideSupport,
1558
+ int AlignmentA,
1559
+ int AlignmentB
1560
+ >
1561
+ struct DefaultConv2dFprop <
1562
+ ElementA,
1563
+ LayoutA,
1564
+ ElementB,
1565
+ LayoutB,
1566
+ ElementC,
1567
+ LayoutC,
1568
+ ElementAccumulator,
1569
+ arch::OpClassSimt,
1570
+ ArchTag,
1571
+ ThreadblockShape,
1572
+ WarpShape,
1573
+ InstructionShape,
1574
+ EpilogueOutputOp,
1575
+ ThreadblockSwizzle,
1576
+ Stages,
1577
+ MathOperatorTag,
1578
+ IteratorAlgorithm::kAnalytic,
1579
+ StrideSupport,
1580
+ AlignmentA,
1581
+ AlignmentB
1582
+ > {
1583
+
1584
+ // Define the core components from GEMM
1585
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1586
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1587
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1588
+ Stages, MathOperatorTag>;
1589
+
1590
+ // Define iterators over tiles from the A operand
1591
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1592
+ using IteratorA =
1593
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
1594
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1595
+ ElementA, LayoutA,
1596
+ ThreadMapA
1597
+ >;
1598
+
1599
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1600
+
1601
+ // Define iterators over tiles from the B operand
1602
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1603
+ using IteratorB =
1604
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
1605
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1606
+ ElementB, LayoutB,
1607
+ ThreadMapB
1608
+ >;
1609
+
1610
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1611
+
1612
+ // Warp-level GEMM components
1613
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1614
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1615
+
1616
+ // Define the Mma
1617
+ using Mma = threadblock::ImplicitGemmMultistage<
1618
+ ThreadblockShape,
1619
+ IteratorA,
1620
+ SmemIteratorA,
1621
+ arch::CacheOperation::Always,
1622
+ IteratorB,
1623
+ SmemIteratorB,
1624
+ arch::CacheOperation::Always,
1625
+ MmaPolicy,
1626
+ Stages
1627
+ >;
1628
+
1629
+ // Define the epilogue
1630
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1631
+ ThreadblockShape,
1632
+ WarpMmaSimtOp,
1633
+ EpilogueOutputOp,
1634
+ EpilogueOutputOp::kCount,
1635
+ false,
1636
+ layout::NoPermute,
1637
+ StrideSupport,
1638
+ 4
1639
+ >::Epilogue;
1640
+
1641
+ // Define the kernel
1642
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1643
+ Mma,
1644
+ Epilogue,
1645
+ ThreadblockSwizzle,
1646
+ conv::Operator::kFprop
1647
+ >;
1648
+
1649
+ };
1650
+
1651
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1652
+
1653
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm,
1654
+ /// multi-stage pipeline, and FFMA-based mainloop for SM80
1655
+
1656
+ template <
1657
+ typename ElementA,
1658
+ typename LayoutA,
1659
+ typename ElementB,
1660
+ typename LayoutB,
1661
+ typename ElementC,
1662
+ typename LayoutC,
1663
+ typename ElementAccumulator,
1664
+ typename ArchTag,
1665
+ typename ThreadblockShape,
1666
+ typename WarpShape,
1667
+ typename InstructionShape,
1668
+ typename EpilogueOutputOp,
1669
+ typename ThreadblockSwizzle,
1670
+ int Stages,
1671
+ typename MathOperatorTag,
1672
+ conv::StrideSupport StrideSupport,
1673
+ int AlignmentA,
1674
+ int AlignmentB
1675
+ >
1676
+ struct DefaultConv2dFprop <
1677
+ ElementA,
1678
+ LayoutA,
1679
+ ElementB,
1680
+ LayoutB,
1681
+ ElementC,
1682
+ LayoutC,
1683
+ ElementAccumulator,
1684
+ arch::OpClassSimt,
1685
+ ArchTag,
1686
+ ThreadblockShape,
1687
+ WarpShape,
1688
+ InstructionShape,
1689
+ EpilogueOutputOp,
1690
+ ThreadblockSwizzle,
1691
+ Stages,
1692
+ MathOperatorTag,
1693
+ IteratorAlgorithm::kOptimized,
1694
+ StrideSupport,
1695
+ AlignmentA,
1696
+ AlignmentB
1697
+ > {
1698
+
1699
+ // Define the core components from GEMM
1700
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1701
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1702
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1703
+ Stages, MathOperatorTag>;
1704
+
1705
+ // Define iterators over tiles from the A operand
1706
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1707
+ using IteratorA =
1708
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
1709
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1710
+ ElementA,
1711
+ LayoutA,
1712
+ ThreadMapA
1713
+ >;
1714
+
1715
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1716
+
1717
+ // Define iterators over tiles from the B operand
1718
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1719
+ using IteratorB =
1720
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
1721
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1722
+ ElementB,
1723
+ LayoutB,
1724
+ ThreadMapB
1725
+ >;
1726
+
1727
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1728
+
1729
+ // Warp-level GEMM components
1730
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1731
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1732
+
1733
+ // Define the Mma
1734
+ using Mma = threadblock::ImplicitGemmMultistage<
1735
+ ThreadblockShape,
1736
+ IteratorA,
1737
+ SmemIteratorA,
1738
+ arch::CacheOperation::Always,
1739
+ IteratorB,
1740
+ SmemIteratorB,
1741
+ arch::CacheOperation::Always,
1742
+ MmaPolicy,
1743
+ Stages
1744
+ >;
1745
+
1746
+ // Define the epilogue
1747
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1748
+ ThreadblockShape,
1749
+ WarpMmaSimtOp,
1750
+ EpilogueOutputOp,
1751
+ EpilogueOutputOp::kCount,
1752
+ false,
1753
+ layout::NoPermute,
1754
+ StrideSupport,
1755
+ 4
1756
+ >::Epilogue;
1757
+
1758
+ // Define the kernel
1759
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1760
+ Mma,
1761
+ Epilogue,
1762
+ ThreadblockSwizzle,
1763
+ conv::Operator::kFprop
1764
+ >;
1765
+ };
1766
+
1767
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1768
+
1769
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm,
1770
+ /// 2 stage pipeline, and FFMA-based mainloop for SM50
1771
+ template <
1772
+ typename ElementA,
1773
+ typename LayoutA,
1774
+ typename ElementB,
1775
+ typename LayoutB,
1776
+ typename ElementC,
1777
+ typename LayoutC,
1778
+ typename ElementAccumulator,
1779
+ typename ArchTag,
1780
+ typename ThreadblockShape,
1781
+ typename WarpShape,
1782
+ typename InstructionShape,
1783
+ typename EpilogueOutputOp,
1784
+ typename ThreadblockSwizzle,
1785
+ typename MathOperatorTag,
1786
+ conv::StrideSupport StrideSupport,
1787
+ int AlignmentA,
1788
+ int AlignmentB
1789
+ >
1790
+ struct DefaultConv2dFprop <
1791
+ ElementA,
1792
+ LayoutA,
1793
+ ElementB,
1794
+ LayoutB,
1795
+ ElementC,
1796
+ LayoutC,
1797
+ ElementAccumulator,
1798
+ arch::OpClassSimt,
1799
+ ArchTag,
1800
+ ThreadblockShape,
1801
+ WarpShape,
1802
+ InstructionShape,
1803
+ EpilogueOutputOp,
1804
+ ThreadblockSwizzle,
1805
+ 2,
1806
+ MathOperatorTag,
1807
+ IteratorAlgorithm::kAnalytic,
1808
+ StrideSupport,
1809
+ AlignmentA,
1810
+ AlignmentB
1811
+ > {
1812
+
1813
+ // Define the core components from GEMM
1814
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1815
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1816
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1817
+ 2, MathOperatorTag>;
1818
+
1819
+ // Define iterators over tiles from the A operand
1820
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1821
+ using IteratorA =
1822
+ cutlass::conv::threadblock::TileIterator<
1823
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
1824
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1825
+ ElementA, LayoutA,
1826
+ ThreadMapA
1827
+ >
1828
+ >;
1829
+
1830
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1831
+
1832
+ // Define iterators over tiles from the B operand
1833
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1834
+ using IteratorB =
1835
+ cutlass::conv::threadblock::TileIterator<
1836
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
1837
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1838
+ ElementB, LayoutB,
1839
+ ThreadMapB
1840
+ >
1841
+ >;
1842
+
1843
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1844
+
1845
+ // Warp-level GEMM components
1846
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1847
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1848
+
1849
+ // Define the Mma
1850
+ using Mma = threadblock::ImplicitGemmPipelined<
1851
+ ThreadblockShape,
1852
+ IteratorA,
1853
+ SmemIteratorA,
1854
+ IteratorB,
1855
+ SmemIteratorB,
1856
+ ElementC,
1857
+ LayoutC,
1858
+ MmaPolicy
1859
+ >;
1860
+
1861
+ // Define the epilogue
1862
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1863
+ ThreadblockShape,
1864
+ WarpMmaSimtOp,
1865
+ EpilogueOutputOp,
1866
+ EpilogueOutputOp::kCount,
1867
+ false,
1868
+ layout::NoPermute,
1869
+ StrideSupport,
1870
+ 4
1871
+ >::Epilogue;
1872
+
1873
+ // Define the kernel
1874
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1875
+ Mma,
1876
+ Epilogue,
1877
+ ThreadblockSwizzle,
1878
+ conv::Operator::kFprop
1879
+ >;
1880
+
1881
+ };
1882
+
1883
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1884
+
1885
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm,
1886
+ /// 2 stage pipeline, and FFMA-based mainloop for SM50
1887
+ template <
1888
+ typename ElementA,
1889
+ typename LayoutA,
1890
+ typename ElementB,
1891
+ typename LayoutB,
1892
+ typename ElementC,
1893
+ typename LayoutC,
1894
+ typename ElementAccumulator,
1895
+ typename ArchTag,
1896
+ typename ThreadblockShape,
1897
+ typename WarpShape,
1898
+ typename InstructionShape,
1899
+ typename EpilogueOutputOp,
1900
+ typename ThreadblockSwizzle,
1901
+ typename MathOperatorTag,
1902
+ conv::StrideSupport StrideSupport,
1903
+ int AlignmentA,
1904
+ int AlignmentB
1905
+ >
1906
+ struct DefaultConv2dFprop <
1907
+ ElementA,
1908
+ LayoutA,
1909
+ ElementB,
1910
+ LayoutB,
1911
+ ElementC,
1912
+ LayoutC,
1913
+ ElementAccumulator,
1914
+ arch::OpClassSimt,
1915
+ ArchTag,
1916
+ ThreadblockShape,
1917
+ WarpShape,
1918
+ InstructionShape,
1919
+ EpilogueOutputOp,
1920
+ ThreadblockSwizzle,
1921
+ 2,
1922
+ MathOperatorTag,
1923
+ IteratorAlgorithm::kOptimized,
1924
+ StrideSupport,
1925
+ AlignmentA,
1926
+ AlignmentB
1927
+ > {
1928
+
1929
+ // Define the core components from GEMM
1930
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
1931
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
1932
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
1933
+ 2, MathOperatorTag>;
1934
+
1935
+ // Define iterators over tiles from the A operand
1936
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
1937
+ using IteratorA =
1938
+ cutlass::conv::threadblock::TileIterator<
1939
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
1940
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
1941
+ ElementA,
1942
+ LayoutA,
1943
+ ThreadMapA
1944
+ >
1945
+ >;
1946
+
1947
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
1948
+
1949
+ // Define iterators over tiles from the B operand
1950
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
1951
+ using IteratorB =
1952
+ cutlass::conv::threadblock::TileIterator<
1953
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
1954
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
1955
+ ElementB,
1956
+ LayoutB,
1957
+ ThreadMapB
1958
+ >
1959
+ >;
1960
+
1961
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
1962
+
1963
+ // Warp-level GEMM components
1964
+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
1965
+ using MmaPolicy = typename MmaCore::MmaPolicy;
1966
+
1967
+ // Define the Mma
1968
+ using Mma = threadblock::ImplicitGemmPipelined<
1969
+ ThreadblockShape,
1970
+ IteratorA,
1971
+ SmemIteratorA,
1972
+ IteratorB,
1973
+ SmemIteratorB,
1974
+ ElementC,
1975
+ LayoutC,
1976
+ MmaPolicy
1977
+ >;
1978
+
1979
+ // Define the epilogue
1980
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
1981
+ ThreadblockShape,
1982
+ WarpMmaSimtOp,
1983
+ EpilogueOutputOp,
1984
+ EpilogueOutputOp::kCount,
1985
+ false,
1986
+ layout::NoPermute,
1987
+ StrideSupport,
1988
+ 4
1989
+ >::Epilogue;
1990
+
1991
+ // Define the kernel
1992
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
1993
+ Mma,
1994
+ Epilogue,
1995
+ ThreadblockSwizzle,
1996
+ conv::Operator::kFprop
1997
+ >;
1998
+
1999
+ };
2000
+
2001
+ /////////////////////////////////////////////////////////////////////////////////////////////////
2002
+
2003
+ } // namespace kernel
2004
+ } // namespace conv
2005
+ } // namespace cutlass
2006
+
2007
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+ Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution
34
+ definitions that combine threadblock-scoped matrix multiply-add with the
35
+ appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
47
+ #include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
48
+ #include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h"
49
+ #include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace cutlass {
54
+ namespace conv {
55
+ namespace kernel {
56
+
57
+ /////////////////////////////////////////////////////////////////////////////////////////////////
58
+ /// Defines a kernel for fused batch norm and Conv2dFprop
59
+ template <
60
+ typename ElementA,
61
+ typename LayoutA,
62
+ typename ElementB,
63
+ typename LayoutB,
64
+ typename ElementScaleBias,
65
+ typename LayoutScaleBias,
66
+ typename ElementC,
67
+ typename LayoutC,
68
+ typename ElementAccumulator,
69
+ typename OperatorClass,
70
+ typename ArchTag,
71
+ typename ThreadblockShape,
72
+ typename WarpShape,
73
+ typename InstructionShape,
74
+ typename EpilogueOutputOp,
75
+ typename ThreadblockSwizzle,
76
+ int Stages,
77
+ typename MathOperatorTag,
78
+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
79
+ conv::StrideSupport StrideSupport = StrideSupport::kUnity
80
+ > struct DefaultConv2dFpropFusion;
81
+
82
+ /////////////////////////////////////////////////////////////////////////////////////////////////
83
+ // OpClassTensorOp convolutions
84
+ /////////////////////////////////////////////////////////////////////////////////////////////////
85
+
86
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
87
+ /// pipeline.
88
+ template <
89
+ typename ElementA,
90
+ typename LayoutA,
91
+ typename ElementB,
92
+ typename LayoutB,
93
+ typename ElementScaleBias,
94
+ typename LayoutScaleBias,
95
+ typename ElementC,
96
+ typename LayoutC,
97
+ typename ElementAccumulator,
98
+ typename ArchTag,
99
+ typename ThreadblockShape,
100
+ typename WarpShape,
101
+ typename InstructionShape,
102
+ typename EpilogueOutputOp,
103
+ typename ThreadblockSwizzle,
104
+ int Stages,
105
+ typename MathOperatorTag
106
+ >
107
+ struct DefaultConv2dFpropFusion <
108
+ ElementA,
109
+ LayoutA,
110
+ ElementB,
111
+ LayoutB,
112
+ ElementScaleBias,
113
+ LayoutScaleBias,
114
+ ElementC,
115
+ LayoutC,
116
+ ElementAccumulator,
117
+ arch::OpClassTensorOp,
118
+ ArchTag,
119
+ ThreadblockShape,
120
+ WarpShape,
121
+ InstructionShape,
122
+ EpilogueOutputOp,
123
+ ThreadblockSwizzle,
124
+ Stages,
125
+ MathOperatorTag,
126
+ IteratorAlgorithm::kAnalytic
127
+ > {
128
+
129
+ // Define the core components from GEMM
130
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
131
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
132
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
133
+ Stages, MathOperatorTag>;
134
+
135
+ // Define iterators over tiles from the A operand
136
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
137
+ using IteratorA =
138
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
139
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
140
+ ElementA, LayoutA,
141
+ ThreadMapA
142
+ >;
143
+
144
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
145
+
146
+ // Define iterators over tiles from the B operand
147
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
148
+ using IteratorB =
149
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
150
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
151
+ ElementB, LayoutB,
152
+ ThreadMapB
153
+ >;
154
+
155
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
156
+
157
+ /// Define iterators over tiles from scale/bias vectors
158
+ using IteratorScaleBias =
159
+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
160
+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
161
+ LayoutScaleBias>;
162
+
163
+ using SmemIteratorScaleBias =
164
+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
165
+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
166
+ LayoutScaleBias>;
167
+
168
+ // Warp-level GEMM components
169
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
170
+ using MmaPolicy = typename MmaCore::MmaPolicy;
171
+
172
+ static int const kThreadCount = 32;
173
+
174
+ // Warp-level iterators to load scale and bias vectors
175
+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
176
+ MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
177
+ LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
178
+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
179
+ MmaCore::WarpCount::kK>;
180
+
181
+ // Define the Mma
182
+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
183
+ ThreadblockShape,
184
+ IteratorA,
185
+ SmemIteratorA,
186
+ arch::CacheOperation::Always,
187
+ IteratorB,
188
+ SmemIteratorB,
189
+ arch::CacheOperation::Global,
190
+ IteratorScaleBias,
191
+ SmemIteratorScaleBias,
192
+ arch::CacheOperation::Always,
193
+ MmaPolicy,
194
+ WarpIteratorScaleBias,
195
+ Stages
196
+ >;
197
+
198
+ // Define the epilogue
199
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
200
+ ThreadblockShape,
201
+ WarpMmaTensorOp,
202
+ 1,
203
+ EpilogueOutputOp,
204
+ EpilogueOutputOp::kCount
205
+ >::Epilogue;
206
+
207
+ // Define the kernel
208
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
209
+ Mma,
210
+ Epilogue,
211
+ ThreadblockSwizzle,
212
+ conv::Operator::kFprop
213
+ >;
214
+ };
215
+
216
+ /////////////////////////////////////////////////////////////////////////////////////////////////
217
+
218
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
219
+ /// multistage pipeline.
220
+ template <
221
+ typename ElementA,
222
+ typename LayoutA,
223
+ typename ElementB,
224
+ typename LayoutB,
225
+ typename ElementScaleBias,
226
+ typename LayoutScaleBias,
227
+ typename ElementC,
228
+ typename LayoutC,
229
+ typename ElementAccumulator,
230
+ typename ArchTag,
231
+ typename ThreadblockShape,
232
+ typename WarpShape,
233
+ typename InstructionShape,
234
+ typename EpilogueOutputOp,
235
+ typename ThreadblockSwizzle,
236
+ int Stages,
237
+ typename MathOperatorTag
238
+ >
239
+ struct DefaultConv2dFpropFusion <
240
+ ElementA,
241
+ LayoutA,
242
+ ElementB,
243
+ LayoutB,
244
+ ElementScaleBias,
245
+ LayoutScaleBias,
246
+ ElementC,
247
+ LayoutC,
248
+ ElementAccumulator,
249
+ arch::OpClassTensorOp,
250
+ ArchTag,
251
+ ThreadblockShape,
252
+ WarpShape,
253
+ InstructionShape,
254
+ EpilogueOutputOp,
255
+ ThreadblockSwizzle,
256
+ Stages,
257
+ MathOperatorTag,
258
+ IteratorAlgorithm::kOptimized
259
+ > {
260
+
261
+ // Define the core components from GEMM
262
+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
263
+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
264
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
265
+ Stages, MathOperatorTag
266
+ >;
267
+
268
+ // Define iterators over tiles from the A operand
269
+ using ThreadMapA = typename MmaCore::IteratorThreadMapA;
270
+ using IteratorA =
271
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
272
+ cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
273
+ ElementA,
274
+ LayoutA,
275
+ ThreadMapA
276
+ >;
277
+
278
+ using SmemIteratorA = typename MmaCore::SmemIteratorA;
279
+
280
+ // Define iterators over tiles from the B operand
281
+ using ThreadMapB = typename MmaCore::IteratorThreadMapB;
282
+ using IteratorB =
283
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
284
+ cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
285
+ ElementB,
286
+ LayoutB,
287
+ ThreadMapB
288
+ >;
289
+
290
+ using SmemIteratorB = typename MmaCore::SmemIteratorB;
291
+
292
+ /// Define iterators over tiles from scale/bias vectors
293
+ using IteratorScaleBias =
294
+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
295
+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
296
+ LayoutScaleBias>;
297
+
298
+ using SmemIteratorScaleBias =
299
+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
300
+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
301
+ LayoutScaleBias>;
302
+
303
+ // Warp-level GEMM components
304
+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
305
+ using MmaPolicy = typename MmaCore::MmaPolicy;
306
+
307
+ static int const kThreadCount = 32;
308
+
309
+ // Warp-level iterators to load scale and bias vectors
310
+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
311
+ MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
312
+ LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
313
+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
314
+ MmaCore::WarpCount::kK>;
315
+
316
+ // Define the Mma
317
+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
318
+ ThreadblockShape,
319
+ IteratorA,
320
+ SmemIteratorA,
321
+ arch::CacheOperation::Always,
322
+ IteratorB,
323
+ SmemIteratorB,
324
+ arch::CacheOperation::Global,
325
+ IteratorScaleBias,
326
+ SmemIteratorScaleBias,
327
+ arch::CacheOperation::Always,
328
+ MmaPolicy,
329
+ WarpIteratorScaleBias,
330
+ Stages
331
+ >;
332
+
333
+ // Define the epilogue
334
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
335
+ ThreadblockShape,
336
+ WarpMmaTensorOp,
337
+ 1,
338
+ EpilogueOutputOp,
339
+ EpilogueOutputOp::kCount
340
+ >::Epilogue;
341
+
342
+ // Define the kernel
343
+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
344
+ Mma,
345
+ Epilogue,
346
+ ThreadblockSwizzle,
347
+ conv::Operator::kFprop
348
+ >;
349
+ };
350
+
351
+ /////////////////////////////////////////////////////////////////////////////////////////////////
352
+
353
+ } // namespace kernel
354
+ } // namespace conv
355
+ } // namespace cutlass
356
+
357
+ /////////////////////////////////////////////////////////////////////////////////////////////////