kernels-bot commited on
Commit
92f2707
·
verified ·
1 Parent(s): 96ee59b

Uploaded using `kernel-builder` (batch 11/32).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h +0 -2119
  2. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h +0 -55
  3. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h +0 -283
  4. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h +0 -0
  5. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h +0 -1322
  6. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py +0 -144
  7. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h +0 -90
  8. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h +0 -154
  9. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h +0 -113
  10. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h +0 -213
  11. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h +0 -311
  12. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h +0 -189
  13. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h +0 -427
  14. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py +0 -129
  15. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py +0 -131
  16. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py +0 -120
  17. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py +0 -469
  18. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py +0 -249
  19. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py +0 -476
  20. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py +0 -232
  21. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py +0 -1013
  22. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py +0 -456
  23. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py +0 -92
  24. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py +0 -135
  25. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py +0 -67
  26. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h +0 -292
  27. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h +0 -94
  28. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h +0 -499
  29. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h +0 -52
  30. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h +0 -938
  31. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h +0 -545
  32. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h +0 -95
  33. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h +0 -150
  34. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h +0 -424
  35. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h +0 -232
  36. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h +0 -775
  37. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh +0 -139
  38. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +0 -421
  39. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh +0 -136
  40. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp +0 -222
  41. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh +0 -92
  42. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp +0 -274
  43. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp +0 -129
  44. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp +0 -246
  45. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h +0 -320
  46. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp +0 -242
  47. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp +0 -61
  48. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp +0 -871
  49. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp +0 -117
  50. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp +0 -561
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h DELETED
@@ -1,2119 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Templates implementing loading of tiles from pitch-linear rank=2
33
- tensors.
34
-
35
- This iterator uses masks to guard out-of-bounds accesses. The first tile
36
- this iterator visits maybe partial, then the remaining tiles are complete.
37
- So, we only need to compute the predicates twice, once before the first tile
38
- and once for the remaining full tiles which can share the same predicates.
39
-
40
- A precomputed "Params" object minimizes the amount of state that must be
41
- stored in registers, and integer addition is used to advance the pointer
42
- through memory.
43
- */
44
-
45
- #pragma once
46
-
47
- #include "cutlass/arch/memory.h"
48
- #include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
49
-
50
- ////////////////////////////////////////////////////////////////////////////////
51
-
52
- namespace cutlass {
53
- namespace transform {
54
- namespace threadblock {
55
-
56
- ////////////////////////////////////////////////////////////////////////////////
57
-
58
- /// PredicatedTileIteratorResidualLast
59
- ///
60
- /// Satisfies: ForwardTileIteratorConcept |
61
- /// ReadableContiguousTileIteratorConcept |
62
- /// WriteableContiguousTileIteratorConcept |
63
- /// MaskedTileIteratorConcept
64
- ///
65
- /// Regular tile iterator using a precomputed control structure to minimize
66
- /// register liveness and integer arithmetic.
67
- ///
68
- /// Layout is assumed to be invariant at the time the precomputed "Params"
69
- /// object is constructed.
70
- ///
71
- /// Base pointer and tensor extents may be specified at the time the iterator is
72
- /// constructed. Subsequently, they are assumed to be immutable.
73
- ///
74
- /// Adding a logical coordinate offset may be performed at the time the iterator
75
- /// is constructed. Subsequent additions to logical coordinate offset may be
76
- /// performed but are relatively expensive.
77
- ///
78
- /// Visitation order is intended to first visit a "residual" tile that may be
79
- /// partially full in both the advance dimension and the steady-state dimension.
80
- /// This is assumed to be the last tile in the iteration sequence. Advancing an
81
- /// iterator that has just been constructed moves to the first tile that is full
82
- /// in the advance dimension and recomputes predicates. Subsequent accesses may
83
- /// be performed without updating internal predicates and are efficient in terms
84
- /// of live register state and pointer arithmetic instructions.
85
- ///
86
- /// To be efficient, this assumes the iterator will be dereferenced and advanced
87
- /// at least once outside any looping structure to minimize integer arithmetic.
88
- ///
89
- /// Accesses out of bounds are safe so long as `clear_mask()` is called prior to
90
- /// dereferencing the iterator.
91
- ///
92
- ///
93
- /// Example:
94
- ///
95
- /// An efficient pipeline structure may be constructed as follows:
96
- ///
97
- // template <typename Iterator>
98
- // __global__ void kernel(
99
- // typename Iterator::Params params,
100
- // typename Iterator::Element *ptr,
101
- // TensorCoord extent) {
102
- //
103
- // typename Iterator::Fragment fragment;
104
- //
105
- // TensorCoord threadblock_offset(0, 0);
106
- //
107
- // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
108
- //
109
- //
110
- // fragment = *iter; // load "residue" tile first
111
- // ++iter; // advance to first "steady state" tile and update
112
- // internal masks
113
- //
114
- //
115
- // #pragma unroll
116
- // for (int i = Remaining - 1; i >= 0; --i) {
117
- //
118
- // f(fragment);
119
- //
120
- // if (!i) {
121
- // iter.clear_mask(); // light-weight operation to clear masks -
122
- // subsequent loads become NO-OPs.
123
- // }
124
- //
125
- // fragment = *iter; // load tile during "steady state" phase
126
- // ++iter; // advance to next tile - lightweight due to
127
- // steady-state masks
128
- // }
129
- // }
130
- //
131
- // void host(TensorView<Element, 2, layout::PitchLinear> view) {
132
- //
133
- // using Iterator =
134
- // transform::threadblock::PredicatedTileIteratorResidualLast;
135
- //
136
- // typename Iterator::Params params(view.layout());
137
- //
138
- // kernel<Iterator>(params, view.data());
139
- // }
140
- ///
141
- ///
142
- template <
143
- typename Shape,
144
- typename Element,
145
- typename Layout,
146
- int AdvanceRank,
147
- typename ThreadMap,
148
- int AccessSize = ThreadMap::kElementsPerAccess,
149
- bool Gather = false>
150
- class PredicatedTileIteratorResidualLast;
151
-
152
- ////////////////////////////////////////////////////////////////////////////////
153
-
154
- /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
155
- ///
156
- /// Satisfies: ForwardTileIteratorConcept |
157
- /// ReadableContiguousTileIteratorConcept |
158
- /// WriteableContiguousTileIteratorConcept |
159
- /// MaskedTileIteratorConcept
160
- ///
161
- template <
162
- typename Shape_,
163
- typename Element_,
164
- int AdvanceRank,
165
- typename ThreadMap_,
166
- int AccessSize,
167
- bool Gather>
168
- class PredicatedTileIteratorResidualLast<
169
- Shape_,
170
- Element_,
171
- layout::PitchLinear,
172
- AdvanceRank,
173
- ThreadMap_,
174
- AccessSize,
175
- Gather> {
176
- public:
177
- static_assert(
178
- AdvanceRank == 0 || AdvanceRank == 1,
179
- "Specialization for pitch-linear iterator may advance along the "
180
- "contiguous(rank=0) or strided(rank=1) dimension.");
181
-
182
- using Shape = Shape_;
183
- using Element = Element_;
184
- using Layout = layout::PitchLinear;
185
- static int const kAdvanceRank = AdvanceRank;
186
- using ThreadMap = ThreadMap_;
187
-
188
- using Index = typename Layout::Index;
189
- using LongIndex = typename Layout::LongIndex;
190
-
191
- using TensorRef = TensorRef<Element, Layout>;
192
- using TensorView = TensorView<Element, Layout>;
193
- using TensorCoord = typename Layout::TensorCoord;
194
-
195
- using Pointer = Element*;
196
- using NonConstPointer = typename platform::remove_const<Element>::type*;
197
-
198
- /// Type used for internal memory accesses
199
- using AccessType = AlignedArray<
200
- Element,
201
- AccessSize,
202
- (AccessSize * sizeof_bits<Element>::value / 8)>;
203
-
204
- /// Underlying iterator to compute the addresses
205
- using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
206
- Shape,
207
- Element,
208
- Layout,
209
- kAdvanceRank,
210
- ThreadMap,
211
- AccessType,
212
- Gather>;
213
-
214
- static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
215
-
216
- /// Fragment object to be loaded or stored
217
- using Fragment = cutlass::Array<
218
- Element,
219
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
220
-
221
- /// Predicate vector stores mask to guard accesses
222
- using Mask = typename TileAccessIterator::Mask;
223
-
224
- /// Parameters object is precomputed state and is host-constructible
225
- class Params {
226
- public:
227
- using Base = typename TileAccessIterator::Params::Base;
228
-
229
- friend PredicatedTileIteratorResidualLast;
230
-
231
- private:
232
- /// Parameters object
233
- typename TileAccessIterator::Params params_;
234
-
235
- public:
236
- /// Construct the Params object given a pitch-linear tensor's layout
237
- CUTLASS_HOST_DEVICE
238
- Params(Layout const& layout) : params_(layout) {}
239
-
240
- CUTLASS_HOST_DEVICE
241
- Params() {}
242
-
243
- CUTLASS_HOST_DEVICE
244
- Params(Base const& base) : params_(base) {}
245
- };
246
-
247
- private:
248
- /// Internal pointer type permits fast address arithmetic
249
- using BytePointer = char*;
250
-
251
- private:
252
- //
253
- // Data members
254
- //
255
-
256
- /// Data member to the tile access iterator
257
- TileAccessIterator address_iterator_;
258
-
259
- public:
260
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
261
- /// and thread ID
262
- CUTLASS_HOST_DEVICE
263
- PredicatedTileIteratorResidualLast(
264
- /// Precomputed parameters object
265
- Params const& params,
266
- /// Pointer to start of tensor
267
- Pointer pointer,
268
- /// Extent of tensor
269
- TensorCoord extent,
270
- /// ID of each participating thread
271
- int thread_id,
272
- /// Initial offset of threadblock
273
- TensorCoord const& threadblock_offset,
274
- /// Gather indices
275
- int const* indices = nullptr)
276
- : address_iterator_(
277
- params.params_,
278
- pointer,
279
- extent,
280
- thread_id,
281
- threadblock_offset,
282
- indices) {}
283
-
284
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
285
- /// offset
286
- CUTLASS_HOST_DEVICE
287
- PredicatedTileIteratorResidualLast(
288
- Params const& params, ///< Precomputed parameters object
289
- Pointer pointer, ///< Pointer to start of tensor
290
- TensorCoord extent, ///< Extent of tensor
291
- int thread_id ///< ID of each participating thread
292
- )
293
- : PredicatedTileIteratorResidualLast(
294
- params,
295
- pointer,
296
- extent,
297
- thread_id,
298
- make_Coord(0, 0)) {}
299
-
300
- /// Adds a pointer offset in units of Element
301
- CUTLASS_HOST_DEVICE
302
- void add_pointer_offset(LongIndex pointer_offset) {
303
- address_iterator_.add_pointer_offset(pointer_offset);
304
- }
305
-
306
- /// Advances to the next tile in memory.
307
- ///
308
- /// The first time this method is called, predicates are updated, and the
309
- /// iterator's internal pointer is reverted to the first "steady state" tile.
310
- /// Subsequent calls are lightweight and must only update the internal
311
- /// pointer.
312
- CUTLASS_HOST_DEVICE
313
- PredicatedTileIteratorResidualLast& operator++() {
314
- if (kAdvanceRank)
315
- address_iterator_.add_tile_offset({0, 1});
316
- else
317
- address_iterator_.add_tile_offset({1, 0});
318
-
319
- return *this;
320
- }
321
-
322
- /// Advances to the next tile in memory.
323
- ///
324
- /// The first time this method is called, predicates are updated, and the
325
- /// iterator's internal pointer is reverted to the first "steady state" tile.
326
- /// Subsequent calls are lightweight and must only update the internal
327
- /// pointer.
328
- CUTLASS_HOST_DEVICE
329
- PredicatedTileIteratorResidualLast operator++(int) {
330
- PredicatedTileIteratorResidualLast self(*this);
331
- operator++();
332
- return self;
333
- }
334
-
335
- /// Clears the predicate set efficiently
336
- CUTLASS_HOST_DEVICE
337
- void clear_mask(bool enable = true) {
338
- address_iterator_.clear_mask(enable);
339
- }
340
-
341
- CUTLASS_HOST_DEVICE
342
- void set_residual_tile(bool enable) {
343
- address_iterator_.set_residual_tile(enable);
344
- }
345
-
346
- /// Clears the predicate set efficiently
347
- CUTLASS_HOST_DEVICE
348
- void enable_mask() {
349
- address_iterator_.enable_mask();
350
- }
351
-
352
- /// Sets the predicate mask, overriding value stored in predicate iterator
353
- CUTLASS_HOST_DEVICE
354
- void set_mask(Mask const& mask) {
355
- address_iterator_.set_mask(mask);
356
- }
357
-
358
- /// Gets the mask
359
- CUTLASS_HOST_DEVICE
360
- void get_mask(Mask& mask) {
361
- address_iterator_.get_mask(mask);
362
- }
363
-
364
- CUTLASS_DEVICE
365
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
366
- load_with_byte_offset(
367
- frag, pointer_offset * sizeof_bits<Element>::value / 8);
368
- }
369
-
370
- CUTLASS_DEVICE
371
- void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
372
- AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
373
-
374
- CUTLASS_PRAGMA_UNROLL
375
- for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
376
- CUTLASS_PRAGMA_UNROLL
377
- for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
378
- CUTLASS_PRAGMA_UNROLL
379
- for (int v = 0; v < kAccessesPerVector; ++v) {
380
- int idx = v +
381
- kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
382
-
383
- address_iterator_.set_iteration_index(idx);
384
- char const* byte_ptr =
385
- reinterpret_cast<char const*>(address_iterator_.get()) +
386
- byte_offset;
387
-
388
- AccessType const* access_ptr =
389
- reinterpret_cast<AccessType const*>(byte_ptr);
390
-
391
- cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
392
- frag_ptr[idx], access_ptr, address_iterator_.valid());
393
-
394
- ++address_iterator_;
395
- }
396
- }
397
- }
398
- }
399
-
400
- /// Loads a fragment from memory
401
- CUTLASS_DEVICE
402
- void load(Fragment& frag) {
403
- load_with_byte_offset(frag, 0);
404
- }
405
-
406
- /// Store a fragment to memory
407
- CUTLASS_DEVICE
408
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
409
- store_with_byte_offset(
410
- frag, pointer_offset * sizeof_bits<Element>::value / 8);
411
- }
412
-
413
- /// Store a fragment to memory
414
- CUTLASS_DEVICE
415
- void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
416
- address_iterator_.set_iteration_index(0);
417
- AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
418
-
419
- CUTLASS_PRAGMA_UNROLL
420
- for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
421
- CUTLASS_PRAGMA_UNROLL
422
- for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
423
- CUTLASS_PRAGMA_UNROLL
424
- for (int v = 0; v < kAccessesPerVector; ++v) {
425
- int idx = v +
426
- kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
427
-
428
- char* byte_ptr =
429
- reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
430
- AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
431
-
432
- if (address_iterator_.valid()) {
433
- *access_ptr = frag_ptr[idx];
434
- }
435
- ++address_iterator_;
436
- }
437
- }
438
- }
439
- }
440
-
441
- /// Store a fragment to memory
442
- CUTLASS_DEVICE
443
- void store(Fragment const& frag) {
444
- store_with_byte_offset(frag, 0);
445
- }
446
- };
447
-
448
- ////////////////////////////////////////////////////////////////////////////////
449
-
450
- /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
451
- ///
452
- /// Satisfies: ForwardTileIteratorConcept |
453
- /// ReadableContiguousTileIteratorConcept |
454
- /// WriteableContiguousTileIteratorConcept |
455
- /// MaskedTileIteratorConcept
456
- ///
457
- template <
458
- typename Shape_,
459
- typename Element_,
460
- int AdvanceRank,
461
- typename ThreadMap_,
462
- int AccessSize,
463
- bool Gather>
464
- class PredicatedTileIteratorResidualLast<
465
- Shape_,
466
- Element_,
467
- layout::ColumnMajor,
468
- AdvanceRank,
469
- ThreadMap_,
470
- AccessSize,
471
- Gather> {
472
- public:
473
- static_assert(
474
- AdvanceRank == 0 || AdvanceRank == 1,
475
- "Specialization for pitch-linear iterator may along advance along the "
476
- "contiguous(rank=0) or strided(rank=1) dimension.");
477
-
478
- using Shape = Shape_;
479
- using Element = Element_;
480
- using Layout = layout::ColumnMajor;
481
- static int const kAdvanceRank = AdvanceRank;
482
- using ThreadMap = ThreadMap_;
483
-
484
- using Index = typename Layout::Index;
485
- using LongIndex = typename Layout::LongIndex;
486
-
487
- using TensorRef = TensorRef<Element, Layout>;
488
- using TensorView = TensorView<Element, Layout>;
489
- using TensorCoord = typename Layout::TensorCoord;
490
-
491
- using Pointer = Element*;
492
- using NonConstPointer = typename platform::remove_const<Element>::type*;
493
-
494
- using UnderlyingIterator = PredicatedTileIteratorResidualLast<
495
- layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
496
- Element,
497
- layout::PitchLinear,
498
- (kAdvanceRank == 0 ? 0 : 1),
499
- ThreadMap,
500
- AccessSize,
501
- Gather>;
502
-
503
- using AccessType = typename UnderlyingIterator::AccessType;
504
-
505
- /// Fragment object to be loaded or stored
506
- using Fragment = cutlass::Array<
507
- Element,
508
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
509
-
510
- /// Predicate vector stores mask to guard accesses
511
- using Mask = typename UnderlyingIterator::Mask;
512
-
513
- /// Parameters object is precomputed state and is host-constructible
514
- class Params {
515
- private:
516
- friend PredicatedTileIteratorResidualLast;
517
-
518
- /// Parameters object
519
- typename UnderlyingIterator::Params params_;
520
-
521
- public:
522
- CUTLASS_HOST_DEVICE
523
- Params() {}
524
-
525
- /// Construct the Params object given a pitch-linear tensor's layout
526
- CUTLASS_HOST_DEVICE
527
- Params(Layout const& layout)
528
- : params_(layout::PitchLinear(layout.stride(0))) {}
529
-
530
- CUTLASS_HOST_DEVICE
531
- Params(typename UnderlyingIterator::Params::Base const& base)
532
- : params_(base) {}
533
- };
534
-
535
- private:
536
- //
537
- // Data members
538
- //
539
-
540
- /// Underlying pitch-linear tile iterator
541
- UnderlyingIterator iterator_;
542
-
543
- public:
544
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
545
- /// and thread ID
546
- CUTLASS_HOST_DEVICE
547
- PredicatedTileIteratorResidualLast(
548
- Params const& params, ///< Precomputed parameters object
549
- Pointer pointer, ///< Pointer to start of tensor
550
- TensorCoord extent, ///< Extent of tensor
551
- int thread_id, ///< ID of each participating thread
552
- TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
553
- int const* indices =
554
- nullptr ///< gather/scatter indices, note no support for
555
- ///< gather/scatter at this specialization
556
- )
557
- : iterator_(
558
- params.params_,
559
- pointer,
560
- layout::PitchLinearCoord(extent.row(), extent.column()),
561
- thread_id,
562
- layout::PitchLinearCoord(
563
- threadblock_offset.row(),
564
- threadblock_offset.column()),
565
- indices) {}
566
-
567
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
568
- /// offset
569
- CUTLASS_HOST_DEVICE
570
- PredicatedTileIteratorResidualLast(
571
- Params const& params, ///< Precomputed parameters object
572
- Pointer pointer, ///< Pointer to start of tensor
573
- TensorCoord extent, ///< Extent of tensor
574
- int thread_id ///< ID of each participating thread
575
- )
576
- : PredicatedTileIteratorResidualLast(
577
- params,
578
- pointer,
579
- extent,
580
- thread_id,
581
- make_Coord(0, 0)) {}
582
-
583
- /// Adds a pointer offset in units of Element
584
- CUTLASS_HOST_DEVICE
585
- void add_pointer_offset(LongIndex pointer_offset) {
586
- iterator_.add_pointer_offset(pointer_offset);
587
- }
588
-
589
- /// Advances to the next tile in memory.
590
- ///
591
- /// The first time this method is called, predicates are updated, and the
592
- /// iterator's internal pointer is reverted to the first "steady state" tile.
593
- /// Subsequent calls are lightweight and must only update the internal
594
- /// pointer.
595
- CUTLASS_HOST_DEVICE
596
- PredicatedTileIteratorResidualLast& operator++() {
597
- ++iterator_;
598
- return *this;
599
- }
600
-
601
- /// Advances to the next tile in memory.
602
- ///
603
- /// The first time this method is called, predicates are updated, and the
604
- /// iterator's internal pointer is reverted to the first "steady state" tile.
605
- /// Subsequent calls are lightweight and must only update the internal
606
- /// pointer.
607
- CUTLASS_HOST_DEVICE
608
- PredicatedTileIteratorResidualLast operator++(int) {
609
- PredicatedTileIteratorResidualLast self(*this);
610
- operator++();
611
- return self;
612
- }
613
-
614
- /// Clears the predicate set efficiently
615
- CUTLASS_HOST_DEVICE
616
- void clear_mask(bool enable = true) {
617
- iterator_.clear_mask(enable);
618
- }
619
-
620
- CUTLASS_HOST_DEVICE
621
- void set_residual_tile(bool enable) {
622
- iterator_.set_residual_tile(enable);
623
- }
624
-
625
- /// Clears the predicate set efficiently
626
- CUTLASS_HOST_DEVICE
627
- void enable_mask() {
628
- iterator_.enable_mask();
629
- }
630
-
631
- /// Sets the predicate mask, overriding value stored in predicate iterator
632
- CUTLASS_HOST_DEVICE
633
- void set_mask(Mask const& mask) {
634
- iterator_.set_mask(mask);
635
- }
636
-
637
- /// Gets the mask
638
- CUTLASS_HOST_DEVICE
639
- void get_mask(Mask& mask) {
640
- iterator_.get_mask(mask);
641
- }
642
-
643
- /// Loads a fragment from memory
644
- CUTLASS_DEVICE
645
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
646
- iterator_.load_with_pointer_offset(frag, pointer_offset);
647
- }
648
-
649
- /// Loads a fragment from memory
650
- CUTLASS_DEVICE
651
- void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
652
- iterator_.load_with_byte_offset(frag, byte_offset);
653
- }
654
-
655
- /// Loads a fragment from memory
656
- CUTLASS_DEVICE
657
- void load(Fragment& frag) {
658
- load_with_pointer_offset(frag, 0);
659
- }
660
-
661
- /// Store a fragment to memory
662
- CUTLASS_DEVICE
663
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
664
- iterator_.store_with_pointer_offset(frag, pointer_offset);
665
- }
666
-
667
- /// Store a fragment to memory
668
- CUTLASS_DEVICE
669
- void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
670
- iterator_.store_with_byte_offset(frag, byte_offset);
671
- }
672
-
673
- /// Store a fragment to memory
674
- CUTLASS_DEVICE
675
- void store(Fragment const& frag) {
676
- store_with_pointer_offset(frag, 0);
677
- }
678
- };
679
-
680
- ////////////////////////////////////////////////////////////////////////////////
681
-
682
- /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
683
- ///
684
- /// Satisfies: ForwardTileIteratorConcept |
685
- /// ReadableContiguousTileIteratorConcept |
686
- /// WriteableContiguousTileIteratorConcept |
687
- /// MaskedTileIteratorConcept
688
- ///
689
- template <
690
- typename Shape_,
691
- typename Element_,
692
- int AdvanceRank,
693
- typename ThreadMap_,
694
- int AccessSize,
695
- bool Gather>
696
- class PredicatedTileIteratorResidualLast<
697
- Shape_,
698
- Element_,
699
- layout::RowMajor,
700
- AdvanceRank,
701
- ThreadMap_,
702
- AccessSize,
703
- Gather> {
704
- public:
705
- static_assert(
706
- AdvanceRank == 0 || AdvanceRank == 1,
707
- "Specialization for pitch-linear iterator may along advance along the "
708
- "contiguous(rank=0) or strided(rank=1) dimension.");
709
-
710
- using Shape = Shape_;
711
- using Element = Element_;
712
- using Layout = layout::RowMajor;
713
- static int const kAdvanceRank = AdvanceRank;
714
- using ThreadMap = ThreadMap_;
715
-
716
- using Index = typename Layout::Index;
717
- using LongIndex = typename Layout::LongIndex;
718
-
719
- using TensorRef = TensorRef<Element, Layout>;
720
- using TensorView = TensorView<Element, Layout>;
721
- using TensorCoord = typename Layout::TensorCoord;
722
-
723
- using Pointer = Element*;
724
- using NonConstPointer = typename platform::remove_const<Element>::type*;
725
-
726
- using UnderlyingIterator = PredicatedTileIteratorResidualLast<
727
- layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
728
- Element,
729
- layout::PitchLinear,
730
- (kAdvanceRank == 0 ? 1 : 0),
731
- ThreadMap,
732
- AccessSize,
733
- Gather>;
734
-
735
- using AccessType = typename UnderlyingIterator::AccessType;
736
-
737
- /// Fragment object to be loaded or stored
738
- using Fragment = cutlass::Array<
739
- Element,
740
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
741
-
742
- /// Predicate vector stores mask to guard accesses
743
- using Mask = typename UnderlyingIterator::Mask;
744
-
745
- /// Parameters object is precomputed state and is host-constructible
746
- class Params {
747
- private:
748
- friend PredicatedTileIteratorResidualLast;
749
-
750
- /// Parameters object
751
- typename UnderlyingIterator::Params params_;
752
-
753
- public:
754
- CUTLASS_HOST_DEVICE
755
- Params() {}
756
-
757
- /// Construct the Params object given a pitch-linear tensor's layout
758
- CUTLASS_HOST_DEVICE
759
- Params(Layout const& layout)
760
- : params_(layout::PitchLinear(layout.stride(0))) {}
761
-
762
- CUTLASS_HOST_DEVICE
763
- Params(typename UnderlyingIterator::Params::Base const& base)
764
- : params_(base) {}
765
- };
766
-
767
- private:
768
- //
769
- // Data members
770
- //
771
-
772
- /// Underlying pitch-linear tile iterator
773
- UnderlyingIterator iterator_;
774
-
775
- public:
776
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
777
- /// and thread ID
778
- CUTLASS_HOST_DEVICE
779
- PredicatedTileIteratorResidualLast(
780
- Params const& params, ///< Precomputed parameters object
781
- Pointer pointer, ///< Pointer to start of tensor
782
- TensorCoord extent, ///< Extent of tensor
783
- int thread_id, ///< ID of each participating thread
784
- TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
785
- int const* indices = nullptr ///< Gather indices
786
- )
787
- : iterator_(
788
- params.params_,
789
- pointer,
790
- layout::PitchLinearCoord(extent.column(), extent.row()),
791
- thread_id,
792
- layout::PitchLinearCoord(
793
- threadblock_offset.column(),
794
- threadblock_offset.row()),
795
- indices) {}
796
-
797
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
798
- /// offset
799
- CUTLASS_HOST_DEVICE
800
- PredicatedTileIteratorResidualLast(
801
- Params const& params, ///< Precomputed parameters object
802
- Pointer pointer, ///< Pointer to start of tensor
803
- TensorCoord extent, ///< Extent of tensor
804
- int thread_id ///< ID of each participating thread
805
- )
806
- : PredicatedTileIteratorResidualLast(
807
- params,
808
- pointer,
809
- extent,
810
- thread_id,
811
- make_Coord(0, 0)) {}
812
-
813
- /// Adds a pointer offset in units of Element
814
- CUTLASS_HOST_DEVICE
815
- void add_pointer_offset(LongIndex pointer_offset) {
816
- iterator_.add_pointer_offset(pointer_offset);
817
- }
818
-
819
- /// Advances to the next tile in memory.
820
- ///
821
- /// The first time this method is called, predicates are updated, and the
822
- /// iterator's internal pointer is reverted to the first "steady state" tile.
823
- /// Subsequent calls are lightweight and must only update the internal
824
- /// pointer.
825
- CUTLASS_HOST_DEVICE
826
- PredicatedTileIteratorResidualLast& operator++() {
827
- ++iterator_;
828
- return *this;
829
- }
830
-
831
- /// Advances to the next tile in memory.
832
- ///
833
- /// The first time this method is called, predicates are updated, and the
834
- /// iterator's internal pointer is reverted to the first "steady state" tile.
835
- /// Subsequent calls are lightweight and must only update the internal
836
- /// pointer.
837
- CUTLASS_HOST_DEVICE
838
- PredicatedTileIteratorResidualLast operator++(int) {
839
- PredicatedTileIteratorResidualLast self(*this);
840
- operator++();
841
- return self;
842
- }
843
-
844
- /// Clears the predicate set efficiently
845
- CUTLASS_HOST_DEVICE
846
- void clear_mask(bool enable = true) {
847
- iterator_.clear_mask(enable);
848
- }
849
-
850
- CUTLASS_HOST_DEVICE
851
- void set_residual_tile(bool enable) {
852
- iterator_.set_residual_tile(enable);
853
- }
854
-
855
- /// Clears the predicate set efficiently
856
- CUTLASS_HOST_DEVICE
857
- void enable_mask() {
858
- iterator_.enable_mask();
859
- }
860
-
861
- /// Sets the predicate mask, overriding value stored in predicate iterator
862
- CUTLASS_HOST_DEVICE
863
- void set_mask(Mask const& mask) {
864
- iterator_.set_mask(mask);
865
- }
866
-
867
- /// Gets the mask
868
- CUTLASS_HOST_DEVICE
869
- void get_mask(Mask& mask) {
870
- iterator_.get_mask(mask);
871
- }
872
-
873
- /// Loads a fragment from memory
874
- CUTLASS_DEVICE
875
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
876
- iterator_.load_with_pointer_offset(frag, pointer_offset);
877
- }
878
-
879
- /// Loads a fragment from memory
880
- CUTLASS_DEVICE
881
- void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
882
- iterator_.load_with_byte_offset(frag, byte_offset);
883
- }
884
-
885
- /// Loads a fragment from memory
886
- CUTLASS_DEVICE
887
- void load(Fragment& frag) {
888
- load_with_pointer_offset(frag, 0);
889
- }
890
-
891
- /// Store a fragment to memory
892
- CUTLASS_DEVICE
893
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
894
- iterator_.store_with_pointer_offset(frag, pointer_offset);
895
- }
896
-
897
- /// Store a fragment to memory
898
- CUTLASS_DEVICE
899
- void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
900
- iterator_.store_with_byte_offset(frag, byte_offset);
901
- }
902
-
903
- /// Store a fragment to memory
904
- CUTLASS_DEVICE
905
- void store(Fragment const& frag) {
906
- store_with_pointer_offset(frag, 0);
907
- }
908
- };
909
-
910
- ////////////////////////////////////////////////////////////////////////////////
911
-
912
- /// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data.
913
- ///
914
- /// Satisfies: ForwardTileIteratorConcept |
915
- /// ReadableContiguousTileIteratorConcept |
916
- /// WriteableContiguousTileIteratorConcept |
917
- /// MaskedTileIteratorConcept
918
- ///
919
- template <
920
- typename Shape_,
921
- typename Element_,
922
- int AdvanceRank,
923
- typename ThreadMap_,
924
- int AccessSize>
925
- class PredicatedTileIteratorResidualLast<
926
- Shape_,
927
- Element_,
928
- layout::AffineRankN<2>,
929
- AdvanceRank,
930
- ThreadMap_,
931
- AccessSize,
932
- false> {
933
- public:
934
- static_assert(
935
- AdvanceRank == 0 || AdvanceRank == 1,
936
- "Specialization for pitch-linear iterator may advance along the "
937
- "contiguous(rank=0) or strided(rank=1) dimension.");
938
-
939
- using Shape = Shape_;
940
- using Element = Element_;
941
- using Layout = layout::AffineRankN<2>;
942
- static int const kAdvanceRank = AdvanceRank;
943
- using ThreadMap = ThreadMap_;
944
-
945
- using Index = typename Layout::Index;
946
- using LongIndex = typename Layout::LongIndex;
947
-
948
- using TensorRef = TensorRef<Element, Layout>;
949
- using TensorView = TensorView<Element, Layout>;
950
- using TensorCoord = typename Layout::TensorCoord;
951
-
952
- using Pointer = Element*;
953
- using NonConstPointer = typename platform::remove_const<Element>::type*;
954
-
955
- /// Type used for internal memory accesses
956
- using AccessType = AlignedArray<
957
- Element,
958
- AccessSize,
959
- (AccessSize * sizeof_bits<Element>::value / 8)>;
960
-
961
- /// Underlying iterator to compute the addresses
962
- using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
963
- Shape,
964
- Element,
965
- Layout,
966
- kAdvanceRank,
967
- ThreadMap,
968
- AccessType>;
969
-
970
- static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
971
-
972
- /// Fragment object to be loaded or stored
973
- using Fragment = cutlass::Array<
974
- Element,
975
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
976
-
977
- /// Predicate vector stores mask to guard accesses
978
- using Mask = typename TileAccessIterator::Mask;
979
-
980
- /// Parameters object is precomputed state and is host-constructible
981
- class Params {
982
- public:
983
- friend PredicatedTileIteratorResidualLast;
984
-
985
- private:
986
- /// Parameters object
987
- typename TileAccessIterator::Params params_;
988
-
989
- public:
990
- /// Construct the Params object given a pitch-linear tensor's layout
991
- CUTLASS_HOST_DEVICE
992
- Params(Layout const& layout) : params_(layout) {}
993
-
994
- CUTLASS_HOST_DEVICE
995
- Params() {}
996
- };
997
-
998
- private:
999
- /// Internal pointer type permits fast address arithmetic
1000
- using BytePointer = char*;
1001
-
1002
- private:
1003
- //
1004
- // Data members
1005
- //
1006
-
1007
- /// Data member to the tile access iterator
1008
- TileAccessIterator address_iterator_;
1009
-
1010
- public:
1011
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
1012
- /// and thread ID
1013
- CUTLASS_HOST_DEVICE
1014
- PredicatedTileIteratorResidualLast(
1015
- /// Precomputed parameters object
1016
- Params const& params,
1017
- /// Pointer to start of tensor
1018
- Pointer pointer,
1019
- /// Extent of tensor
1020
- TensorCoord extent,
1021
- /// ID of each participating thread
1022
- int thread_id,
1023
- /// Initial offset of threadblock
1024
- TensorCoord const& threadblock_offset,
1025
- int const* indices =
1026
- nullptr ///< gather/scatter indices, note no support for
1027
- ///< gather/scatter at this specialization
1028
- )
1029
- : address_iterator_(
1030
- params.params_,
1031
- pointer,
1032
- extent,
1033
- thread_id,
1034
- threadblock_offset) {}
1035
-
1036
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1037
- /// offset
1038
- CUTLASS_HOST_DEVICE
1039
- PredicatedTileIteratorResidualLast(
1040
- Params const& params, ///< Precomputed parameters object
1041
- Pointer pointer, ///< Pointer to start of tensor
1042
- TensorCoord extent, ///< Extent of tensor
1043
- int thread_id ///< ID of each participating thread
1044
- )
1045
- : PredicatedTileIteratorResidualLast(
1046
- params,
1047
- pointer,
1048
- extent,
1049
- thread_id,
1050
- make_Coord(0, 0)) {}
1051
-
1052
- /// Adds a pointer offset in units of Element
1053
- CUTLASS_HOST_DEVICE
1054
- void add_pointer_offset(LongIndex pointer_offset) {
1055
- address_iterator_.add_pointer_offset(pointer_offset);
1056
- }
1057
-
1058
- /// Advances to the next tile in memory.
1059
- ///
1060
- /// The first time this method is called, predicates are updated, and the
1061
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1062
- /// Subsequent calls are lightweight and must only update the internal
1063
- /// pointer.
1064
- CUTLASS_HOST_DEVICE
1065
- PredicatedTileIteratorResidualLast& operator++() {
1066
- if (kAdvanceRank)
1067
- address_iterator_.add_tile_offset(make_Coord(0, 1));
1068
- else
1069
- address_iterator_.add_tile_offset(make_Coord(1, 0));
1070
-
1071
- return *this;
1072
- }
1073
-
1074
- /// Advances to the next tile in memory.
1075
- ///
1076
- /// The first time this method is called, predicates are updated, and the
1077
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1078
- /// Subsequent calls are lightweight and must only update the internal
1079
- /// pointer.
1080
- CUTLASS_HOST_DEVICE
1081
- PredicatedTileIteratorResidualLast operator++(int) {
1082
- PredicatedTileIteratorResidualLast self(*this);
1083
- operator++();
1084
- return self;
1085
- }
1086
-
1087
- /// Clears the predicate set efficiently
1088
- CUTLASS_HOST_DEVICE
1089
- void clear_mask(bool enable = true) {
1090
- address_iterator_.clear_mask(enable);
1091
- }
1092
-
1093
- CUTLASS_HOST_DEVICE
1094
- void set_residual_tile(bool enable) {
1095
- address_iterator_.set_residual_tile(enable);
1096
- }
1097
-
1098
- /// Clears the predicate set efficiently
1099
- CUTLASS_HOST_DEVICE
1100
- void enable_mask() {
1101
- address_iterator_.enable_mask();
1102
- }
1103
-
1104
- /// Sets the predicate mask, overriding value stored in predicate iterator
1105
- CUTLASS_HOST_DEVICE
1106
- void set_mask(Mask const& mask) {
1107
- address_iterator_.set_mask(mask);
1108
- }
1109
-
1110
- /// Gets the mask
1111
- CUTLASS_HOST_DEVICE
1112
- void get_mask(Mask& mask) {
1113
- address_iterator_.get_mask(mask);
1114
- }
1115
-
1116
- CUTLASS_DEVICE
1117
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1118
- load_with_byte_offset(
1119
- frag, pointer_offset * sizeof_bits<Element>::value / 8);
1120
- }
1121
-
1122
- CUTLASS_DEVICE
1123
- void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
1124
- AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
1125
-
1126
- CUTLASS_PRAGMA_UNROLL
1127
- for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
1128
- CUTLASS_PRAGMA_UNROLL
1129
- for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
1130
- CUTLASS_PRAGMA_UNROLL
1131
- for (int v = 0; v < kAccessesPerVector; ++v) {
1132
- int idx = v +
1133
- kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
1134
-
1135
- address_iterator_.set_iteration_index(idx);
1136
- char const* byte_ptr =
1137
- reinterpret_cast<char const*>(address_iterator_.get()) +
1138
- byte_offset;
1139
-
1140
- AccessType const* access_ptr =
1141
- reinterpret_cast<AccessType const*>(byte_ptr);
1142
-
1143
- cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
1144
- frag_ptr[idx], access_ptr, address_iterator_.valid());
1145
-
1146
- ++address_iterator_;
1147
- }
1148
- }
1149
- }
1150
- }
1151
-
1152
- /// Loads a fragment from memory
1153
- CUTLASS_DEVICE
1154
- void load(Fragment& frag) {
1155
- load_with_byte_offset(frag, 0);
1156
- }
1157
-
1158
- /// Store a fragment to memory
1159
- CUTLASS_DEVICE
1160
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1161
- store_with_byte_offset(
1162
- frag, pointer_offset * sizeof_bits<Element>::value / 8);
1163
- }
1164
-
1165
- /// Store a fragment to memory
1166
- CUTLASS_DEVICE
1167
- void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
1168
- address_iterator_.set_iteration_index(0);
1169
- AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
1170
-
1171
- CUTLASS_PRAGMA_UNROLL
1172
- for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
1173
- CUTLASS_PRAGMA_UNROLL
1174
- for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
1175
- CUTLASS_PRAGMA_UNROLL
1176
- for (int v = 0; v < kAccessesPerVector; ++v) {
1177
- int idx = v +
1178
- kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
1179
-
1180
- char* byte_ptr =
1181
- reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
1182
- AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
1183
-
1184
- if (address_iterator_.valid()) {
1185
- *access_ptr = frag_ptr[idx];
1186
- }
1187
- ++address_iterator_;
1188
- }
1189
- }
1190
- }
1191
- }
1192
-
1193
- /// Store a fragment to memory
1194
- CUTLASS_DEVICE
1195
- void store(Fragment const& frag) {
1196
- store_with_byte_offset(frag, 0);
1197
- }
1198
- };
1199
-
1200
- ////////////////////////////////////////////////////////////////////////////////
1201
-
1202
- /// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
1203
- /// column-major data.
1204
- ///
1205
- /// Satisfies: ForwardTileIteratorConcept |
1206
- /// ReadableContiguousTileIteratorConcept |
1207
- /// WriteableContiguousTileIteratorConcept |
1208
- /// MaskedTileIteratorConcept
1209
- ///
1210
- template <
1211
- typename Shape_,
1212
- typename Element_,
1213
- int AdvanceRank,
1214
- typename ThreadMap_,
1215
- int AccessSize>
1216
- class PredicatedTileIteratorResidualLast<
1217
- Shape_,
1218
- Element_,
1219
- layout::AffineRank2ColumnMajor,
1220
- AdvanceRank,
1221
- ThreadMap_,
1222
- AccessSize,
1223
- false> {
1224
- public:
1225
- static_assert(
1226
- AdvanceRank == 0 || AdvanceRank == 1,
1227
- "Specialization for pitch-linear iterator may along advance along the "
1228
- "contiguous(rank=0) or strided(rank=1) dimension.");
1229
-
1230
- using Shape = Shape_;
1231
- using Element = Element_;
1232
- using Layout = layout::AffineRank2ColumnMajor;
1233
- static int const kAdvanceRank = AdvanceRank;
1234
- using ThreadMap = ThreadMap_;
1235
-
1236
- using Index = typename Layout::Index;
1237
- using LongIndex = typename Layout::LongIndex;
1238
-
1239
- using TensorRef = TensorRef<Element, Layout>;
1240
- using TensorView = TensorView<Element, Layout>;
1241
- using TensorCoord = typename Layout::TensorCoord;
1242
-
1243
- using Pointer = Element*;
1244
- using NonConstPointer = typename platform::remove_const<Element>::type*;
1245
-
1246
- // Map to the underlying AffineRankN<2> layout
1247
- using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1248
- layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
1249
- Element,
1250
- layout::AffineRankN<2>,
1251
- (kAdvanceRank == 0 ? 0 : 1),
1252
- ThreadMap,
1253
- AccessSize>;
1254
-
1255
- using AccessType = typename UnderlyingIterator::AccessType;
1256
-
1257
- /// Fragment object to be loaded or stored
1258
- using Fragment = cutlass::Array<
1259
- Element,
1260
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1261
-
1262
- /// Predicate vector stores mask to guard accesses
1263
- using Mask = typename UnderlyingIterator::Mask;
1264
-
1265
- /// Parameters object is precomputed state and is host-constructible
1266
- class Params {
1267
- private:
1268
- friend PredicatedTileIteratorResidualLast;
1269
-
1270
- /// Parameters object
1271
- typename UnderlyingIterator::Params params_;
1272
-
1273
- public:
1274
- CUTLASS_HOST_DEVICE
1275
- Params() {}
1276
-
1277
- /// Construct the Params object given an AffineRankN<2> tensor's layout
1278
- CUTLASS_HOST_DEVICE
1279
- Params(Layout const& layout)
1280
- : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {}
1281
- };
1282
-
1283
- private:
1284
- //
1285
- // Data members
1286
- //
1287
-
1288
- /// Underlying AffineRankN<2> tile iterator
1289
- UnderlyingIterator iterator_;
1290
-
1291
- public:
1292
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
1293
- /// and thread ID
1294
- CUTLASS_HOST_DEVICE
1295
- PredicatedTileIteratorResidualLast(
1296
- Params const& params, ///< Precomputed parameters object
1297
- Pointer pointer, ///< Pointer to start of tensor
1298
- TensorCoord extent, ///< Extent of tensor
1299
- int thread_id, ///< ID of each participating thread
1300
- TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
1301
- int const* indices =
1302
- nullptr ///< gather/scatter indices, note no support for
1303
- ///< gather/scatter at this specialization
1304
- )
1305
- : iterator_(
1306
- params.params_,
1307
- pointer,
1308
- layout::PitchLinearCoord(extent.row(), extent.column()),
1309
- thread_id,
1310
- layout::PitchLinearCoord(
1311
- threadblock_offset.row(),
1312
- threadblock_offset.column())) {}
1313
-
1314
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1315
- /// offset
1316
- CUTLASS_HOST_DEVICE
1317
- PredicatedTileIteratorResidualLast(
1318
- Params const& params, ///< Precomputed parameters object
1319
- Pointer pointer, ///< Pointer to start of tensor
1320
- TensorCoord extent, ///< Extent of tensor
1321
- int thread_id ///< ID of each participating thread
1322
- )
1323
- : PredicatedTileIteratorResidualLast(
1324
- params,
1325
- pointer,
1326
- extent,
1327
- thread_id,
1328
- make_Coord(0, 0)) {}
1329
-
1330
- /// Adds a pointer offset in units of Element
1331
- CUTLASS_HOST_DEVICE
1332
- void add_pointer_offset(LongIndex pointer_offset) {
1333
- iterator_.add_pointer_offset(pointer_offset);
1334
- }
1335
-
1336
- /// Advances to the next tile in memory.
1337
- ///
1338
- /// The first time this method is called, predicates are updated, and the
1339
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1340
- /// Subsequent calls are lightweight and must only update the internal
1341
- /// pointer.
1342
- CUTLASS_HOST_DEVICE
1343
- PredicatedTileIteratorResidualLast& operator++() {
1344
- ++iterator_;
1345
- return *this;
1346
- }
1347
-
1348
- /// Advances to the next tile in memory.
1349
- ///
1350
- /// The first time this method is called, predicates are updated, and the
1351
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1352
- /// Subsequent calls are lightweight and must only update the internal
1353
- /// pointer.
1354
- CUTLASS_HOST_DEVICE
1355
- PredicatedTileIteratorResidualLast operator++(int) {
1356
- PredicatedTileIteratorResidualLast self(*this);
1357
- operator++();
1358
- return self;
1359
- }
1360
-
1361
- /// Clears the predicate set efficiently
1362
- CUTLASS_HOST_DEVICE
1363
- void clear_mask(bool enable = true) {
1364
- iterator_.clear_mask(enable);
1365
- }
1366
-
1367
- CUTLASS_HOST_DEVICE
1368
- void set_residual_tile(bool enable) {
1369
- iterator_.set_residual_tile(enable);
1370
- }
1371
-
1372
- /// Clears the predicate set efficiently
1373
- CUTLASS_HOST_DEVICE
1374
- void enable_mask() {
1375
- iterator_.enable_mask();
1376
- }
1377
-
1378
- /// Sets the predicate mask, overriding value stored in predicate iterator
1379
- CUTLASS_HOST_DEVICE
1380
- void set_mask(Mask const& mask) {
1381
- iterator_.set_mask(mask);
1382
- }
1383
-
1384
- /// Gets the mask
1385
- CUTLASS_HOST_DEVICE
1386
- void get_mask(Mask& mask) {
1387
- iterator_.get_mask(mask);
1388
- }
1389
-
1390
- /// Loads a fragment from memory
1391
- CUTLASS_DEVICE
1392
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1393
- iterator_.load_with_pointer_offset(frag, pointer_offset);
1394
- }
1395
-
1396
- /// Loads a fragment from memory
1397
- CUTLASS_DEVICE
1398
- void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
1399
- iterator_.load_with_byte_offset(frag, byte_offset);
1400
- }
1401
-
1402
- /// Loads a fragment from memory
1403
- CUTLASS_DEVICE
1404
- void load(Fragment& frag) {
1405
- load_with_pointer_offset(frag, 0);
1406
- }
1407
-
1408
- /// Store a fragment to memory
1409
- CUTLASS_DEVICE
1410
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1411
- iterator_.store_with_pointer_offset(frag, pointer_offset);
1412
- }
1413
-
1414
- /// Store a fragment to memory
1415
- CUTLASS_DEVICE
1416
- void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
1417
- iterator_.store_with_byte_offset(frag, byte_offset);
1418
- }
1419
-
1420
- /// Store a fragment to memory
1421
- CUTLASS_DEVICE
1422
- void store(Fragment const& frag) {
1423
- store_with_pointer_offset(frag, 0);
1424
- }
1425
- };
1426
-
1427
- ////////////////////////////////////////////////////////////////////////////////
1428
-
1429
- /// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
1430
- /// row-major data.
1431
- ///
1432
- /// Satisfies: ForwardTileIteratorConcept |
1433
- /// ReadableContiguousTileIteratorConcept |
1434
- /// WriteableContiguousTileIteratorConcept |
1435
- /// MaskedTileIteratorConcept
1436
- ///
1437
- template <
1438
- typename Shape_,
1439
- typename Element_,
1440
- int AdvanceRank,
1441
- typename ThreadMap_,
1442
- int AccessSize>
1443
- class PredicatedTileIteratorResidualLast<
1444
- Shape_,
1445
- Element_,
1446
- layout::AffineRank2RowMajor,
1447
- AdvanceRank,
1448
- ThreadMap_,
1449
- AccessSize,
1450
- false> {
1451
- public:
1452
- static_assert(
1453
- AdvanceRank == 0 || AdvanceRank == 1,
1454
- "Specialization for pitch-linear iterator may along advance along the "
1455
- "contiguous(rank=0) or strided(rank=1) dimension.");
1456
-
1457
- using Shape = Shape_;
1458
- using Element = Element_;
1459
- using Layout = layout::AffineRank2RowMajor;
1460
- static int const kAdvanceRank = AdvanceRank;
1461
- using ThreadMap = ThreadMap_;
1462
-
1463
- using Index = typename Layout::Index;
1464
- using LongIndex = typename Layout::LongIndex;
1465
-
1466
- using TensorRef = TensorRef<Element, Layout>;
1467
- using TensorView = TensorView<Element, Layout>;
1468
- using TensorCoord = typename Layout::TensorCoord;
1469
-
1470
- using Pointer = Element*;
1471
- using NonConstPointer = typename platform::remove_const<Element>::type*;
1472
-
1473
- // Map to the underlying AffineRankN<2> layout
1474
- using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1475
- layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
1476
- Element,
1477
- layout::AffineRankN<2>,
1478
- (kAdvanceRank == 0 ? 1 : 0),
1479
- ThreadMap,
1480
- AccessSize>;
1481
-
1482
- using AccessType = typename UnderlyingIterator::AccessType;
1483
-
1484
- /// Fragment object to be loaded or stored
1485
- using Fragment = cutlass::Array<
1486
- Element,
1487
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1488
-
1489
- /// Predicate vector stores mask to guard accesses
1490
- using Mask = typename UnderlyingIterator::Mask;
1491
-
1492
- /// Parameters object is precomputed state and is host-constructible
1493
- class Params {
1494
- private:
1495
- friend PredicatedTileIteratorResidualLast;
1496
-
1497
- /// Parameters object
1498
- typename UnderlyingIterator::Params params_;
1499
-
1500
- public:
1501
- CUTLASS_HOST_DEVICE
1502
- Params() {}
1503
-
1504
- /// Construct the Params object given an AffineRankN<2> tensor's layout
1505
- CUTLASS_HOST_DEVICE
1506
- Params(Layout const& layout)
1507
- : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
1508
- };
1509
-
1510
- private:
1511
- //
1512
- // Data members
1513
- //
1514
-
1515
- /// Underlying AffineRankN<2> tile iterator
1516
- UnderlyingIterator iterator_;
1517
-
1518
- public:
1519
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
1520
- /// and thread ID
1521
- CUTLASS_HOST_DEVICE
1522
- PredicatedTileIteratorResidualLast(
1523
- Params const& params, ///< Precomputed parameters object
1524
- Pointer pointer, ///< Pointer to start of tensor
1525
- TensorCoord extent, ///< Extent of tensor
1526
- int thread_id, ///< ID of each participating thread
1527
- TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
1528
- int const* indices =
1529
- nullptr ///< gather/scatter indices, note no support for
1530
- ///< gather/scatter at this specialization
1531
- )
1532
- : iterator_(
1533
- params.params_,
1534
- pointer,
1535
- layout::PitchLinearCoord(extent.column(), extent.row()),
1536
- thread_id,
1537
- layout::PitchLinearCoord(
1538
- threadblock_offset.column(),
1539
- threadblock_offset.row())) {}
1540
-
1541
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1542
- /// offset
1543
- CUTLASS_HOST_DEVICE
1544
- PredicatedTileIteratorResidualLast(
1545
- Params const& params, ///< Precomputed parameters object
1546
- Pointer pointer, ///< Pointer to start of tensor
1547
- TensorCoord extent, ///< Extent of tensor
1548
- int thread_id ///< ID of each participating thread
1549
- )
1550
- : PredicatedTileIteratorResidualLast(
1551
- params,
1552
- pointer,
1553
- extent,
1554
- thread_id,
1555
- make_Coord(0, 0)) {}
1556
-
1557
- /// Adds a pointer offset in units of Element
1558
- CUTLASS_HOST_DEVICE
1559
- void add_pointer_offset(LongIndex pointer_offset) {
1560
- iterator_.add_pointer_offset(pointer_offset);
1561
- }
1562
-
1563
- /// Advances to the next tile in memory.
1564
- ///
1565
- /// The first time this method is called, predicates are updated, and the
1566
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1567
- /// Subsequent calls are lightweight and must only update the internal
1568
- /// pointer.
1569
- CUTLASS_HOST_DEVICE
1570
- PredicatedTileIteratorResidualLast& operator++() {
1571
- ++iterator_;
1572
- return *this;
1573
- }
1574
-
1575
- /// Advances to the next tile in memory.
1576
- ///
1577
- /// The first time this method is called, predicates are updated, and the
1578
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1579
- /// Subsequent calls are lightweight and must only update the internal
1580
- /// pointer.
1581
- CUTLASS_HOST_DEVICE
1582
- PredicatedTileIteratorResidualLast operator++(int) {
1583
- PredicatedTileIteratorResidualLast self(*this);
1584
- operator++();
1585
- return self;
1586
- }
1587
-
1588
- /// Clears the predicate set efficiently
1589
- CUTLASS_HOST_DEVICE
1590
- void clear_mask(bool enable = true) {
1591
- iterator_.clear_mask(enable);
1592
- }
1593
-
1594
- CUTLASS_HOST_DEVICE
1595
- void set_residual_tile(bool enable) {
1596
- iterator_.set_residual_tile(enable);
1597
- }
1598
-
1599
- /// Clears the predicate set efficiently
1600
- CUTLASS_HOST_DEVICE
1601
- void enable_mask() {
1602
- iterator_.enable_mask();
1603
- }
1604
-
1605
- /// Sets the predicate mask, overriding value stored in predicate iterator
1606
- CUTLASS_HOST_DEVICE
1607
- void set_mask(Mask const& mask) {
1608
- iterator_.set_mask(mask);
1609
- }
1610
-
1611
- /// Gets the mask
1612
- CUTLASS_HOST_DEVICE
1613
- void get_mask(Mask& mask) {
1614
- iterator_.get_mask(mask);
1615
- }
1616
-
1617
- /// Loads a fragment from memory
1618
- CUTLASS_DEVICE
1619
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1620
- iterator_.load_with_pointer_offset(frag, pointer_offset);
1621
- }
1622
-
1623
- /// Loads a fragment from memory
1624
- CUTLASS_DEVICE
1625
- void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
1626
- iterator_.load_with_byte_offset(frag, byte_offset);
1627
- }
1628
-
1629
- /// Loads a fragment from memory
1630
- CUTLASS_DEVICE
1631
- void load(Fragment& frag) {
1632
- load_with_pointer_offset(frag, 0);
1633
- }
1634
-
1635
- /// Store a fragment to memory
1636
- CUTLASS_DEVICE
1637
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1638
- iterator_.store_with_pointer_offset(frag, pointer_offset);
1639
- }
1640
-
1641
- /// Store a fragment to memory
1642
- CUTLASS_DEVICE
1643
- void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
1644
- iterator_.store_with_byte_offset(frag, byte_offset);
1645
- }
1646
-
1647
- /// Store a fragment to memory
1648
- CUTLASS_DEVICE
1649
- void store(Fragment const& frag) {
1650
- store_with_pointer_offset(frag, 0);
1651
- }
1652
- };
1653
-
1654
- ////////////////////////////////////////////////////////////////////////////////
1655
-
1656
- /// Specialization of PredicatedTileIteratorResidualLast for interleaved data.
1657
- /// It is mapped to the congruous layout.
1658
- ///
1659
- /// Satisfies: ForwardTileIteratorConcept |
1660
- /// ReadableContiguousTileIteratorConcept |
1661
- /// WriteableContiguousTileIteratorConcept |
1662
- /// MaskedTileIteratorConcept
1663
- ///
1664
-
1665
- template <
1666
- typename Shape_,
1667
- typename Element_,
1668
- int AdvanceRank,
1669
- typename ThreadMap_,
1670
- int AccessSize,
1671
- int InterleavedK>
1672
- class PredicatedTileIteratorResidualLast<
1673
- Shape_,
1674
- Element_,
1675
- layout::ColumnMajorInterleaved<InterleavedK>,
1676
- AdvanceRank,
1677
- ThreadMap_,
1678
- AccessSize,
1679
- false> {
1680
- public:
1681
- static_assert(
1682
- AdvanceRank == 0 || AdvanceRank == 1,
1683
- "Specialization for pitch-linear iterator may along advance along the "
1684
- "contiguous(rank=0) or strided(rank=1) dimension.");
1685
-
1686
- using Shape = Shape_;
1687
- using Element = Element_;
1688
- static int const kInterleavedK = InterleavedK;
1689
- using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
1690
- static int const kAdvanceRank = AdvanceRank;
1691
- using ThreadMap = ThreadMap_;
1692
-
1693
- using Index = typename Layout::Index;
1694
- using LongIndex = typename Layout::LongIndex;
1695
-
1696
- using TensorRef = TensorRef<Element, Layout>;
1697
- using TensorView = TensorView<Element, Layout>;
1698
- using TensorCoord = typename Layout::TensorCoord;
1699
-
1700
- using Pointer = Element*;
1701
- using NonConstPointer = typename platform::remove_const<Element>::type*;
1702
-
1703
- using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1704
- layout::PitchLinearShape<
1705
- Shape::kRow * kInterleavedK,
1706
- Shape::kColumn / kInterleavedK>,
1707
- Element,
1708
- layout::PitchLinear,
1709
- (kAdvanceRank == 0 ? 0 : 1),
1710
- ThreadMap,
1711
- AccessSize>;
1712
-
1713
- using AccessType = typename UnderlyingIterator::AccessType;
1714
-
1715
- /// Fragment object to be loaded or stored
1716
- using Fragment = cutlass::Array<
1717
- Element,
1718
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1719
-
1720
- /// Predicate vector stores mask to guard accesses
1721
- using Mask = typename UnderlyingIterator::Mask;
1722
-
1723
- /// Parameters object is precomputed state and is host-constructible
1724
- class Params {
1725
- private:
1726
- friend PredicatedTileIteratorResidualLast;
1727
-
1728
- /// Parameters object
1729
- typename UnderlyingIterator::Params params_;
1730
-
1731
- public:
1732
- CUTLASS_HOST_DEVICE
1733
- Params() {}
1734
-
1735
- /// Construct the Params object given a pitch-linear tensor's layout
1736
- CUTLASS_HOST_DEVICE
1737
- Params(Layout const& layout)
1738
- : params_(layout::PitchLinear(layout.stride(0))) {}
1739
-
1740
- CUTLASS_HOST_DEVICE
1741
- Params(typename UnderlyingIterator::Params::Base const& base)
1742
- : params_(base) {}
1743
- };
1744
-
1745
- private:
1746
- //
1747
- // Data members
1748
- //
1749
-
1750
- /// Underlying pitch-linear tile iterator
1751
- UnderlyingIterator iterator_;
1752
-
1753
- public:
1754
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
1755
- /// and thread ID
1756
- CUTLASS_HOST_DEVICE
1757
- PredicatedTileIteratorResidualLast(
1758
- /// Precomputed parameters object
1759
- Params const& params,
1760
- /// Pointer to start of tensor
1761
- Pointer pointer,
1762
- /// Extent of tensor
1763
- TensorCoord extent,
1764
- /// ID of each participating thread
1765
- int thread_id,
1766
- /// Initial offset of threadblock
1767
- TensorCoord const& threadblock_offset,
1768
- int const* indices =
1769
- nullptr ///< gather/scatter indices, note no support for
1770
- ///< gather/scatter at this specialization
1771
- )
1772
- : iterator_(
1773
- params.params_,
1774
- pointer,
1775
- layout::PitchLinearCoord(
1776
- extent.row() * kInterleavedK,
1777
- extent.column() / kInterleavedK),
1778
- thread_id,
1779
- layout::PitchLinearCoord(
1780
- threadblock_offset.row() * kInterleavedK,
1781
- threadblock_offset.column() / kInterleavedK)) {}
1782
-
1783
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1784
- /// offset
1785
- CUTLASS_HOST_DEVICE
1786
- PredicatedTileIteratorResidualLast(
1787
- Params const& params, ///< Precomputed parameters object
1788
- Pointer pointer, ///< Pointer to start of tensor
1789
- TensorCoord extent, ///< Extent of tensor
1790
- int thread_id ///< ID of each participating thread
1791
- )
1792
- : PredicatedTileIteratorResidualLast(
1793
- params,
1794
- pointer,
1795
- extent,
1796
- thread_id,
1797
- make_Coord(0, 0)) {}
1798
-
1799
- /// Adds a pointer offset in units of Element
1800
- CUTLASS_HOST_DEVICE
1801
- void add_pointer_offset(LongIndex pointer_offset) {
1802
- iterator_.add_pointer_offset(pointer_offset);
1803
- }
1804
-
1805
- /// Advances to the next tile in memory.
1806
- ///
1807
- /// The first time this method is called, predicates are updated, and the
1808
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1809
- /// Subsequent calls are lightweight and must only update the internal
1810
- /// pointer.
1811
- CUTLASS_HOST_DEVICE
1812
- PredicatedTileIteratorResidualLast& operator++() {
1813
- ++iterator_;
1814
- return *this;
1815
- }
1816
-
1817
- /// Advances to the next tile in memory.
1818
- ///
1819
- /// The first time this method is called, predicates are updated, and the
1820
- /// iterator's internal pointer is reverted to the first "steady state" tile.
1821
- /// Subsequent calls are lightweight and must only update the internal
1822
- /// pointer.
1823
- CUTLASS_HOST_DEVICE
1824
- PredicatedTileIteratorResidualLast operator++(int) {
1825
- PredicatedTileIteratorResidualLast self(*this);
1826
- operator++();
1827
- return self;
1828
- }
1829
-
1830
- /// Clears the predicate set efficiently
1831
- CUTLASS_HOST_DEVICE
1832
- void clear_mask(bool enable = true) {
1833
- iterator_.clear_mask(enable);
1834
- }
1835
-
1836
- CUTLASS_HOST_DEVICE
1837
- void set_residual_tile(bool enable) {
1838
- iterator_.set_residual_tile(enable);
1839
- }
1840
-
1841
- /// Clears the predicate set efficiently
1842
- CUTLASS_HOST_DEVICE
1843
- void enable_mask() {
1844
- iterator_.enable_mask();
1845
- }
1846
-
1847
- /// Sets the predicate mask, overriding value stored in predicate iterator
1848
- CUTLASS_HOST_DEVICE
1849
- void set_mask(Mask const& mask) {
1850
- iterator_.set_mask(mask);
1851
- }
1852
-
1853
- /// Gets the mask
1854
- CUTLASS_HOST_DEVICE
1855
- void get_mask(Mask& mask) {
1856
- iterator_.get_mask(mask);
1857
- }
1858
-
1859
- /// Loads a fragment from memory
1860
- CUTLASS_DEVICE
1861
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1862
- iterator_.load_with_pointer_offset(frag, pointer_offset);
1863
- }
1864
-
1865
- /// Loads a fragment from memory
1866
- CUTLASS_DEVICE
1867
- void load(Fragment& frag) {
1868
- load_with_pointer_offset(frag, 0);
1869
- }
1870
-
1871
- /// Store a fragment to memory
1872
- CUTLASS_DEVICE
1873
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1874
- iterator_.store_with_pointer_offset(frag, pointer_offset);
1875
- }
1876
-
1877
- /// Store a fragment to memory
1878
- CUTLASS_DEVICE
1879
- void store(Fragment const& frag) {
1880
- store_with_pointer_offset(frag, 0);
1881
- }
1882
- };
1883
-
1884
- ////////////////////////////////////////////////////////////////////////////////
1885
-
1886
- /// Specialization of PredicatedTileIteratorResidualLast for interleaved-32
1887
- /// data. It is mapped to the congruous layout.
1888
- ///
1889
- /// Satisfies: ForwardTileIteratorConcept |
1890
- /// ReadableContiguousTileIteratorConcept |
1891
- /// WriteableContiguousTileIteratorConcept |
1892
- /// MaskedTileIteratorConcept
1893
- ///
1894
- template <
1895
- typename Shape_,
1896
- typename Element_,
1897
- int AdvanceRank,
1898
- typename ThreadMap_,
1899
- int AccessSize,
1900
- int InterleavedK>
1901
- class PredicatedTileIteratorResidualLast<
1902
- Shape_,
1903
- Element_,
1904
- layout::RowMajorInterleaved<InterleavedK>,
1905
- AdvanceRank,
1906
- ThreadMap_,
1907
- AccessSize,
1908
- false> {
1909
- public:
1910
- static_assert(
1911
- AdvanceRank == 0 || AdvanceRank == 1,
1912
- "Specialization for pitch-linear iterator may along advance along the "
1913
- "contiguous(rank=0) or strided(rank=1) dimension.");
1914
-
1915
- using Shape = Shape_;
1916
- using Element = Element_;
1917
- static int const kInterleavedK = InterleavedK;
1918
- using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1919
- static int const kAdvanceRank = AdvanceRank;
1920
- using ThreadMap = ThreadMap_;
1921
-
1922
- using Index = typename Layout::Index;
1923
- using LongIndex = typename Layout::LongIndex;
1924
-
1925
- using TensorRef = TensorRef<Element, Layout>;
1926
- using TensorView = TensorView<Element, Layout>;
1927
- using TensorCoord = typename Layout::TensorCoord;
1928
-
1929
- using Pointer = Element*;
1930
- using NonConstPointer = typename platform::remove_const<Element>::type*;
1931
-
1932
- using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1933
- layout::PitchLinearShape<
1934
- Shape::kColumn * kInterleavedK,
1935
- Shape::kRow / kInterleavedK>,
1936
- Element,
1937
- layout::PitchLinear,
1938
- (kAdvanceRank == 0 ? 1 : 0),
1939
- ThreadMap,
1940
- AccessSize>;
1941
-
1942
- using AccessType = typename UnderlyingIterator::AccessType;
1943
-
1944
- /// Fragment object to be loaded or stored
1945
- using Fragment = cutlass::Array<
1946
- Element,
1947
- ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1948
-
1949
- /// Predicate vector stores mask to guard accesses
1950
- using Mask = typename UnderlyingIterator::Mask;
1951
-
1952
- /// Parameters object is precomputed state and is host-constructible
1953
- class Params {
1954
- private:
1955
- friend PredicatedTileIteratorResidualLast;
1956
-
1957
- /// Parameters object
1958
- typename UnderlyingIterator::Params params_;
1959
-
1960
- public:
1961
- CUTLASS_HOST_DEVICE
1962
- Params() {}
1963
-
1964
- /// Construct the Params object given a pitch-linear tensor's layout
1965
- CUTLASS_HOST_DEVICE
1966
- Params(Layout const& layout)
1967
- : params_(layout::PitchLinear(layout.stride(0))) {}
1968
-
1969
- CUTLASS_HOST_DEVICE
1970
- Params(typename UnderlyingIterator::Params::Base const& base)
1971
- : params_(base) {}
1972
- };
1973
-
1974
- private:
1975
- //
1976
- // Data members
1977
- //
1978
-
1979
- /// Underlying pitch-linear tile iterator
1980
- UnderlyingIterator iterator_;
1981
-
1982
- public:
1983
- /// Constructs a TileIterator from its precomputed state, threadblock offset,
1984
- /// and thread ID
1985
- CUTLASS_HOST_DEVICE
1986
- PredicatedTileIteratorResidualLast(
1987
- /// Precomputed parameters object
1988
- Params const& params,
1989
- /// Pointer to start of tensor
1990
- Pointer pointer,
1991
- /// Extent of tensor
1992
- TensorCoord extent,
1993
- /// ID of each participating thread
1994
- int thread_id,
1995
- /// Initial offset of threadblock
1996
- TensorCoord const& threadblock_offset,
1997
- int const* indices =
1998
- nullptr ///< gather/scatter indices, note no support for
1999
- ///< gather/scatter at this specialization
2000
- )
2001
- : iterator_(
2002
- params.params_,
2003
- pointer,
2004
- layout::PitchLinearCoord(
2005
- extent.column() * kInterleavedK,
2006
- extent.row() / kInterleavedK),
2007
- thread_id,
2008
- layout::PitchLinearCoord(
2009
- threadblock_offset.column() * kInterleavedK,
2010
- threadblock_offset.row() / kInterleavedK)) {}
2011
-
2012
- /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
2013
- /// offset
2014
- CUTLASS_HOST_DEVICE
2015
- PredicatedTileIteratorResidualLast(
2016
- Params const& params, ///< Precomputed parameters object
2017
- Pointer pointer, ///< Pointer to start of tensor
2018
- TensorCoord extent, ///< Extent of tensor
2019
- int thread_id ///< ID of each participating thread
2020
- )
2021
- : PredicatedTileIteratorResidualLast(
2022
- params,
2023
- pointer,
2024
- extent,
2025
- thread_id,
2026
- make_Coord(0, 0)) {}
2027
-
2028
- /// Adds a pointer offset in units of Element
2029
- CUTLASS_HOST_DEVICE
2030
- void add_pointer_offset(LongIndex pointer_offset) {
2031
- iterator_.add_pointer_offset(pointer_offset);
2032
- }
2033
-
2034
- /// Advances to the next tile in memory.
2035
- ///
2036
- /// The first time this method is called, predicates are updated, and the
2037
- /// iterator's internal pointer is reverted to the first "steady state" tile.
2038
- /// Subsequent calls are lightweight and must only update the internal
2039
- /// pointer.
2040
- CUTLASS_HOST_DEVICE
2041
- PredicatedTileIteratorResidualLast& operator++() {
2042
- ++iterator_;
2043
- return *this;
2044
- }
2045
-
2046
- /// Advances to the next tile in memory.
2047
- ///
2048
- /// The first time this method is called, predicates are updated, and the
2049
- /// iterator's internal pointer is reverted to the first "steady state" tile.
2050
- /// Subsequent calls are lightweight and must only update the internal
2051
- /// pointer.
2052
- CUTLASS_HOST_DEVICE
2053
- PredicatedTileIteratorResidualLast operator++(int) {
2054
- PredicatedTileIteratorResidualLast self(*this);
2055
- operator++();
2056
- return self;
2057
- }
2058
-
2059
- /// Clears the predicate set efficiently
2060
- CUTLASS_HOST_DEVICE
2061
- void clear_mask(bool enable = true) {
2062
- iterator_.clear_mask(enable);
2063
- }
2064
-
2065
- CUTLASS_HOST_DEVICE
2066
- void set_residual_tile(bool enable) {
2067
- iterator_.set_residual_tile(enable);
2068
- }
2069
-
2070
- /// Clears the predicate set efficiently
2071
- CUTLASS_HOST_DEVICE
2072
- void enable_mask() {
2073
- iterator_.enable_mask();
2074
- }
2075
-
2076
- /// Sets the predicate mask, overriding value stored in predicate iterator
2077
- CUTLASS_HOST_DEVICE
2078
- void set_mask(Mask const& mask) {
2079
- iterator_.set_mask(mask);
2080
- }
2081
-
2082
- /// Gets the mask
2083
- CUTLASS_HOST_DEVICE
2084
- void get_mask(Mask& mask) {
2085
- iterator_.get_mask(mask);
2086
- }
2087
-
2088
- /// Loads a fragment from memory
2089
- CUTLASS_DEVICE
2090
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
2091
- iterator_.load_with_pointer_offset(frag, pointer_offset);
2092
- }
2093
-
2094
- /// Loads a fragment from memory
2095
- CUTLASS_DEVICE
2096
- void load(Fragment& frag) {
2097
- load_with_pointer_offset(frag, 0);
2098
- }
2099
-
2100
- /// Store a fragment to memory
2101
- CUTLASS_DEVICE
2102
- void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
2103
- iterator_.store_with_pointer_offset(frag, pointer_offset);
2104
- }
2105
-
2106
- /// Store a fragment to memory
2107
- CUTLASS_DEVICE
2108
- void store(Fragment const& frag) {
2109
- store_with_pointer_offset(frag, 0);
2110
- }
2111
- };
2112
-
2113
- ////////////////////////////////////////////////////////////////////////////////
2114
-
2115
- } // namespace threadblock
2116
- } // namespace transform
2117
- } // namespace cutlass
2118
-
2119
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h DELETED
@@ -1,55 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "warp_iterator_from_smem.h"
35
-
36
- template <typename WarpIterator>
37
- struct TransposeWarpIterator {
38
- using Iterator = char;
39
- static bool constexpr kSupportsTranspose = false;
40
- };
41
-
42
- template <
43
- /// Operand identity
44
- cutlass::gemm::Operand Operand,
45
- /// Data type of A elements
46
- typename Element,
47
- typename InstructionShape,
48
- bool kTranspose>
49
- struct TransposeWarpIterator<
50
- cutlass::gemm::warp::
51
- WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
52
- using Iterator = cutlass::gemm::warp::
53
- WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
54
- static bool constexpr kSupportsTranspose = true;
55
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h DELETED
@@ -1,283 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Inspired from
33
- "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
34
- operands from a RowMajor shared-memory layout into registers to use by A100
35
- TensorCores.
36
-
37
- The difference with "mma_tensor_op_tile_access_iterator.h" is that:
38
- (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
39
- faster) (2) We support to transpose the operand (eg read `A.transpose()` when
40
- the shared memory holds `A`)
41
-
42
- This is only implemented for the specific shapes.
43
- */
44
- #pragma once
45
-
46
- #include <cutlass/gemm/gemm.h>
47
-
48
- ////////////////////////////////////////////////////////////////////////////////
49
- namespace cutlass {
50
- namespace gemm {
51
- namespace warp {
52
-
53
- template <
54
- /// Operand identity
55
- Operand Operand_,
56
- /// Data type of A elements
57
- typename Element_,
58
- typename InstructionShape_,
59
- bool kTranspose = false>
60
- class WarpIteratorFromSmem {
61
- public:
62
- /// Shape of tile to load (concept: MatrixShape)
63
- using Shape = cutlass::MatrixShape<32, 32>;
64
-
65
- /// Operand tag
66
- static Operand const kOperand = Operand_;
67
- static_assert(
68
- kOperand == Operand::kA,
69
- "No support for OperandB at the moment");
70
-
71
- /// Basic check
72
- static_assert(
73
- kOperand == Operand::kA || kOperand == Operand::kB,
74
- "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
75
-
76
- /// Element type
77
- using Element = Element_;
78
- static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
79
-
80
- /// Layout of source tile
81
- using Layout = cutlass::layout::RowMajor;
82
-
83
- /// Shape of one matrix product operation (concept: MatrixShape)
84
- using InstructionShape = InstructionShape_;
85
- static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
86
- static_assert(
87
- InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
88
- "Only supports 16x8x8 / 16x8x16");
89
-
90
- /// Delta between *MMA operations (in units of *MMA operations, concept:
91
- /// MatrixShape)
92
- static int const kOpDelta = 1;
93
-
94
- /// Number of participating threads
95
- static int const kThreads = 32;
96
-
97
- /// TensorRef type for loading element from a tensor
98
- using TensorRef = TensorRef<Element, Layout>;
99
-
100
- /// Index type
101
- using Index = typename TensorRef::Index;
102
-
103
- /// Long Index type
104
- using LongIndex = typename TensorRef::LongIndex;
105
-
106
- /// Coordinate for an element in the tensor
107
- using TensorCoord = typename TensorRef::TensorCoord;
108
-
109
- /// Number of elements accessed per Shared Memory load
110
- static int const kElementsPerAccess =
111
- (sizeof_bits<Element>::value >= 32 ? 1
112
- : 32 / sizeof_bits<Element>::value);
113
-
114
- using InstructionCount = MatrixShape<
115
- Shape::kRow / InstructionShape::kRow,
116
- Shape::kColumn / InstructionShape::kColumn>;
117
-
118
- static int const kIterations = (kOperand == Operand::kA)
119
- ? InstructionCount::kColumn
120
- : InstructionCount::kRow;
121
-
122
- public:
123
- //
124
- // Derived quantities
125
- //
126
-
127
- /// Fragment object holding a thread's part of a tile
128
- using Fragment = Array<
129
- Element,
130
- (kOperand == Operand::kA)
131
- ? (Shape::kRow* InstructionShape::kColumn / kThreads)
132
- : (Shape::kColumn* InstructionShape::kRow / kThreads)>;
133
-
134
- /// Memory access type
135
- // using AccessType = AlignedArray<Element, kElementsPerAccess>;
136
- using AccessType = Array<unsigned, 4>;
137
-
138
- static int constexpr kWarpShapeDivisibleInner =
139
- (kOperand == Operand::kA ? InstructionShape::kColumn
140
- : InstructionShape::kRow);
141
- static int constexpr kAccessesInner =
142
- (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
143
- // Number of 32bits tiles to load per `ldmatrix`
144
- static int const kTilesPerInstruction = InstructionShape::kRow / 8;
145
- static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
146
-
147
- private:
148
- /// Underlying tensor reference
149
- TensorRef ref_;
150
-
151
- /// Origin
152
- MatrixCoord origin_;
153
-
154
- /// Iterations in a tile
155
- int iterations_;
156
-
157
- public:
158
- /// Constructor from TensorRef
159
- CUTLASS_HOST_DEVICE
160
- WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
161
- : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
162
- CUTLASS_HOST_DEVICE
163
- WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
164
- : ref_(ref), iterations_(0) {
165
- // See also:
166
- // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
167
- // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
168
- // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
169
- int ldsm_vec_num = (lane_id >> 3);
170
- if (kOperand == Operand::kA) {
171
- origin_ = MatrixCoord(lane_id % 8, 0);
172
- static_assert(
173
- InstructionCount::kRow * kTilesPerInstruction == 4,
174
- "can't use ldmatrix.x4");
175
- int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
176
- int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
177
- int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
178
- MatrixCoord offset(
179
- access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
180
- inner_idx * 4 * kElementsPerAccess);
181
- if (kTranspose) {
182
- offset = MatrixCoord(offset.column(), offset.row());
183
- }
184
- origin_ += offset;
185
- } else {
186
- // Note: This is not tested or used
187
- origin_ = MatrixCoord(0, lane_id % 8);
188
- static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
189
- CUTLASS_PRAGMA_UNROLL
190
- for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
191
- ++inst_n_idx) {
192
- CUTLASS_PRAGMA_UNROLL
193
- for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
194
- int access_idx = inner_idx + kAccessesInner * inst_n_idx;
195
-
196
- MatrixCoord offset(
197
- inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
198
-
199
- if (access_idx == ldsm_vec_num) {
200
- if (kTranspose) {
201
- offset = MatrixCoord(offset.column(), offset.row());
202
- }
203
- origin_ += offset;
204
- }
205
- }
206
- }
207
- }
208
-
209
- ref_.add_coord_offset(origin_);
210
- }
211
-
212
- /// Advances an iterator along logical dimensions of matrix in units of whole
213
- /// tiles
214
- CUTLASS_HOST_DEVICE
215
- WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
216
- TensorCoord coord_offset(
217
- tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
218
- if (kTranspose) {
219
- coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
220
- }
221
- origin_ += coord_offset;
222
-
223
- ref_.add_coord_offset(coord_offset);
224
-
225
- return *this;
226
- }
227
-
228
- /// Advances the iterator along the advance dimension
229
- CUTLASS_DEVICE
230
- void advance() {
231
- if (kOperand == Operand::kA) {
232
- add_tile_offset({0, 1});
233
- } else {
234
- add_tile_offset({1, 0});
235
- }
236
-
237
- iterations_ = 0;
238
- }
239
-
240
- /// increase iterations in a tile
241
- CUTLASS_HOST_DEVICE
242
- WarpIteratorFromSmem& operator++() {
243
- iterations_++;
244
-
245
- if (iterations_ >= kIterations)
246
- advance();
247
-
248
- return *this;
249
- }
250
-
251
- /// Loads a fragment from memory at the location pointed to by the iterator.
252
- CUTLASS_DEVICE
253
- void load(Fragment& frag) const {
254
- AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
255
- using LoadLayout = typename platform::
256
- conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
257
-
258
- CUTLASS_PRAGMA_UNROLL
259
- for (int access_m_idx = 0; access_m_idx <
260
- (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
261
- ++access_m_idx) {
262
- MatrixCoord offset;
263
- if (kOperand == Operand::kA) {
264
- offset = MatrixCoord(
265
- access_m_idx * 16, iterations_ * InstructionShape::kColumn);
266
- } else {
267
- offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
268
- }
269
- if (kTranspose) {
270
- offset = MatrixCoord(offset.column(), offset.row());
271
- }
272
- cutlass::arch::ldsm<LoadLayout, 4>(
273
- access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
274
- }
275
- }
276
- };
277
-
278
- ////////////////////////////////////////////////////////////////////////////////
279
-
280
- } // namespace warp
281
- } // namespace gemm
282
- } // namespace cutlass
283
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h DELETED
The diff for this file is too large to render. See raw diff
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h DELETED
@@ -1,1322 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #ifdef HAS_PYTORCH
35
- #include <ATen/cuda/CUDAGeneratorImpl.h>
36
- #include <ATen/cuda/CUDAGraphsUtils.cuh>
37
- #endif
38
-
39
- #include <curand_kernel.h>
40
- #include <cmath>
41
- #include <cinttypes>
42
- #include <vector>
43
-
44
- #include "cutlass/fast_math.h"
45
- #include "cutlass/gemm/gemm.h"
46
- #include "cutlass/layout/matrix.h"
47
- #include "cutlass/layout/vector.h"
48
- #include "cutlass/matrix.h"
49
- #include "cutlass/numeric_types.h"
50
- #include "cutlass/tensor_ref.h"
51
-
52
- #include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
53
- #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
54
- #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
55
- #include "cutlass/gemm/device/default_gemm_configuration.h"
56
- #include "cutlass/gemm/kernel/default_gemm.h"
57
- #include "cutlass/gemm/threadblock/default_mma.h"
58
- #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
59
- #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
60
- #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
61
- #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
62
- #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
63
- #include "cutlass/matrix_shape.h"
64
- #include "cutlass/platform/platform.h"
65
- #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
66
- #include "debug_utils.h"
67
- #include "epilogue/epilogue_pipelined.h"
68
- #include "epilogue/epilogue_rescale_output.h"
69
- #include "gemm/custom_mma.h"
70
- #include "gemm/find_default_mma.h"
71
- #include "gemm/mma_from_smem.h"
72
- #include "gemm_kernel_utils.h"
73
- #include "transform/tile_smem_loader.h"
74
-
75
- using namespace gemm_kernel_utils;
76
-
77
- namespace {
78
- template <typename scalar_t, typename Arch>
79
- constexpr int getWarpsPerSmFw() {
80
- return (
81
- Arch::kMinComputeCapability >= 80 &&
82
- !cutlass::platform::is_same<scalar_t, float>::value
83
- ? 16
84
- : 12);
85
- }
86
- static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
87
- // source: https://stackoverflow.com/a/51549250
88
- return (value >= 0)
89
- ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
90
- : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
91
- }
92
- } // namespace
93
-
94
- // If ToBatchHookType_ is supplied other than this default (which is
95
- // never the case in the xformers library) then the user is
96
- // defining the logic which each block uses to find its data to work on,
97
- // with the advance_to_batch function with the following signature.
98
- // It should return false if there is no work to do for this block.
99
- // In general this will not work with saving for backward due to fixed layout
100
- // for logsumexp and incompatible rngs for dropout, so is likely only useful for
101
- // custom inference.
102
- struct DefaultToBatchHook {
103
- template <typename Params>
104
- CUTLASS_DEVICE static bool advance_to_batch(
105
- Params&,
106
- int64_t& /* q_start */,
107
- int64_t& /* k_start */) {
108
- return true;
109
- }
110
- };
111
-
112
- template <
113
- // The datatype of Q/K/V
114
- typename scalar_t_,
115
- // Architecture we are targeting (eg `cutlass::arch::Sm80`)
116
- typename ArchTag,
117
- // If Q/K/V are correctly aligned in memory and we can run a fast kernel
118
- bool isAligned_,
119
- int kQueriesPerBlock_,
120
- int kKeysPerBlock_,
121
- // upperbound on `max(value.shape[-1], query.shape[-1])`
122
- int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
123
- // This is quite slower on V100 for some reason
124
- // Set to false if you know at compile-time you will never need dropout
125
- bool kSupportsDropout_ = true,
126
- bool kSupportsBias_ = true,
127
- typename ToBatchHookType_ = DefaultToBatchHook>
128
- struct AttentionKernel {
129
- enum CustomMaskType {
130
- NoCustomMask = 0,
131
- CausalFromTopLeft = 1,
132
- CausalFromBottomRight = 2,
133
- NumCustomMaskTypes,
134
- };
135
-
136
- using scalar_t = scalar_t_;
137
- using accum_t = float;
138
- using lse_scalar_t = float;
139
- using output_t = scalar_t;
140
- // Accumulator between 2 iterations
141
- // Using `accum_t` improves perf on f16 at the cost of
142
- // numerical errors
143
- using output_accum_t = accum_t;
144
- static constexpr bool kSupportsDropout = kSupportsDropout_;
145
- static constexpr bool kSupportsBias = kSupportsBias_;
146
- static constexpr int kKeysPerBlock = kKeysPerBlock_;
147
- static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
148
- static constexpr int kMaxK = kMaxK_;
149
- static constexpr bool kIsAligned = isAligned_;
150
- static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
151
- static constexpr int32_t kAlignLSE = 32; // block size of backward
152
- static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
153
- static constexpr bool kPreloadV =
154
- ArchTag::kMinComputeCapability >= 80 && kIsHalf;
155
- static constexpr bool kKeepOutputInRF = kSingleValueIteration;
156
- static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
157
- !cutlass::platform::is_same<output_accum_t, output_t>::value;
158
-
159
- static_assert(kQueriesPerBlock % 32 == 0, "");
160
- static_assert(kKeysPerBlock % 32 == 0, "");
161
- static constexpr int kNumWarpsPerBlock =
162
- kQueriesPerBlock * kKeysPerBlock / (32 * 32);
163
- static constexpr int kWarpSize = 32;
164
-
165
- // Launch bounds
166
- static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
167
- static constexpr int kMinBlocksPerSm =
168
- getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
169
-
170
- struct Params {
171
- // Input tensors
172
- scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
173
- scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
174
- scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
175
- scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
176
- int32_t* seqstart_q_ptr = nullptr;
177
- int32_t* seqstart_k_ptr = nullptr;
178
-
179
- int32_t* seqlen_k_ptr = nullptr;
180
- uint32_t causal_diagonal_offset = 0;
181
-
182
- // Output tensors
183
- output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
184
- // [num_queries, num_heads, head_dim_value]
185
- output_accum_t* output_accum_ptr = nullptr;
186
- // [num_heads, num_queries] - can be null
187
- lse_scalar_t* logsumexp_ptr = nullptr;
188
-
189
- // Scale
190
- accum_t scale = 0.0;
191
-
192
- // Dimensions/strides
193
- int32_t head_dim = 0;
194
- int32_t head_dim_value = 0;
195
- int32_t num_queries = 0;
196
- int32_t num_keys = 0;
197
- int32_t num_keys_absolute = 0;
198
-
199
- uint8_t custom_mask_type = NoCustomMask;
200
-
201
- int32_t q_strideM = 0;
202
- int32_t k_strideM = 0;
203
- int32_t v_strideM = 0;
204
- int32_t bias_strideM = 0;
205
-
206
- int32_t o_strideM = 0;
207
-
208
- // Everything below is only used in `advance_to_block`
209
- // and shouldn't use registers
210
- int32_t q_strideH = 0;
211
- int32_t k_strideH = 0;
212
- int32_t v_strideH = 0;
213
- int64_t bias_strideH = 0;
214
-
215
- int64_t q_strideB = 0;
216
- int64_t k_strideB = 0;
217
- int64_t v_strideB = 0;
218
- int64_t bias_strideB = 0;
219
-
220
- int32_t num_batches = 0;
221
- int32_t num_heads = 0;
222
-
223
- // dropout
224
- bool use_dropout = false;
225
- unsigned long long dropout_batch_head_rng_offset = 0;
226
- float dropout_prob = 0.0f;
227
- #ifdef HAS_PYTORCH
228
- at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
229
- #endif
230
-
231
- // Moves pointers to what we should process
232
- // Returns "false" if there is no work to do
233
- CUTLASS_DEVICE bool advance_to_block() {
234
- auto batch_id = blockIdx.z;
235
- auto head_id = blockIdx.y;
236
- auto query_start = blockIdx.x * kQueriesPerBlock;
237
-
238
- auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
239
-
240
- if (kSupportsDropout) {
241
- dropout_batch_head_rng_offset =
242
- batch_id * num_heads * num_queries * num_keys +
243
- head_id * num_queries * num_keys;
244
- }
245
-
246
- int64_t q_start = 0, k_start = 0;
247
- // Advance to current batch - in case of different sequence lengths
248
- constexpr bool kToBatchHook =
249
- !cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
250
- value;
251
- if (kToBatchHook) {
252
- // Call out to a custom implementation.
253
- if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
254
- return false;
255
- }
256
- } else if (seqstart_q_ptr != nullptr) {
257
- assert(seqstart_k_ptr != nullptr);
258
- seqstart_q_ptr += batch_id;
259
-
260
- q_start = seqstart_q_ptr[0];
261
- int64_t q_next_start = seqstart_q_ptr[1];
262
- int64_t k_end;
263
- seqstart_k_ptr += batch_id;
264
-
265
- if (seqlen_k_ptr) {
266
- k_start = seqstart_k_ptr[0];
267
- k_end = k_start + seqlen_k_ptr[batch_id];
268
- } else {
269
- k_start = seqstart_k_ptr[0];
270
- k_end = seqstart_k_ptr[1];
271
- }
272
-
273
- num_queries = q_next_start - q_start;
274
- num_keys = k_end - k_start;
275
-
276
- if (query_start >= num_queries) {
277
- return false;
278
- }
279
- } else {
280
- query_ptr += batch_id * q_strideB;
281
- key_ptr += batch_id * k_strideB;
282
- value_ptr += batch_id * v_strideB;
283
- output_ptr += int64_t(batch_id * num_queries) * o_strideM;
284
- if (output_accum_ptr != nullptr) {
285
- output_accum_ptr +=
286
- int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
287
- }
288
- q_start = 0;
289
- k_start = 0;
290
- }
291
-
292
- // Advance to the current batch / head / query_start
293
- query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
294
- key_ptr += k_start * k_strideM + head_id * k_strideH;
295
-
296
- value_ptr += k_start * v_strideM + head_id * v_strideH;
297
- output_ptr +=
298
- int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
299
-
300
- if (kSupportsBias && attn_bias_ptr != nullptr) {
301
- attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
302
- }
303
- if (output_accum_ptr != nullptr) {
304
- output_accum_ptr +=
305
- int64_t(q_start + query_start) * (head_dim_value * num_heads) +
306
- head_id * head_dim_value;
307
- } else {
308
- // Accumulate directly in the destination buffer (eg for f32)
309
- output_accum_ptr = (accum_t*)output_ptr;
310
- }
311
-
312
- if (logsumexp_ptr != nullptr) {
313
- // lse[batch_id, head_id, query_start]
314
- logsumexp_ptr +=
315
- batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
316
- }
317
-
318
- // Custom masking
319
- if (custom_mask_type == CausalFromBottomRight) {
320
- causal_diagonal_offset = num_keys - num_queries;
321
- }
322
- // We use num_keys_absolute to index into the rng_state
323
- // We need this index to match between forward and backwards
324
- num_keys_absolute = num_keys;
325
- if (custom_mask_type == CausalFromTopLeft ||
326
- custom_mask_type == CausalFromBottomRight) {
327
- // the bottom row of the current block is query_start + kQueriesPerBlock
328
- // the last active key is then query_start + causal_diagonal_offset +
329
- // kQueriesPerBlock so num_keys is the min between actual num_keys and
330
- // this to avoid extra computations
331
- num_keys = cutlass::fast_min(
332
- int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
333
- num_keys);
334
- }
335
-
336
- num_queries -= query_start;
337
- num_batches = 0; // no longer used after
338
-
339
- // If num_queries == 1, and there is only one key head we're wasting
340
- // 15/16th of tensor core compute In that case :
341
- // - we only launch kernels for head_id % kQueriesPerBlock == 0
342
- // - we iterate over heads instead of queries (strideM = strideH)
343
- if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
344
- if (head_id % kQueriesPerBlock != 0)
345
- return false;
346
- q_strideM = q_strideH;
347
- num_queries = num_heads;
348
- num_heads = 1; // unused but here for intent
349
- // remove causal since n_query = 1
350
- // otherwise, offset would change with head !
351
- custom_mask_type = NoCustomMask;
352
- o_strideM = head_dim_value;
353
- }
354
-
355
- // Make sure the compiler knows these variables are the same on all
356
- // the threads of the warp.
357
- // Only worth doing if they could have been modified above.
358
- query_ptr = warp_uniform(query_ptr);
359
- key_ptr = warp_uniform(key_ptr);
360
- value_ptr = warp_uniform(value_ptr);
361
- if (kSupportsBias) {
362
- attn_bias_ptr = warp_uniform(attn_bias_ptr);
363
- }
364
- output_ptr = warp_uniform(output_ptr);
365
- output_accum_ptr = warp_uniform(output_accum_ptr);
366
- logsumexp_ptr = warp_uniform(logsumexp_ptr);
367
- num_queries = warp_uniform(num_queries);
368
- num_keys = warp_uniform(num_keys);
369
- num_heads = warp_uniform(num_heads);
370
- o_strideM = warp_uniform(o_strideM);
371
- custom_mask_type = warp_uniform(custom_mask_type);
372
- return true;
373
- }
374
-
375
- __host__ dim3 getBlocksGrid() const {
376
- return dim3(
377
- ceil_div(num_queries, (int32_t)kQueriesPerBlock),
378
- num_heads,
379
- num_batches);
380
- }
381
-
382
- __host__ dim3 getThreadsGrid() const {
383
- return dim3(kWarpSize, kNumWarpsPerBlock, 1);
384
- }
385
- };
386
-
387
- struct MM0 {
388
- /*
389
- In this first matmul, we compute a block of `Q @ K.T`.
390
- While the calculation result is still hot in registers, we update
391
- `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
392
- into a shared-memory ("AccumulatorSharedStorage") that is used later as
393
- operand A for the second matmul (see MM1)
394
- */
395
- using GemmType = DefaultGemmType<ArchTag, scalar_t>;
396
-
397
- using OpClass = typename GemmType::OpClass;
398
- using DefaultConfig =
399
- typename cutlass::gemm::device::DefaultGemmConfiguration<
400
- OpClass,
401
- ArchTag,
402
- scalar_t,
403
- scalar_t,
404
- scalar_t, // ElementC
405
- accum_t // ElementAccumulator
406
- >;
407
- static constexpr int kAlignmentA =
408
- kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
409
- static constexpr int kAlignmentB =
410
- kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
411
- using ThreadblockShape = cutlass::gemm::
412
- GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
413
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
414
- using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
415
- scalar_t, // ElementA,
416
- cutlass::layout::RowMajor, // LayoutA,
417
- kAlignmentA,
418
- scalar_t, // ElementB,
419
- cutlass::layout::ColumnMajor, // LayoutB,
420
- kAlignmentB,
421
- accum_t,
422
- cutlass::layout::RowMajor, // LayoutC,
423
- OpClass,
424
- ArchTag, // ArchTag
425
- ThreadblockShape, // ThreadblockShape
426
- WarpShape, // WarpShape
427
- typename GemmType::InstructionShape, // InstructionShape
428
- ArchTag::kMinComputeCapability >= 80 && kIsHalf
429
- ? 4
430
- : DefaultConfig::kStages,
431
- typename GemmType::Operator // Operator
432
- >::DefaultMma;
433
- using MmaCore = typename DefaultMma::MmaCore;
434
- using IteratorA = typename DefaultMma::IteratorA;
435
- using IteratorB = typename DefaultMma::IteratorB;
436
- using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
437
- using Mma = typename cutlass::platform::conditional<
438
- kSingleValueIteration,
439
- typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
440
- DefaultThreadblockMma>::type;
441
- using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
442
- typename Mma::Operator::IteratorC,
443
- accum_t,
444
- kWarpSize>::Iterator;
445
- static_assert(
446
- MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
447
- MmaCore::WarpCount::kK ==
448
- kNumWarpsPerBlock,
449
- "");
450
-
451
- // used for efficient load of bias tile Bij from global to shared memory
452
- using BiasLoader = TileSmemLoader<
453
- scalar_t,
454
- cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
455
- MmaCore::kThreads,
456
- // input restriction: kv_len has to be a multiple of this value
457
- 128 / cutlass::sizeof_bits<scalar_t>::value>;
458
-
459
- // Epilogue to store to shared-memory in a format that we can use later for
460
- // the second matmul
461
- using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
462
- typename Mma::Operator::IteratorC,
463
- typename Mma::Operator,
464
- scalar_t,
465
- WarpShape,
466
- ThreadblockShape>;
467
- using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
468
- };
469
-
470
- struct MM1 {
471
- /**
472
- Second matmul: perform `attn @ V` where `attn` is the attention (not
473
- normalized) and stored in shared memory
474
- */
475
- using GemmType = DefaultGemmType<ArchTag, scalar_t>;
476
-
477
- using OpClass = typename GemmType::OpClass;
478
- using DefaultConfig =
479
- typename cutlass::gemm::device::DefaultGemmConfiguration<
480
- OpClass,
481
- ArchTag,
482
- scalar_t,
483
- scalar_t,
484
- output_accum_t, // ElementC
485
- accum_t // ElementAccumulator
486
- >;
487
- static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
488
- static constexpr int kAlignmentB =
489
- kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
490
- using ThreadblockShape = cutlass::gemm::
491
- GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
492
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
493
- using InstructionShape = typename GemmType::InstructionShape;
494
-
495
- using LayoutB = cutlass::layout::RowMajor;
496
- using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
497
- scalar_t, // ElementA,
498
- cutlass::layout::RowMajor, // LayoutA,
499
- kAlignmentA,
500
- scalar_t, // ElementB,
501
- LayoutB, // LayoutB,
502
- kAlignmentB,
503
- output_accum_t,
504
- cutlass::layout::RowMajor, // LayoutC,
505
- accum_t,
506
- OpClass,
507
- ArchTag,
508
- ThreadblockShape,
509
- WarpShape,
510
- typename GemmType::InstructionShape,
511
- typename DefaultConfig::EpilogueOutputOp,
512
- void, // ThreadblockSwizzle - not used
513
- ArchTag::kMinComputeCapability >= 80 && kIsHalf
514
- ? 4
515
- : DefaultConfig::kStages,
516
- false, // SplitKSerial
517
- typename GemmType::Operator>;
518
-
519
- using WarpIteratorA = typename cutlass::gemm::threadblock::
520
- DefaultWarpIteratorAFromSharedMemory<
521
- typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
522
- typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
523
- typename DefaultGemm::Mma::Policy::Operator::IteratorA,
524
- typename DefaultGemm::Mma::Policy>::WarpIterator;
525
- using DefaultMmaFromSmem =
526
- typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
527
- typename DefaultGemm::Mma,
528
- MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
529
- WarpIteratorA,
530
- false>; // kScaleOperandA
531
- using Mma = typename DefaultMmaFromSmem::Mma;
532
- using IteratorB = typename Mma::IteratorB;
533
- using WarpCount = typename Mma::WarpCount;
534
- static_assert(
535
- WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
536
- "");
537
-
538
- using DefaultEpilogue = typename DefaultGemm::Epilogue;
539
- using OutputTileIterator =
540
- typename cutlass::epilogue::threadblock::PredicatedTileIterator<
541
- typename DefaultEpilogue::OutputTileIterator::ThreadMap,
542
- output_t>;
543
- using OutputTileIteratorAccum =
544
- typename cutlass::epilogue::threadblock::PredicatedTileIterator<
545
- typename DefaultEpilogue::OutputTileIterator::ThreadMap,
546
- output_accum_t>;
547
- };
548
-
549
- static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
550
- static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
551
- static constexpr int64_t kAlignmentV = 1;
552
-
553
- // Shared storage - depends on kernel params
554
- struct ScalingCoefs {
555
- cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
556
- cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
557
- cutlass::Array<accum_t, kQueriesPerBlock> mi;
558
- cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
559
- cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
560
- addition_storage;
561
- };
562
-
563
- struct SharedStorageEpilogueAtEnd : ScalingCoefs {
564
- struct SharedStorageAfterMM0 {
565
- // Everything here might be overwritten during MM0
566
- union {
567
- typename MM0::BiasLoader::SmemTile bias;
568
- typename MM0::AccumulatorSharedStorage si;
569
- };
570
- typename MM1::Mma::SharedStorage mm1;
571
- };
572
-
573
- union {
574
- typename MM0::Mma::SharedStorage mm0;
575
- SharedStorageAfterMM0 after_mm0;
576
- typename MM1::DefaultEpilogue::SharedStorage epilogue;
577
- };
578
-
579
- CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
580
- epilogue_shared_storage() {
581
- return epilogue;
582
- }
583
- };
584
-
585
- struct SharedStorageEpilogueInLoop : ScalingCoefs {
586
- struct SharedStorageAfterMM0 {
587
- // Everything here might be overwritten during MM0
588
- union {
589
- typename MM0::BiasLoader::SmemTile bias;
590
- typename MM0::AccumulatorSharedStorage si;
591
- };
592
- typename MM1::Mma::SharedStorage mm1;
593
- typename MM1::DefaultEpilogue::SharedStorage epilogue;
594
- };
595
-
596
- union {
597
- typename MM0::Mma::SharedStorage mm0;
598
- SharedStorageAfterMM0 after_mm0;
599
- };
600
-
601
- CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
602
- epilogue_shared_storage() {
603
- return after_mm0.epilogue;
604
- }
605
- };
606
-
607
- using SharedStorage = typename cutlass::platform::conditional<
608
- kSingleValueIteration || kKeepOutputInRF,
609
- SharedStorageEpilogueAtEnd,
610
- SharedStorageEpilogueInLoop>::type;
611
-
612
- static bool __host__ check_supported(Params const& p) {
613
- CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
614
- CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
615
- CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
616
- if (kSupportsBias) {
617
- CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
618
- XFORMERS_CHECK(
619
- p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
620
- "attn_bias is not correctly aligned (strideB)");
621
- XFORMERS_CHECK(
622
- p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
623
- "attn_bias is not correctly aligned (strideH)");
624
- XFORMERS_CHECK(
625
- p.bias_strideM % kAlignmentQ == 0,
626
- "attn_bias is not correctly aligned");
627
- }
628
- XFORMERS_CHECK(
629
- p.q_strideM % kAlignmentQ == 0,
630
- "query is not correctly aligned (strideM)");
631
- XFORMERS_CHECK(
632
- p.k_strideM % kAlignmentK == 0,
633
- "key is not correctly aligned (strideM)");
634
- XFORMERS_CHECK(
635
- p.v_strideM % kAlignmentV == 0,
636
- "value is not correctly aligned (strideM)");
637
- XFORMERS_CHECK(
638
- p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
639
- "query is not correctly aligned (strideH)");
640
- XFORMERS_CHECK(
641
- p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
642
- "key is not correctly aligned (strideH)");
643
- XFORMERS_CHECK(
644
- p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
645
- "value is not correctly aligned (strideH)");
646
- XFORMERS_CHECK(
647
- p.custom_mask_type < NumCustomMaskTypes,
648
- "invalid value for `custom_mask_type`");
649
- return true;
650
- }
651
-
652
- static void CUTLASS_DEVICE attention_kernel(Params& p) {
653
- // In this block, we will only ever:
654
- // - read query[query_start:query_end, :]
655
- // - write to output[query_start:query_end, :]
656
-
657
- extern __shared__ char smem_buffer[];
658
- SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
659
- auto& m_prime = shared_storage.m_prime;
660
- auto& s_prime = shared_storage.s_prime;
661
- auto& mi = shared_storage.mi;
662
- auto& out_rescale = shared_storage.out_rescale;
663
- const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
664
-
665
- static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
666
- if (thread_id() < kQueriesPerBlock) {
667
- s_prime[thread_id()] = accum_t(0);
668
- out_rescale[thread_id()] = accum_t(1.0);
669
- m_prime[thread_id()] =
670
- -cutlass::platform::numeric_limits<accum_t>::infinity();
671
- mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
672
- }
673
- typename MM1::Mma::FragmentC accum_o;
674
- accum_o.clear();
675
-
676
- auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
677
- using OutputTileIterator = typename MM1::OutputTileIterator;
678
- return OutputTileIterator(
679
- typename OutputTileIterator::Params{(int32_t)p.o_strideM},
680
- p.output_ptr,
681
- typename OutputTileIterator::TensorCoord{
682
- p.num_queries, p.head_dim_value},
683
- thread_id(),
684
- {0, col});
685
- };
686
-
687
- auto createOutputAccumIter = [&](int col) ->
688
- typename MM1::OutputTileIteratorAccum {
689
- using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
690
- return OutputTileIteratorAccum(
691
- typename OutputTileIteratorAccum::Params{
692
- (int32_t)(p.head_dim_value * p.num_heads)},
693
- p.output_accum_ptr,
694
- typename OutputTileIteratorAccum::TensorCoord{
695
- p.num_queries, p.head_dim_value},
696
- thread_id(),
697
- {0, col});
698
- };
699
-
700
- #ifdef HAS_PYTORCH
701
- curandStatePhilox4_32_10_t curand_state_init;
702
- if (kSupportsDropout && p.use_dropout) {
703
- const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
704
-
705
- // each element of the attention matrix P with shape
706
- // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
707
- // offset in RNG sequence. we initialize the RNG state with offset that
708
- // starts at the beginning of a (n_queries, n_keys) matrix for this
709
- // block's batch_id and head_id
710
- // initializing rng state is very expensive, so we run once per kernel,
711
- // rather than once per iteration. each iteration takes a copy of the
712
- // initialized RNG state and offsets it as needed.
713
- curand_init(
714
- std::get<0>(seeds),
715
- 0,
716
- std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
717
- &curand_state_init);
718
- }
719
- #endif
720
-
721
- // Iterate through keys
722
- for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
723
- iter_key_start += kKeysPerBlock) {
724
- int32_t problem_size_0_m =
725
- cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
726
- int32_t problem_size_0_n = cutlass::fast_min(
727
- int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
728
- int32_t const& problem_size_0_k = p.head_dim;
729
- int32_t const& problem_size_1_n = p.head_dim_value;
730
- int32_t const& problem_size_1_k = problem_size_0_n;
731
-
732
- auto prologueV = [&](int blockN) {
733
- typename MM1::Mma::IteratorB iterator_V(
734
- typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
735
- p.value_ptr + iter_key_start * p.v_strideM,
736
- {problem_size_1_k, problem_size_1_n},
737
- thread_id(),
738
- cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
739
- MM1::Mma::prologue(
740
- shared_storage.after_mm0.mm1,
741
- iterator_V,
742
- thread_id(),
743
- problem_size_1_k);
744
- };
745
-
746
- __syncthreads(); // Need to have shared memory initialized, and `m_prime`
747
- // updated from end of prev iter
748
- //
749
- // MATMUL: Q.K_t
750
- //
751
- // Computes the block-matrix product of:
752
- // (a) query[query_start:query_end, :]
753
- // with
754
- // (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
755
- // and stores that into `shared_storage.si`
756
- //
757
-
758
- // Compute threadblock location
759
- cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
760
-
761
- cutlass::MatrixCoord tb_offset_A{
762
- tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
763
-
764
- cutlass::MatrixCoord tb_offset_B{
765
- tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
766
-
767
- // Construct iterators to A and B operands
768
- typename MM0::IteratorA iterator_A(
769
- typename MM0::IteratorA::Params(
770
- typename MM0::MmaCore::LayoutA(p.q_strideM)),
771
- p.query_ptr,
772
- {problem_size_0_m, problem_size_0_k},
773
- thread_id(),
774
- tb_offset_A);
775
-
776
- typename MM0::IteratorB iterator_B(
777
- typename MM0::IteratorB::Params(
778
- typename MM0::MmaCore::LayoutB(p.k_strideM)),
779
- p.key_ptr + iter_key_start * p.k_strideM,
780
- {problem_size_0_k, problem_size_0_n},
781
- thread_id(),
782
- tb_offset_B);
783
-
784
- auto my_warp_id = warp_uniform(warp_id());
785
- auto my_lane_id = lane_id();
786
-
787
- // Construct thread-scoped matrix multiply
788
- typename MM0::Mma mma(
789
- shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
790
-
791
- typename MM0::Mma::FragmentC accum;
792
-
793
- accum.clear();
794
-
795
- auto gemm_k_iterations =
796
- (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
797
-
798
- // Compute threadblock-scoped matrix multiply-add
799
- mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
800
- __syncthreads();
801
-
802
- if (kPreloadV) {
803
- prologueV(0);
804
- } else {
805
- MM1::Mma::drain_cp_asyncs();
806
- }
807
-
808
- typename MM0::Mma::Operator::IteratorC::TensorCoord
809
- iteratorC_tile_offset = {
810
- (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
811
- (my_warp_id % MM0::Mma::WarpCount::kM),
812
- (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
813
- (my_warp_id / MM0::Mma::WarpCount::kM)};
814
-
815
- // multiply by scaling factor
816
- if (kSupportsBias) {
817
- accum =
818
- cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
819
- }
820
-
821
- // apply attention bias if applicable
822
- if (kSupportsBias && p.attn_bias_ptr != nullptr) {
823
- // load bias tile Bij into shared memory
824
- typename MM0::BiasLoader::GmemTileIterator bias_iter(
825
- {cutlass::layout::RowMajor(p.bias_strideM)},
826
- // attn_bias_pointer points to matrix of size (n_queries, n_keys)
827
- // for the relevant batch_id and head_id
828
- p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
829
- {problem_size_0_m, problem_size_0_n},
830
- thread_id());
831
- cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
832
- shared_storage.after_mm0.bias.data(),
833
- cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
834
- typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
835
- bias_tensor_ref, thread_id());
836
- MM0::BiasLoader::load(bias_iter, smem_tile_iter);
837
-
838
- // Pij += Bij, Pij is in register fragment and Bij is in shared memory
839
- auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
840
- my_lane_id, my_warp_id, iteratorC_tile_offset);
841
- MM0::AccumLambdaIterator::iterateRows(
842
- lane_offset,
843
- [&](int accum_m) {},
844
- [&](int accum_m, int accum_n, int idx) {
845
- if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
846
- accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
847
- }
848
- },
849
- [&](int accum_m) {});
850
- }
851
-
852
- // Mask out last if causal
853
- // This is only needed if upper-right corner of current query / key block
854
- // intersects the mask Coordinates of upper-right corner of current block
855
- // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
856
- // first masked element is x = y + offset -> query_start + offset There is
857
- // intersection (and we need to mask) if min(iter_key_start +
858
- // kKeysPerBlock, num_keys)) >= query_start + offset
859
- if (p.custom_mask_type &&
860
- cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
861
- (query_start + p.causal_diagonal_offset)) {
862
- auto query_start = blockIdx.x * kQueriesPerBlock;
863
- auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
864
- my_lane_id, my_warp_id, iteratorC_tile_offset);
865
- int32_t last_col;
866
- MM0::AccumLambdaIterator::iterateRows(
867
- lane_offset,
868
- [&](int accum_m) {
869
- // last absolute col is (last absolute query + offset)
870
- // last local col is (last absolute query + offset -
871
- // iter_key_start)
872
- last_col = query_start + accum_m + p.causal_diagonal_offset -
873
- iter_key_start;
874
- },
875
- [&](int accum_m, int accum_n, int idx) {
876
- if (accum_n > last_col) {
877
- accum[idx] =
878
- -cutlass::platform::numeric_limits<accum_t>::infinity();
879
- }
880
- },
881
- [&](int accum_m) {});
882
- }
883
- // Update `mi` from accum stored in registers
884
- // Also does accum[i] <- exp(accum[i] - mi)
885
- iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
886
- accum_o,
887
- accum,
888
- mi,
889
- m_prime,
890
- s_prime,
891
- out_rescale,
892
- shared_storage.addition_storage,
893
- my_lane_id,
894
- thread_id(),
895
- my_warp_id,
896
- p.num_keys - iter_key_start,
897
- iter_key_start == 0,
898
- iteratorC_tile_offset,
899
- kSupportsBias ? 1.0f : p.scale);
900
-
901
- // Output results to shared-memory
902
- int warp_idx_mn_0 = my_warp_id %
903
- (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
904
- auto output_tile_coords = cutlass::MatrixCoord{
905
- warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
906
- warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
907
-
908
- MM0::B2bGemm::accumToSmem(
909
- shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
910
-
911
- __syncthreads();
912
-
913
- #ifdef HAS_PYTORCH
914
- // apply dropout (if applicable) after we've written Pij to smem.
915
- // dropout is applied by multiplying each element of Pij by:
916
- // - 0 with probability dropout_p
917
- // - 1 / (1 - dropout_p) with probability 1 - dropout_p
918
- //
919
- // for backward purposes we want to be able to map each element of the
920
- // attention matrix to the same random uniform number as the one we used
921
- // in forward, without needing to use the same iteration order or having
922
- // to store the dropout matrix. its possible to do this in registers but
923
- // it ends up being very slow because each thread having noncontiguous
924
- // strips of the Pij tile means we have to skip around a lot, and also
925
- // have to generate a single random number at a time
926
- if (kSupportsDropout && p.use_dropout) {
927
- auto si = shared_storage.after_mm0.si.accum_ref();
928
- // each thread handles a contiguous sequence of elements from Sij, all
929
- // coming from the same row. the reason they have to come from the same
930
- // row is that the sampling random numbers from a contiguous random
931
- // number sequence is much more efficient than jumping around, and the
932
- // linear offset of each element of S (the global matrix) maps to an
933
- // offset in a random number sequence. for S, the end of a row and the
934
- // beginning of the next have adjacent offsets, but for Sij, this is not
935
- // necessarily the case.
936
- const int num_threads = blockDim.x * blockDim.y * blockDim.z;
937
- const int threads_per_row =
938
- cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
939
- const int elts_per_thread = cutlass::round_nearest(
940
- cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
941
-
942
- const int thread_i = thread_id() / threads_per_row;
943
- const int thread_start_j =
944
- (thread_id() % threads_per_row) * elts_per_thread;
945
-
946
- if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
947
- curandStatePhilox4_32_10_t curand_state = curand_state_init;
948
- skipahead(
949
- static_cast<unsigned long long>(
950
- (query_start + thread_i) * p.num_keys_absolute +
951
- (iter_key_start + thread_start_j)),
952
- &curand_state);
953
- const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
954
-
955
- // apply dropout scaling to elements this thread is responsible for,
956
- // in chunks of 4
957
- for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
958
- cutlass::fast_min(thread_start_j + elts_per_thread,
959
- problem_size_0_n);
960
- sij_start_col_idx += 4) {
961
- const float4 rand_uniform_quad = curand_uniform4(&curand_state);
962
-
963
- CUTLASS_PRAGMA_UNROLL
964
- for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
965
- si.at({thread_i, sij_start_col_idx + quad_idx}) *=
966
- static_cast<scalar_t>(
967
- dropout_scale *
968
- ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
969
- }
970
- }
971
- }
972
- __syncthreads(); // p.use_dropout should have same value kernel-wide
973
- }
974
- #endif
975
-
976
- //
977
- // MATMUL: Attn . V
978
- // Run the matmul `attn @ V` for a block of attn and V.
979
- // `attn` is read from shared memory (in `shared_storage_si`)
980
- // `V` is read from global memory (with iterator_B)
981
- //
982
-
983
- const int64_t nBlockN = kSingleValueIteration
984
- ? 1
985
- : ceil_div(
986
- (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
987
- for (int blockN = 0; blockN < nBlockN; ++blockN) {
988
- int gemm_k_iterations =
989
- (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
990
-
991
- // Compute threadblock-scoped matrix multiply-add and store it in accum
992
- // (in registers)
993
- if (!kPreloadV) {
994
- __syncthreads(); // we share shmem between mma and epilogue
995
- }
996
-
997
- typename MM1::Mma::IteratorB iterator_V(
998
- typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
999
- p.value_ptr + iter_key_start * p.v_strideM,
1000
- {problem_size_1_k, problem_size_1_n},
1001
- thread_id(),
1002
- cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
1003
- typename MM1::Mma mma_pv(
1004
- // operand A: Pij_dropped in shared memory
1005
- shared_storage.after_mm0.si.accum_ref(),
1006
- // operand B: shared memory staging area for Vj, which is loaded
1007
- // from global memory
1008
- shared_storage.after_mm0.mm1.operand_B_ref(),
1009
- (int)thread_id(),
1010
- (int)my_warp_id,
1011
- (int)my_lane_id);
1012
- mma_pv.set_prologue_done(kPreloadV);
1013
- if (!kKeepOutputInRF) {
1014
- accum_o.clear();
1015
- }
1016
- mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
1017
- __syncthreads();
1018
-
1019
- if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
1020
- prologueV(blockN + 1);
1021
- }
1022
-
1023
- if (!kKeepOutputInRF) {
1024
- MM1::Mma::drain_cp_asyncs();
1025
- DISPATCH_BOOL(
1026
- iter_key_start == 0, kIsFirst, ([&] {
1027
- DISPATCH_BOOL(
1028
- (iter_key_start + kKeysPerBlock) >= p.num_keys,
1029
- kIsLast,
1030
- ([&] {
1031
- using DefaultEpilogue = typename MM1::DefaultEpilogue;
1032
- using DefaultOp =
1033
- typename MM1::DefaultConfig::EpilogueOutputOp;
1034
- using ElementCompute = typename DefaultOp::ElementCompute;
1035
- using EpilogueOutputOp = typename cutlass::epilogue::
1036
- thread::MemoryEfficientAttentionNormalize<
1037
- typename cutlass::platform::conditional<
1038
- kIsLast::value,
1039
- output_t,
1040
- output_accum_t>::type,
1041
- output_accum_t,
1042
- DefaultOp::kCount,
1043
- typename DefaultOp::ElementAccumulator,
1044
- ElementCompute,
1045
- kIsFirst::value,
1046
- kIsLast::value,
1047
- cutlass::Array<ElementCompute, kQueriesPerBlock>>;
1048
- using Epilogue = typename cutlass::epilogue::threadblock::
1049
- EpiloguePipelined<
1050
- typename DefaultEpilogue::Shape,
1051
- typename MM1::Mma::Operator,
1052
- DefaultEpilogue::kPartitionsK,
1053
- typename cutlass::platform::conditional<
1054
- kIsLast::value,
1055
- typename MM1::OutputTileIterator,
1056
- typename MM1::OutputTileIteratorAccum>::type,
1057
- typename DefaultEpilogue::
1058
- AccumulatorFragmentIterator,
1059
- typename DefaultEpilogue::WarpTileIterator,
1060
- typename DefaultEpilogue::SharedLoadIterator,
1061
- EpilogueOutputOp,
1062
- typename DefaultEpilogue::Padding,
1063
- DefaultEpilogue::kFragmentsPerIteration,
1064
- true, // IterationsUnroll
1065
- typename MM1::OutputTileIteratorAccum // Read
1066
- // iterator
1067
- >;
1068
-
1069
- int col = blockN * MM1::Mma::Shape::kN;
1070
- auto source_iter = createOutputAccumIter(col);
1071
- auto dest_iter = call_conditional<
1072
- kIsLast::value,
1073
- decltype(createOutputIter),
1074
- decltype(createOutputAccumIter)>::
1075
- apply(createOutputIter, createOutputAccumIter, col);
1076
- EpilogueOutputOp rescale(s_prime, out_rescale);
1077
- Epilogue epilogue(
1078
- shared_storage.epilogue_shared_storage(),
1079
- thread_id(),
1080
- my_warp_id,
1081
- my_lane_id);
1082
- epilogue(rescale, dest_iter, accum_o, source_iter);
1083
- }));
1084
- }));
1085
- if (!kSingleValueIteration) {
1086
- __syncthreads();
1087
- }
1088
- }
1089
- }
1090
- __syncthreads(); // we modify `m_prime` after
1091
- }
1092
-
1093
- if (kKeepOutputInRF) {
1094
- constexpr bool kIsFirst = true;
1095
- constexpr bool kIsLast = true;
1096
- using DefaultEpilogue = typename MM1::DefaultEpilogue;
1097
- using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
1098
- using ElementCompute = typename DefaultOp::ElementCompute;
1099
- using EpilogueOutputOp =
1100
- typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
1101
- output_t, // output
1102
- output_accum_t, // source
1103
- DefaultOp::kCount,
1104
- typename DefaultOp::ElementAccumulator, // accum
1105
- output_accum_t, // compute
1106
- kIsFirst,
1107
- kIsLast,
1108
- cutlass::Array<ElementCompute, kQueriesPerBlock>>;
1109
- using Epilogue =
1110
- typename cutlass::epilogue::threadblock::EpiloguePipelined<
1111
- typename DefaultEpilogue::Shape,
1112
- typename MM1::Mma::Operator,
1113
- DefaultEpilogue::kPartitionsK,
1114
- typename MM1::OutputTileIterator, // destination
1115
- typename DefaultEpilogue::AccumulatorFragmentIterator,
1116
- typename DefaultEpilogue::WarpTileIterator,
1117
- typename DefaultEpilogue::SharedLoadIterator,
1118
- EpilogueOutputOp,
1119
- typename DefaultEpilogue::Padding,
1120
- DefaultEpilogue::kFragmentsPerIteration,
1121
- true, // IterationsUnroll
1122
- typename MM1::OutputTileIteratorAccum // source tile
1123
- >;
1124
- auto dest_iter = createOutputIter(0);
1125
- EpilogueOutputOp rescale(s_prime, out_rescale);
1126
- Epilogue epilogue(
1127
- shared_storage.epilogue_shared_storage(),
1128
- thread_id(),
1129
- warp_id(),
1130
- lane_id());
1131
- MM1::Mma::drain_cp_asyncs();
1132
- epilogue(rescale, dest_iter, accum_o);
1133
- }
1134
-
1135
- // 7. Calculate logsumexp
1136
- // To make the backward easier, we pad logsumexp with `inf`
1137
- // this avoids a few bound checks, and is not more expensive during fwd
1138
- static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
1139
- if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
1140
- auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
1141
- constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
1142
- if (thread_id() < p.num_queries) {
1143
- p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
1144
- cutlass::fast_log(accum_t(s_prime[thread_id()]));
1145
- } else if (thread_id() < lse_dim) {
1146
- p.logsumexp_ptr[thread_id()] =
1147
- cutlass::platform::numeric_limits<accum_t>::infinity();
1148
- }
1149
- }
1150
- }
1151
-
1152
- template <typename WarpIteratorC>
1153
- CUTLASS_DEVICE static void iterative_softmax(
1154
- typename WarpIteratorC::Fragment& frag_o, // output so far
1155
- typename WarpIteratorC::Fragment& frag,
1156
- cutlass::Array<accum_t, kQueriesPerBlock>& mi,
1157
- cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
1158
- cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
1159
- cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
1160
- cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
1161
- addition_storage,
1162
- int8_t lane_id,
1163
- int8_t thread_id,
1164
- int8_t warp_id,
1165
- int max_col,
1166
- bool is_first,
1167
- typename WarpIteratorC::TensorCoord const& tile_offset,
1168
- float scaling) {
1169
- /* Iterates on the accumulator and corresponding position on result matrix
1170
-
1171
- (1) Update `mi[r]` to the max value of the row `r`
1172
- (2) In a second iteration do the following:
1173
- (a) accum <- exp(accum - mi)
1174
- (b) m_prime <- exp(m_prime - mi)
1175
- (c) s_prime <- s_prime * m_prime + sum(accum)
1176
-
1177
- All of this is done on registers, before we store all of this
1178
- on shared memory for the next matmul with Value.
1179
- */
1180
- using Fragment = typename WarpIteratorC::Fragment;
1181
- using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
1182
- WarpIteratorC,
1183
- accum_t,
1184
- kWarpSize>::Iterator;
1185
- // Convert to `accum_t` (rather than double)
1186
- constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
1187
-
1188
- static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
1189
- static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
1190
-
1191
- frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
1192
-
1193
- auto lane_offset =
1194
- LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
1195
-
1196
- // First update `mi` to the max per-row
1197
- {
1198
- accum_t max;
1199
- LambdaIterator::iterateRows(
1200
- lane_offset,
1201
- [&](int accum_m) {
1202
- max = -cutlass::platform::numeric_limits<accum_t>::infinity();
1203
- },
1204
- [&](int accum_m, int accum_n, int idx) {
1205
- if (accum_n < max_col) {
1206
- max = cutlass::fast_max(max, frag[idx]);
1207
- }
1208
- },
1209
- [&](int accum_m) {
1210
- // Having 4x atomicMax seems faster than reduce within warp
1211
- // first...
1212
- atomicMaxFloat(&mi[accum_m], max);
1213
- });
1214
- }
1215
-
1216
- // Make sure we all share the update values for `mi`
1217
- __syncthreads();
1218
-
1219
- // Doing this `exp` is quite expensive. Let's
1220
- // split it across the warps
1221
- bool restore_mi_to_minus_inf = false;
1222
- if (lane_id < kLinesPerWarp) {
1223
- int id = warp_id * kLinesPerWarp + lane_id;
1224
- auto m_prime_id = m_prime[id];
1225
- auto mi_id = mi[id];
1226
- bool changed = m_prime_id < mi_id; // `false` if both are -inf
1227
- if (changed) {
1228
- auto m_prime_exp = exp2f(m_prime_id - mi_id);
1229
- out_rescale[id] = m_prime_exp;
1230
- s_prime[id] *= m_prime_exp;
1231
- } else {
1232
- // Only when bias is enabled, it's possible that all the first values
1233
- // of attention are masked to `-inf`. In that case we want to avoid
1234
- // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
1235
- if (kSupportsBias &&
1236
- mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
1237
- restore_mi_to_minus_inf = true;
1238
- mi[id] = 0.0f;
1239
- }
1240
- out_rescale[id] = 1.0f;
1241
- }
1242
- }
1243
- __syncthreads(); // Update output fragments
1244
- if (kKeepOutputInRF && !is_first) {
1245
- accum_t line_rescale;
1246
- LambdaIterator::iterateRows(
1247
- lane_offset,
1248
- [&](int accum_m) { line_rescale = out_rescale[accum_m]; },
1249
- [&](int accum_m, int accum_n, int idx) {
1250
- frag_o[idx] = frag_o[idx] * line_rescale;
1251
- },
1252
- [&](int accum_m) {});
1253
- }
1254
- // Update accum_m, accum_n, ...
1255
- {
1256
- accum_t mi_row, total_row;
1257
- LambdaIterator::iterateRows(
1258
- lane_offset,
1259
- [&](int accum_m) { mi_row = mi[accum_m]; },
1260
- [&](int accum_m, int accum_n, int idx) {
1261
- frag[idx] =
1262
- (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
1263
- },
1264
- [&](int accum_m) {});
1265
- LambdaIterator::iterateRows(
1266
- lane_offset,
1267
- [&](int accum_m) { total_row = 0.0; },
1268
- [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
1269
- [&](int accum_m) {
1270
- if (LambdaIterator::reduceSameRow(
1271
- lane_id, total_row, [](accum_t a, accum_t b) {
1272
- return a + b;
1273
- })) {
1274
- // NOTE: we could atomically add `total_row` to `s_prime`, but
1275
- // it's faster (and deterministic) to avoid atomics here
1276
- addition_storage
1277
- [accum_m + kQueriesPerBlock * tile_offset.column()] =
1278
- total_row;
1279
- }
1280
- });
1281
- }
1282
- __syncthreads();
1283
- if (lane_id < kLinesPerWarp) {
1284
- int id = warp_id * kLinesPerWarp + lane_id;
1285
- accum_t total_row = s_prime[id];
1286
- if (restore_mi_to_minus_inf) {
1287
- // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
1288
- mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
1289
- } else {
1290
- m_prime[id] = mi[id];
1291
- }
1292
- CUTLASS_PRAGMA_UNROLL
1293
- for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
1294
- total_row += addition_storage[id + kQueriesPerBlock * i];
1295
- }
1296
- s_prime[id] = total_row;
1297
- }
1298
- }
1299
-
1300
- static CUTLASS_DEVICE int8_t lane_id() {
1301
- return threadIdx.x;
1302
- }
1303
- static CUTLASS_DEVICE int8_t warp_id() {
1304
- return threadIdx.y;
1305
- }
1306
- static CUTLASS_DEVICE int16_t thread_id() {
1307
- return threadIdx.x + threadIdx.y * blockDim.x;
1308
- }
1309
- };
1310
-
1311
- template <typename AK>
1312
- __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
1313
- attention_kernel_batched_impl(typename AK::Params p) {
1314
- if (!p.advance_to_block()) {
1315
- return;
1316
- }
1317
- AK::attention_kernel(p);
1318
- }
1319
-
1320
- template <typename AK>
1321
- __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
1322
- attention_kernel_batched(typename AK::Params params);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py DELETED
@@ -1,144 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from typing import List
34
- import torch
35
- import subprocess
36
- import sys
37
- import tempfile
38
- import os
39
- import numpy as np
40
-
41
-
42
- TORCH_DTYPE_NAME = {
43
- torch.float32: "f32",
44
- torch.float16: "f16",
45
- torch.bfloat16: "b16"
46
- }
47
- NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
48
-
49
- def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
50
- # PyTorch >= 2.0
51
- if hasattr(tensor, 'untyped_storage'):
52
- return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
53
- return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
54
-
55
- class PipedSubprocess:
56
- def __init__(self, binary: str) -> None:
57
- self.binary = binary
58
- self.tempdir_ctx = tempfile.TemporaryDirectory()
59
-
60
- def __enter__(self) -> "PipedSubprocess":
61
- self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
62
- self.tempdir = self.tempdir_ctx.__enter__()
63
- self.file_counter = 0
64
- return self
65
-
66
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
67
- self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
68
-
69
- def temp_filename(self, suffix: str) -> str:
70
- self.file_counter += 1
71
- return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
72
-
73
- def write(self, *args) -> None:
74
- for a in args:
75
- self.subp.stdin.write(str(a) + " ")
76
-
77
- def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
78
- print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
79
- tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
80
- self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
81
- filename = self.temp_filename(f"{name}.tensor")
82
- assert tensor.storage_offset() == 0
83
- with open(filename, "wb+") as fd:
84
- fd.write(bytes(tensor_u8.numpy()))
85
- self.write("file", filename)
86
- self.write("tensor_end")
87
-
88
- for stride_name, stride_value in zip(stride_names, tensor.stride()):
89
- self.write(stride_name, stride_value)
90
-
91
- def readTensor(self, name, stride_name, shape) -> torch.Tensor:
92
- tmpfile = self.temp_filename(f"{name}.tensor")
93
- self.write("tmpfile", tmpfile)
94
-
95
- self.readExpect("tensor_begin")
96
- dtype_str, name = self.read().split(":")
97
- print(f"C++->Py : {dtype_str}:{name}")
98
- u8len = int(self.read())
99
- dtype = NAME_TORCH_DTYPE[dtype_str]
100
-
101
- self.readExpect("file")
102
- self.readExpect(tmpfile)
103
-
104
- with open(tmpfile, "rb") as fd:
105
- data = fd.read(u8len)
106
- # `np.array` is not strictly needed, but avoids a torch warning
107
- tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
108
- self.readExpect("tensor_end")
109
-
110
- tensor = _tensor_from_storage(tensor_u8, dtype)
111
- strides = []
112
- for sn in stride_name:
113
- self.readExpect(sn)
114
- strides.append(int(self.read()))
115
- if len(strides) != shape:
116
- strides.append(1)
117
- assert len(strides) == len(shape), name
118
- return torch.as_strided(tensor, shape, strides)
119
-
120
- def readNamed(self, name: str):
121
- self.readExpect(name)
122
- return self.read()
123
-
124
- def readExpect(self, what: str) -> None:
125
- r = self.read()
126
- if r != what:
127
- raise ValueError(f"Read {r} but expected {what}")
128
-
129
- def read(self):
130
- read_all = []
131
- # Skip initial whitespace
132
- while True:
133
- r = self.subp.stdout.read(1)
134
- if r not in [' ', "\n"]:
135
- read_all.append(r)
136
- break
137
- # Read data
138
- while True:
139
- r = self.subp.stdout.read(1)
140
- if r in [' ', "\n"]:
141
- break
142
- read_all.append(r)
143
- return ''.join(read_all)
144
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h DELETED
@@ -1,90 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include <cutlass/cutlass.h>
35
- #include "cutlass/aligned_buffer.h"
36
- #include "cutlass/array.h"
37
- #include "cutlass/layout/matrix.h"
38
- #include "cutlass/layout/pitch_linear.h"
39
- #include "cutlass/numeric_types.h"
40
- #include "cutlass/transform/pitch_linear_thread_map.h"
41
- #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
42
- #include "cutlass/transform/threadblock/regular_tile_iterator.h"
43
-
44
- template <
45
- typename scalar_t, // scalar type
46
- typename ThreadblockTileShape, // size of tile to load
47
- int Threads, // number of participating threads
48
- int ElementsPerAccess> // thread access width in elements
49
- class TileSmemLoader {
50
- public:
51
- using SmemTile =
52
- cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
53
-
54
- using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
55
- cutlass::layout::PitchLinearShape<
56
- ThreadblockTileShape::kColumn, // contiguous
57
- ThreadblockTileShape::kRow>, // strided
58
- Threads, // Threads
59
- ElementsPerAccess>; // ElementsPerAccess
60
-
61
- using GmemTileIterator =
62
- cutlass::transform::threadblock::PredicatedTileIterator<
63
- ThreadblockTileShape, // Shape
64
- scalar_t, // Element
65
- cutlass::layout::RowMajor, // Layout
66
- 0, // AdvanceRank
67
- ThreadMap>; // ThreadMap
68
-
69
- using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
70
- ThreadblockTileShape, // Shape
71
- scalar_t, // Element
72
- cutlass::layout::RowMajor, // Layout
73
- 0, // AdvanceRank
74
- ThreadMap>; // ThreadMap
75
-
76
- using Fragment = typename GmemTileIterator::Fragment;
77
-
78
- /// load a tile from global memory into shared memory
79
- CUTLASS_DEVICE
80
- static void load(
81
- GmemTileIterator tile_load_iter,
82
- SmemTileIterator tile_store_iter) {
83
- Fragment tb_frag;
84
- tb_frag.clear();
85
- tile_load_iter.load(tb_frag);
86
- tile_store_iter.store(tb_frag);
87
-
88
- __syncthreads();
89
- }
90
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h DELETED
@@ -1,154 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
34
-
35
- The epilogue rearranges the result of a matrix product through shared memory to match canonical
36
- tensor layouts in global memory. Epilogues support conversion and reduction operations.
37
-
38
- */
39
-
40
- #pragma once
41
-
42
- #include "cutlass/cutlass.h"
43
- #include "cutlass/numeric_types.h"
44
- #include "cutlass/array.h"
45
-
46
- #include "cutlass/gemm/gemm.h"
47
-
48
- #include "cutlass/epilogue/thread/linear_combination.h"
49
- #include "cutlass/epilogue/thread/linear_combination_clamp.h"
50
- #include "cutlass/epilogue/thread/conversion_op.h"
51
- #include "cutlass/epilogue/thread/reduction_op.h"
52
-
53
- #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
54
-
55
- #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
56
- #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
57
- #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
58
- #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
59
- #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
60
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
61
- #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
62
- #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
63
-
64
- // #include "cutlass/epilogue/threadblock/epilogue.h"
65
- #include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
66
-
67
- #include "fused_bias_act_epilogue.h"
68
- #include "../warp/fused_bias_act_fragment_iterator_tensor_op.h"
69
- #include "output_tile_thread_map_for_fused_bias.h"
70
- #include "default_thread_map_tensor_op_for_fused_bias.h"
71
-
72
- ////////////////////////////////////////////////////////////////////////////////
73
-
74
- namespace cutlass {
75
- namespace epilogue {
76
- namespace threadblock {
77
-
78
- ////////////////////////////////////////////////////////////////////////////////
79
-
80
-
81
- ////////////////////////////////////////////////////////////////////////////////
82
-
83
- /// Defines sensible defaults for epilogues for TensorOps.
84
- template <
85
- typename Shape_,
86
- typename WarpMmaTensorOp_,
87
- int PartitionsK,
88
- typename OutputOp_,
89
- int ElementsPerAccess
90
- >
91
- struct DefaultFusedBiasActEpilogueTensorOp {
92
-
93
- using Shape = Shape_;
94
- using WarpMmaTensorOp = WarpMmaTensorOp_;
95
- static int const kPartitionsK = PartitionsK;
96
- using OutputOp = OutputOp_;
97
- static int const kElementsPerAccess = ElementsPerAccess;
98
- using ElementOutput = typename OutputOp::ElementOutput;
99
- using LayoutC = typename WarpMmaTensorOp::LayoutC;
100
- using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
101
-
102
- //
103
- // Thread map
104
- //
105
-
106
- using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias<
107
- Shape,
108
- typename WarpMmaTensorOp::Shape,
109
- kPartitionsK,
110
- ElementOutput,
111
- kElementsPerAccess
112
- >::Type;
113
-
114
- using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
115
- OutputTileThreadMap,
116
- ElementOutput
117
- >;
118
-
119
- using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
120
- cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
121
- typename WarpMmaTensorOp::Shape,
122
- typename WarpMmaTensorOp::Policy::Operator::Shape,
123
- typename WarpMmaTensorOp::Policy::Operator::ElementC,
124
- typename WarpMmaTensorOp::Policy::Operator::FragmentC,
125
- LayoutC>,
126
- cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp<
127
- typename WarpMmaTensorOp::Shape,
128
- typename WarpMmaTensorOp::Policy::Operator::Shape,
129
- typename WarpMmaTensorOp::Policy::Operator::ElementC,
130
- typename WarpMmaTensorOp::Policy::Operator::FragmentC,
131
- LayoutC> >::type;
132
-
133
- //
134
- // Define the epilogue
135
- //
136
- using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue<
137
- Shape,
138
- WarpMmaTensorOp,
139
- kPartitionsK,
140
- OutputTileIterator,
141
- AccumulatorFragmentIterator,
142
- OutputOp
143
- >;
144
- };
145
-
146
- ////////////////////////////////////////////////////////////////////////////////
147
-
148
- ////////////////////////////////////////////////////////////////////////////////
149
-
150
- } // namespace threadblock
151
- } // namespace epilogue
152
- } // namespace cutlass
153
-
154
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h DELETED
@@ -1,113 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief
34
-
35
- */
36
-
37
- #pragma once
38
-
39
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
40
- #include "cutlass/gemm/gemm.h"
41
- #include "cutlass/layout/pitch_linear.h"
42
-
43
- ////////////////////////////////////////////////////////////////////////////////
44
-
45
- namespace cutlass {
46
- namespace epilogue {
47
- namespace threadblock {
48
-
49
- ////////////////////////////////////////////////////////////////////////////////
50
-
51
- /// Defines the optimal thread map for TensorOp accumulator layouts
52
- template <
53
- typename ThreadblockShape_,
54
- typename WarpShape_,
55
- int PartitionsK,
56
- typename Element_,
57
- int ElementsPerAccess
58
- >
59
- struct DefaultThreadMapTensorOpForFusedBias {
60
-
61
- using ThreadblockShape = ThreadblockShape_;
62
- using WarpShape = WarpShape_;
63
- static int const kPartitionsK = PartitionsK;
64
- using Element = Element_;
65
- static int const kElementsPerAccess = ElementsPerAccess;
66
-
67
- //
68
- // Definitions
69
- //
70
-
71
- struct Detail {
72
-
73
- /// Tensor Operations fundamentally perform operations on 8 rows
74
- static int const kTensorOpRows = 8;
75
- static int const kWarpSize = 32;
76
-
77
- static_assert(
78
- !(ThreadblockShape::kM % WarpShape::kM) &&
79
- !(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
80
-
81
- /// Number of warps
82
- using WarpCount = gemm::GemmShape<
83
- ThreadblockShape::kM / WarpShape::kM,
84
- ThreadblockShape::kN / WarpShape::kN,
85
- kPartitionsK
86
- >;
87
-
88
- /// Number of participating threads
89
- static int const kThreads = WarpCount::kCount * kWarpSize;
90
- };
91
-
92
- //
93
- // ThreadMap
94
- //
95
-
96
- /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
97
- using Type = OutputTileOptimalThreadMapBiasAct <
98
- OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
99
- OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
100
- Detail::kThreads,
101
- kElementsPerAccess,
102
- sizeof_bits<Element>::value
103
- >;
104
- };
105
-
106
- ///////////////////////////////////////////////////////////////////////////////
107
- ////////////////////////////////////////////////////////////////////////////////
108
-
109
- } // namespace threadblock
110
- } // namespace epilogue
111
- } // namespace cutlass
112
-
113
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h DELETED
@@ -1,213 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
34
-
35
- The epilogue rearranges the result of a matrix product through shared memory to match canonical
36
- tensor layouts in global memory. Epilogues support conversion and reduction operations.
37
-
38
- */
39
-
40
- #pragma once
41
- #include "cutlass/cutlass.h"
42
- #include CUDA_STD_HEADER(cassert)
43
- #include "cutlass/numeric_types.h"
44
- #include "cutlass/array.h"
45
- #include "cutlass/layout/vector.h"
46
- #include "cutlass/layout/tensor.h"
47
- #include "cutlass/tensor_coord.h"
48
- #include "cutlass/aligned_buffer.h"
49
- #include "cutlass/functional.h"
50
- #include "cutlass/gemm/gemm.h"
51
- #include "cutlass/transform/pitch_linear_thread_map.h"
52
- #include "cutlass/transform/threadblock/regular_tile_iterator.h"
53
- #include "cutlass/epilogue/threadblock/epilogue_base.h"
54
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
55
-
56
- ////////////////////////////////////////////////////////////////////////////////
57
-
58
- namespace cutlass {
59
- namespace epilogue {
60
- namespace threadblock {
61
-
62
- ////////////////////////////////////////////////////////////////////////////////
63
-
64
- /// Epilogue operator without splitk
65
- template <
66
- typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
67
- typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
68
- int PartitionsK, ///< Number of partitions of the K dimension
69
- typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
70
- typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
71
- typename OutputOp_ ///< Output operator
72
- >
73
- class FusedBiasActEpilogue {
74
-
75
- public:
76
-
77
- using Shape = Shape_;
78
- using WarpMmaOperator = WarpMmaOperator_;
79
- static int const kPartitionsK = PartitionsK;
80
- using OutputTileIterator = OutputTileIterator_;
81
- using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
82
- using OutputOp = OutputOp_;
83
-
84
- /// Output layout is always row-major
85
- using Layout = layout::RowMajor;
86
- using LongIndex = typename Layout::LongIndex;
87
-
88
- /// The complete warp-level accumulator tile
89
- using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
90
-
91
- /// Output element
92
- using ElementOutput = typename OutputTileIterator::Element;
93
-
94
- /// Output access size
95
- static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
96
-
97
-
98
- public:
99
-
100
-
101
- static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
102
-
103
- static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
104
- "Divisibility");
105
-
106
- public:
107
-
108
- /// Constructor
109
- CUTLASS_DEVICE
110
- FusedBiasActEpilogue(
111
- ){ }
112
-
113
- /// Streams the result to global memory
114
- CUTLASS_DEVICE
115
- void operator()(
116
- OutputOp const &output_op, ///< Output operator
117
- AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
118
- AccumulatorTile & fused_bias_act_accumlators,
119
- OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
120
-
121
- bool need_bias = output_op.is_source_needed();
122
-
123
- if (need_bias)
124
- compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator);
125
- else
126
- compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
127
-
128
-
129
- }
130
-
131
- CUTLASS_DEVICE
132
- void operator()(
133
- OutputOp const &output_op, ///< Output operator
134
- AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
135
- AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
136
-
137
- compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
138
- }
139
-
140
- CUTLASS_DEVICE
141
- void compute_source_needed_(
142
- OutputOp const &output_op, ///< Output operator
143
- AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
144
- AccumulatorTile & fused_bias_act_accumlators,
145
- OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
146
-
147
- typename OutputTileIterator::Fragment source_fragment;
148
-
149
-
150
- source_fragment.clear();
151
-
152
- AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
153
- AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
154
-
155
- CUTLASS_PRAGMA_UNROLL
156
- for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
157
-
158
- source_iterator.load(source_fragment);
159
- ++source_iterator;
160
-
161
- typename AccumulatorFragmentIterator::Fragment accum_fragment;
162
-
163
- accum_fragment_iterator.load(accum_fragment);
164
- ++accum_fragment_iterator;
165
-
166
- typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
167
- fused_bias_act_fragment = output_op(accum_fragment, source_fragment);
168
-
169
- fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
170
- ++fused_bias_act_fragment_iterator;
171
- }
172
- }
173
-
174
- CUTLASS_DEVICE
175
- void compute_source_no_needed_(
176
- OutputOp const &output_op, ///< Output operator
177
- AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
178
- AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
179
-
180
-
181
- AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
182
- AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
183
-
184
-
185
-
186
- CUTLASS_PRAGMA_UNROLL
187
- for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) {
188
-
189
- typename AccumulatorFragmentIterator::Fragment accum_fragment;
190
-
191
- accum_fragment_iterator.load(accum_fragment);
192
- ++accum_fragment_iterator;
193
-
194
- typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
195
- fused_bias_act_fragment = output_op(accum_fragment);
196
-
197
- fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
198
- ++fused_bias_act_fragment_iterator;
199
- }
200
- }
201
-
202
- };
203
-
204
-
205
-
206
-
207
- ////////////////////////////////////////////////////////////////////////////////
208
-
209
- } // namespace threadblock
210
- } // namespace epilogue
211
- } // namespace cutlass
212
-
213
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h DELETED
@@ -1,311 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles.
34
-
35
-
36
- */
37
-
38
- #pragma once
39
-
40
- #include "cutlass/cutlass.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/array.h"
43
- #include "cutlass/layout/matrix.h"
44
- #include "cutlass/matrix_shape.h"
45
- #include "cutlass/tensor_ref.h"
46
- #include "cutlass/fast_math.h"
47
-
48
- #include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
49
- ////////////////////////////////////////////////////////////////////////////////
50
-
51
- namespace cutlass {
52
- namespace epilogue {
53
- namespace threadblock {
54
-
55
- ////////////////////////////////////////////////////////////////////////////////
56
-
57
- ////////////////////////////////////////////////////////////////////////////////
58
-
59
- namespace detail {
60
-
61
- /// RowArrangement determines how one or more warps cover a region of consecutive rows.
62
- template <
63
- typename Shape,
64
- int WarpsRemaining,
65
- int ElementsPerAccess,
66
- int ElementSize,
67
- bool Is2dTile
68
- >
69
- struct RowArrangementBiasAct;
70
-
71
- /// RowArrangement in which each warp's access is a 1D tiled arrangement.
72
- template <
73
- typename Shape,
74
- int WarpsRemaining,
75
- int ElementsPerAccess,
76
- int ElementSize
77
- >
78
- struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
79
- static int const kWarpSize = 32;
80
- static int const kElementsPerAccess = ElementsPerAccess;
81
- static int const kElementSize = ElementSize;
82
-
83
- static int const kIterationsRow = 1;
84
- static int const kDeltaRow = 1;
85
- static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
86
- static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
87
-
88
- static int const kAccessWidth = kWarpSize;
89
- static int const kAccessRows = 1;
90
- static int const kWarpPartitionsRow = 1;
91
- static int const kWarpPartitionsColumn = WarpsRemaining;
92
- };
93
-
94
- /// RowArrangement in which each warp's access is a 2D tiled arrangement.
95
- template <
96
- typename Shape,
97
- int WarpsRemaining,
98
- int ElementsPerAccess,
99
- int ElementSize
100
- >
101
- struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
102
-
103
- static int const kMemoryAccessSize = 4;//128;
104
- static int const kWarpSize = 32;
105
-
106
- static int const kElementsPerAccess = ElementsPerAccess;
107
- static int const kElementSize = ElementSize;
108
-
109
- struct Detail {
110
- static int const kShapeRow = Shape::kRow / WarpsRemaining;
111
- static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
112
-
113
- static int const kTargetMemoryAccessWidth =
114
- kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
115
-
116
- static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
117
- };
118
-
119
- static int const kAccessWidth =
120
- (Detail::kTargetAccessRows > Detail::kShapeRow ?
121
- kWarpSize / Detail::kShapeRow
122
- : const_min(
123
- Detail::kShapeWidth,
124
- const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
125
- ));
126
-
127
- static int const kAccessRows =
128
- (Detail::kTargetAccessRows > Detail::kShapeRow ?
129
- Detail::kShapeRow
130
- : const_min(Shape::kRow, kWarpSize / kAccessWidth));
131
-
132
- static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
133
- static int const kDeltaRow = kAccessRows;
134
-
135
- static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
136
- static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
137
-
138
- static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");
139
- static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );
140
- static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );
141
-
142
- static int const kWarpPartitionsRow = 1;
143
- static int const kWarpPartitionsColumn = 1;
144
- };
145
-
146
- }
147
-
148
- ////////////////////////////////////////////////////////////////////////////////
149
-
150
- /// Template metaprogram for partitioning a 4D space across warps to achieve several performance
151
- /// objectives:
152
- ///
153
- /// - coalesced memory accesses in units of 16 Byte lines
154
- /// - minimal address arithmetic
155
- /// - minimal predicate calculations
156
- ///
157
- template <
158
- typename Shape_,
159
- typename Count_,
160
- int Threads,
161
- int ElementsPerAccess,
162
- int ElementSize
163
- >
164
- struct OutputTileOptimalThreadMapBiasAct {
165
-
166
- using Shape = Shape_;
167
- using Count = Count_;
168
-
169
- static int const kWarpSize = 32;
170
- static int const kThreads = Threads;
171
- static int const kWarpCount = kThreads / kWarpSize;
172
-
173
- static int const kElementsPerAccess = ElementsPerAccess;
174
- static int const kElementSize = ElementSize;
175
-
176
- //
177
- // Metaprogram computation
178
- //
179
-
180
- struct Detail {
181
-
182
- // Clusters
183
- static int const kIterationsCluster =
184
- ((Shape::kCluster > kWarpCount) ?
185
- Shape::kCluster / kWarpCount
186
- : 1);
187
-
188
- static int const kDeltaCluster =
189
- ((Shape::kCluster > kWarpCount) ?
190
- Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
191
- : 1);
192
-
193
- static int const kCompactedDeltaCluster =
194
- ((Shape::kCluster > kWarpCount) ?
195
- Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
196
- : 1);
197
-
198
- static int const kWarpPartitionsCluster =
199
- ((Shape::kCluster > kWarpCount) ?
200
- kWarpCount
201
- : kWarpCount / Shape::kCluster);
202
-
203
- static int const kWarpsRemainingForGroups =
204
- ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
205
-
206
- // Groups
207
- static int const kIterationsGroup =
208
- ((Shape::kGroup > kWarpsRemainingForGroups) ?
209
- Shape::kGroup / kWarpsRemainingForGroups
210
- : 1);
211
-
212
- static int const kDeltaGroup =
213
- ((Shape::kGroup > kWarpsRemainingForGroups) ?
214
- Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
215
- : 1);
216
-
217
- static int const kCompactedDeltaGroup =
218
- ((Shape::kGroup > kWarpsRemainingForGroups) ?
219
- Shape::kRow * Shape::kGroup / kIterationsGroup
220
- : 1);
221
-
222
- static int const kWarpPartitionsGroup =
223
- ((Shape::kGroup > kWarpsRemainingForGroups) ?
224
- 1
225
- : kWarpsRemainingForGroups / Shape::kGroup);
226
-
227
- static int const kWarpsRemainingForRows =
228
- ((Shape::kGroup > kWarpsRemainingForGroups) ?
229
- 1
230
- : kWarpsRemainingForGroups / Shape::kGroup);
231
-
232
- // Rows
233
- using RowArrangement = detail::RowArrangementBiasAct<
234
- Shape,
235
- kWarpsRemainingForRows,
236
- kElementsPerAccess,
237
- kElementSize,
238
- (Shape::kRow > kWarpsRemainingForRows)
239
- >;
240
-
241
- // Warp partitions
242
- using WarpPartitions = OutputTileShape<
243
- RowArrangement::kWarpPartitionsColumn,
244
- RowArrangement::kWarpPartitionsRow,
245
- kWarpPartitionsGroup,
246
- kWarpPartitionsCluster,
247
- 1>;
248
-
249
- static int const kAccessWidth = RowArrangement::kAccessWidth;
250
- static int const kAccessRows = RowArrangement::kAccessRows;
251
- };
252
-
253
- //
254
- // Output
255
- //
256
-
257
- using Iterations = OutputTileShape<
258
- Detail::RowArrangement::kIterationsColumn,
259
- Detail::RowArrangement::kIterationsRow,
260
- Detail::kIterationsGroup,
261
- Detail::kIterationsCluster,
262
- 1>;
263
-
264
- using Delta = OutputTileShape<
265
- Detail::RowArrangement::kDeltaColumn,
266
- Detail::RowArrangement::kDeltaRow,
267
- Detail::kDeltaGroup,
268
- Detail::kDeltaCluster,
269
- 1>;
270
-
271
- /// Initial offset function
272
- CUTLASS_HOST_DEVICE
273
- static MatrixCoord initial_offset(int thread_idx) {
274
-
275
- int warp_idx = thread_idx / kWarpSize;
276
- int lane_idx = thread_idx % kWarpSize;
277
-
278
- // Compute warp location
279
- int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
280
- int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
281
-
282
- int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
283
- int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
284
-
285
- int row_idx = residual_group / Detail::WarpPartitions::kRow;
286
- int col_idx = residual_group % Detail::WarpPartitions::kRow;
287
-
288
- // Compute per-lane offset
289
- int lane_row_offset = lane_idx / Detail::kAccessWidth;
290
- int lane_col_offset = lane_idx % Detail::kAccessWidth;
291
-
292
- // Compute coordinate in output space
293
- int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
294
- int group_offset = group_idx * Shape::kRow * Count::kRow;
295
- int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
296
- int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
297
-
298
- return MatrixCoord(
299
- cluster_offset + group_offset + row_offset + lane_row_offset,
300
- (column_offset + lane_col_offset) * kElementsPerAccess
301
- );
302
- }
303
-
304
- };
305
-
306
-
307
- ////////////////////////////////////////////////////////////////////////////////
308
-
309
- } // namespace threadblock
310
- } // namespace epilogue
311
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h DELETED
@@ -1,189 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile
34
- that participate in one warp-level store operation.
35
-
36
- Typically, the accumulator tile is the largest single block of register-backed storage
37
- within the kernel. Storing it to memory is best accomplished by partitioning it into
38
- smaller tiles and storing these sequentially.
39
-
40
- Round trips through shared memory during the Epilogue phase require partitioning, as
41
- shared memory capacity is typically insufficient for a threadblock's total accumulator
42
- size.
43
- */
44
-
45
- #pragma once
46
-
47
- #include "cutlass/array.h"
48
- #include "cutlass/layout/matrix.h"
49
-
50
- #include "cutlass/epilogue/warp/tensor_op_policy.h"
51
-
52
- ////////////////////////////////////////////////////////////////////////////////
53
-
54
- namespace cutlass {
55
- namespace epilogue {
56
- namespace warp {
57
-
58
- ////////////////////////////////////////////////////////////////////////////////
59
-
60
- ///
61
- template <
62
- typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape)
63
- typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape)
64
- typename OperatorElementC, ///< matrix multiply operation data type (concept: data type)
65
- typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array)
66
- typename Layout ///< target shared memory layout
67
- >
68
- class FusedBiasActFragmentIteratorTensorOp;
69
-
70
- ////////////////////////////////////////////////////////////////////////////////
71
-
72
- /// Partial specialization for row-major shared memory
73
- template <
74
- typename WarpShape_, ///< shape of the warp-level GEMM tile
75
- typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
76
- typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type)
77
- typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array)
78
- >
79
- class FusedBiasActFragmentIteratorTensorOp<WarpShape_, OperatorShape_, OperatorElementC_, OperatorFragmentC_, layout::RowMajor> {
80
- public:
81
-
82
- using WarpShape = WarpShape_;
83
- using OperatorShape = OperatorShape_;
84
- using OperatorElementC = OperatorElementC_;
85
- using OperatorFragmentC = OperatorFragmentC_;
86
- using Layout = layout::RowMajor;
87
-
88
- using Policy = TensorOpPolicy<WarpShape, OperatorShape, Layout>;
89
-
90
- /// This is the fragment size produced by one access of the iterator.
91
- using Fragment = Array<
92
- OperatorElementC,
93
- Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
94
-
95
- /// This is the complete warp-level accumulator tile.
96
- using AccumulatorTile = Array<
97
- OperatorElementC,
98
- OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>;
99
-
100
- using OutputAccumulatorTile = AccumulatorTile;
101
-
102
- /// Number of times this iterator can be incremented
103
- static int const kIterations = Policy::kIterations;
104
-
105
- private:
106
-
107
- /// Internal access type
108
- using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
109
-
110
- private:
111
-
112
- //
113
- // Data members
114
- //
115
-
116
- /// Accumulator tile
117
- AccessType *accumulators_;
118
-
119
- /// Internal index
120
- int index_;
121
-
122
- public:
123
-
124
- /// Constructs an iterator
125
- CUTLASS_HOST_DEVICE
126
- FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum):
127
- accumulators_(reinterpret_cast<AccessType *>(&accum)),
128
- index_(0) {
129
- }
130
-
131
- /// Increments
132
- CUTLASS_HOST_DEVICE
133
- FusedBiasActFragmentIteratorTensorOp &operator++() {
134
- ++index_;
135
- return *this;
136
- }
137
-
138
- /// Decrements
139
- CUTLASS_HOST_DEVICE
140
- FusedBiasActFragmentIteratorTensorOp &operator--() {
141
- --index_;
142
- return *this;
143
- }
144
-
145
- /// Loads a fragment from the referenced part of the accumulator tile
146
- CUTLASS_HOST_DEVICE
147
- void load(Fragment &frag, int index_offset = 0) const {
148
-
149
- int index = index_ + index_offset;
150
-
151
- AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
152
-
153
- CUTLASS_PRAGMA_UNROLL
154
- for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
155
-
156
- int accumulator_access_offset =
157
- index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
158
-
159
- frag_ptr[n] = accumulators_[accumulator_access_offset];
160
- }
161
- }
162
- /// Stores a fragment from the referenced part of the accumulator tile
163
- CUTLASS_HOST_DEVICE
164
- void store(Fragment &frag, int index_offset = 0) const {
165
-
166
- int index = index_ + index_offset;
167
-
168
- AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
169
-
170
- CUTLASS_PRAGMA_UNROLL
171
- for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
172
-
173
- int accumulator_access_offset =
174
- index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
175
-
176
- accumulators_[accumulator_access_offset] = frag_ptr[n];
177
- }
178
- }
179
- };
180
-
181
- ////////////////////////////////////////////////////////////////////////////////
182
-
183
- ////////////////////////////////////////////////////////////////////////////////
184
-
185
- } // namespace warp
186
- } // namespace epilogue
187
- } // namespace cutlass
188
-
189
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h DELETED
@@ -1,427 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
-
36
- #include "cutlass/array.h"
37
- #include "cutlass/matrix_shape.h"
38
- #include "cutlass/layout/matrix.h"
39
- #include "cutlass/layout/tensor.h"
40
- #include "cutlass/numeric_conversion.h"
41
-
42
- namespace cutlass {
43
- namespace gemm {
44
- namespace warp {
45
-
46
-
47
- ////////////////////////////////////////////////////////////////////////////////
48
-
49
- template <
50
- /// Size of the matrix to load (concept: MatrixShape)
51
- typename Shape_,
52
- /// Size of the accumulation tile shape (concept: MatrixShape)
53
- typename AccumulatorShape_,
54
- /// KBlocks columns to compute residual
55
- int KBlocksColumn_,
56
- /// Accumulator Element type
57
- typename ElementAccumulator_,
58
- /// Element type
59
- typename Element_,
60
- /// Layout of operand in memory
61
- typename Layout_,
62
- /// Shape of one matrix product operation (concept: MatrixShape)
63
- typename InstructionShape_,
64
- /// Whether beta is zero
65
- bool IsBetaZero_ >
66
- class MmaTensorOpPureFragmentIterator;
67
-
68
-
69
- // Partial specialization for col-major accumulator tile
70
- // And Element type is the same as Accumulator Element type
71
-
72
- template <
73
- /// Shape of warp tile to load (concept: MatrixShape)
74
- typename Shape_,
75
- /// Shape of the warp accumulation tile (concept: MatrixShape)
76
- typename AccumulatorShape_,
77
- /// KBlocks columns to compute residual
78
- int KBlocksColumn_,
79
- /// Element type
80
- typename Element_,
81
- /// Shape of one matrix product operation (concept: MatrixShape)
82
- typename InstructionShape_>
83
- class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
84
- cutlass::layout::ColumnMajor,
85
- InstructionShape_, true> {
86
- public:
87
-
88
- /// Shape of warp tile to load (concept: MatrixShape)
89
- using Shape = Shape_;
90
-
91
- /// Shape of the warp accumulation tile (concept: MatrixShape)
92
- using AccumulatorShape = AccumulatorShape_;
93
-
94
- /// KBlocks columns to compute residual
95
- static int const kKBlockColumn = KBlocksColumn_;
96
-
97
- /// Element type
98
- using Element = Element_;
99
-
100
- /// Layout of source tile
101
- using Layout = cutlass::layout::ColumnMajor;
102
-
103
- /// Shape of one matrix product operation (concept: MatrixShape)
104
- using InstructionShape = InstructionShape_;
105
-
106
- /// Whether beta is zero
107
- static bool const IsBetaZero = true;
108
-
109
- /// Number of participating threads
110
- static int const kThreads = 32;
111
-
112
- /// Internal structure of iterator - made public to enable introspection
113
- struct Policy {
114
- static_assert(
115
- !(Shape::kRow % InstructionShape::kM) &&
116
- !(Shape::kColumn % InstructionShape::kN),
117
- "Shape of warp-level Mma must be divisible by operator shape.");
118
- static_assert(
119
- !(AccumulatorShape::kRow % Shape::kRow) &&
120
- !(AccumulatorShape::kColumn % Shape::kColumn),
121
- "Shape of Warp Accumulator must be divisible by warp shape.");
122
- static_assert(
123
- !(kKBlockColumn % Shape::kColumn),
124
- "KBlock size must be divisible by warp shape.");
125
-
126
- /// Number of times this iterator can be incremented
127
- static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
128
- };
129
-
130
- private:
131
-
132
- static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
133
-
134
- /// Number of mma operations performed by a warp
135
- using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
136
- Shape::kColumn / InstructionShape::kN>;
137
- /// Number of mma operations performed by the entire accumulator
138
- using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
139
- AccumulatorShape::kColumn / InstructionShape::kN>;
140
-
141
- /// Number of K iterations
142
- static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
143
- static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
144
- static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
145
- * (AccumulatorShape::kRow / Shape::kRow);
146
- static int const kResidualIndex = kResidualColumn / Shape::kColumn
147
- * (AccumulatorShape::kRow / Shape::kRow);
148
-
149
- public:
150
-
151
- //
152
- // Derived quantities
153
- //
154
-
155
- /// Fragment object holding a thread's part of a tile
156
- /// This is the fragment size produced by one access of the iterator.
157
- using Fragment = Array<Element, Shape::kCount / kThreads>;
158
-
159
- /// Accumulator Fragment object
160
- using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
161
-
162
-
163
- private:
164
-
165
- /// Internal access type
166
- using AccessType = Array<Element, kElementsPerAccess>;
167
-
168
- private:
169
- //
170
- // Data members
171
- //
172
-
173
- /// Accumulator tile
174
- AccessType const *accumulators_;
175
-
176
- /// Internal index
177
- int index_;
178
-
179
- /// Used to access residual tile first
180
- bool is_residual_tile_;
181
-
182
- public:
183
- /// Constructs an iterator
184
- CUTLASS_HOST_DEVICE
185
- MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
186
- : accumulators_(reinterpret_cast<AccessType const *>(&accum)),
187
- index_(0), is_residual_tile_(true) {}
188
-
189
- /// Add offset
190
- CUTLASS_HOST_DEVICE
191
- void add_offset(int index_offset) {
192
- index_ += index_offset;
193
- if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
194
- index_ = index_ - kKBlockColumnIterations + kResidualIndex;
195
- is_residual_tile_ = false;
196
- }
197
- }
198
-
199
- /// Increments
200
- CUTLASS_HOST_DEVICE
201
- MmaTensorOpPureFragmentIterator &operator++() {
202
- add_offset(1);
203
- return *this;
204
- }
205
-
206
- /// Decrements
207
- CUTLASS_HOST_DEVICE
208
- MmaTensorOpPureFragmentIterator &operator--() {
209
- add_offset(-1);
210
- return *this;
211
- }
212
-
213
- /// Loads a fragment from the referenced part of the accumulator tile
214
- CUTLASS_HOST_DEVICE
215
- void load(Fragment &frag) const {
216
-
217
- AccessType src_fragment;
218
- src_fragment.clear();
219
-
220
-
221
- AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
222
-
223
- int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
224
- int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
225
- * MmaIterations::kColumn;
226
-
227
- CUTLASS_PRAGMA_UNROLL
228
- for (int n = 0; n < MmaIterations::kColumn; n++) {
229
- for (int m = 0; m < MmaIterations::kRow; m++) {
230
- int accumulator_access_offset =
231
- (n + index_n) * AccumulatorIterations::kRow + m + index_m;
232
-
233
- frag_ptr[n * MmaIterations::kRow + m].clear();
234
- if(!(is_residual_tile_ && index_ >= kResidualIndex))
235
- frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset];
236
- // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment);
237
- }
238
- }
239
- }
240
-
241
- };
242
-
243
- // Partial specialization for row-major accumulator tile
244
-
245
- template <
246
- /// Shape of warp tile to load (concept: MatrixShape)
247
- typename Shape_,
248
- /// Shape of the warp accumulation tile (concept: MatrixShape)
249
- typename AccumulatorShape_,
250
- /// KBlocks columns to compute residual
251
- int KBlocksColumn_,
252
- /// Accumulator Element type
253
- typename ElementAccumulator_,
254
- /// Element type
255
- typename Element_,
256
- /// Shape of one matrix product operation (concept: MatrixShape)
257
- typename InstructionShape_>
258
- class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
259
- cutlass::layout::RowMajor,
260
- InstructionShape_, true> {
261
- public:
262
-
263
- /// Shape of warp tile to load (concept: MatrixShape)
264
- using Shape = Shape_;
265
-
266
- /// Shape of the warp accumulation tile (concept: MatrixShape)
267
- using AccumulatorShape = AccumulatorShape_;
268
-
269
- /// KBlocks columns to compute residual
270
- static int const kKBlockColumn = KBlocksColumn_;
271
-
272
- /// Accumulator Element type
273
- using ElementAccumulator = ElementAccumulator_;
274
-
275
- /// Element type
276
- using Element = Element_;
277
-
278
- /// Layout of source tile
279
- using Layout = cutlass::layout::RowMajor;
280
-
281
- /// Shape of one matrix product operation (concept: MatrixShape)
282
- using InstructionShape = InstructionShape_;
283
-
284
- /// Whether beta is zero
285
- static bool const IsBetaZero = true;
286
-
287
- /// Number of participating threads
288
- static int const kThreads = 32;
289
-
290
- /// Internal structure of iterator - made public to enable introspection
291
- struct Policy {
292
- static_assert(
293
- !(Shape::kRow % InstructionShape::kM) &&
294
- !(Shape::kColumn % InstructionShape::kN),
295
- "Shape of warp-level Mma must be divisible by operator shape.");
296
- static_assert(
297
- !(AccumulatorShape::kRow % Shape::kRow) &&
298
- !(AccumulatorShape::kColumn % Shape::kColumn),
299
- "Shape of Warp Accumulator must be divisible by warp shape.");
300
- static_assert(
301
- !(kKBlockColumn % Shape::kColumn),
302
- "KBlock size must be divisible by warp shape.");
303
-
304
- /// Number of times this iterator can be incremented
305
- static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
306
- };
307
-
308
- private:
309
-
310
- static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
311
-
312
- /// Number of mma operations performed by a warp
313
- using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
314
- Shape::kColumn / InstructionShape::kN>;
315
- /// Number of mma operations performed by the entire accumulator
316
- using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
317
- AccumulatorShape::kColumn / InstructionShape::kN>;
318
-
319
- /// Number of K iterations
320
- static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
321
- static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
322
- static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
323
- * (AccumulatorShape::kRow / Shape::kRow);
324
- static int const kResidualIndex = kResidualColumn / Shape::kColumn
325
- * (AccumulatorShape::kRow / Shape::kRow);
326
-
327
- public:
328
-
329
- //
330
- // Derived quantities
331
- //
332
-
333
- /// Fragment object holding a thread's part of a tile
334
- /// This is the fragment size produced by one access of the iterator.
335
- using Fragment = Array<Element, Shape::kCount / kThreads>;
336
-
337
- /// Accumulator Fragment object
338
- using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
339
-
340
-
341
- private:
342
-
343
- /// Internal access type
344
- using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
345
- using FragmentAccessType = Array<Element, kElementsPerAccess>;
346
-
347
- private:
348
- //
349
- // Data members
350
- //
351
-
352
- /// Accumulator tile
353
- AccessType const *accumulators_;
354
-
355
- /// Internal index
356
- int index_;
357
-
358
- /// Used to access residual tile first
359
- bool is_residual_tile_;
360
-
361
- public:
362
- /// Constructs an iterator
363
- CUTLASS_HOST_DEVICE
364
- MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
365
- : accumulators_(reinterpret_cast<AccessType const *>(&accum)),
366
- index_(0), is_residual_tile_(true) {}
367
-
368
- /// Add offset
369
- CUTLASS_HOST_DEVICE
370
- void add_offset(int index_offset) {
371
- index_ += index_offset;
372
- if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
373
- index_ = index_ - kKBlockColumnIterations + kResidualIndex;
374
- is_residual_tile_ = false;
375
- }
376
- }
377
-
378
- /// Increments
379
- CUTLASS_HOST_DEVICE
380
- MmaTensorOpPureFragmentIterator &operator++() {
381
- add_offset(1);
382
- return *this;
383
- }
384
-
385
- /// Decrements
386
- CUTLASS_HOST_DEVICE
387
- MmaTensorOpPureFragmentIterator &operator--() {
388
- add_offset(-1);
389
- return *this;
390
- }
391
-
392
- /// Loads a fragment from the referenced part of the accumulator tile
393
- CUTLASS_HOST_DEVICE
394
- void load(Fragment &frag) const {
395
-
396
-
397
- FragmentAccessType src_fragment;
398
- src_fragment.clear();
399
-
400
- FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
401
-
402
- int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
403
- int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
404
- * MmaIterations::kColumn;
405
-
406
- CUTLASS_PRAGMA_UNROLL
407
- for (int m = 0; m < MmaIterations::kRow; m++) {
408
- for (int n = 0; n < MmaIterations::kColumn; n++) {
409
- int accumulator_access_offset =
410
- (m + index_m) * AccumulatorIterations::kColumn + n + index_n;
411
-
412
- frag_ptr[m * MmaIterations::kColumn + n].clear();
413
- if(!(is_residual_tile_ && index_ >= kResidualIndex))
414
- frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]);
415
- }
416
- }
417
- }
418
-
419
- };
420
-
421
- ////////////////////////////////////////////////////////////////////////////////
422
-
423
- } // namespace warp
424
- } // namespace gemm
425
- } // namespace cutlass
426
-
427
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py DELETED
@@ -1,129 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import gen_turing_and_volta as api_generator
34
- import gen_sample as sample_creater
35
- import gen_cmake as cmake_creater
36
- import gen_verify as verify_creater
37
- import gen_device as b2b_fused_generator
38
- import replace_fix_impl_header
39
-
40
- import argparse
41
- import os
42
- import json
43
-
44
-
45
- parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels")
46
- parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate")
47
- parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name")
48
- parser.add_argument("--output-dir", default="", help="Specifies the output dir")
49
- parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir")
50
- parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.")
51
- args = parser.parse_args()
52
-
53
- gen_name = args.gen_name
54
-
55
- cutlass_deps_dir = args.cutlass_dir
56
-
57
- output_dir = args.output_dir
58
- output_dir += "/"
59
-
60
- cutlass_deps_root = args.gen_include_cutlass_dir
61
- if cutlass_deps_root == '':
62
- cutlass_deps_root = cutlass_deps_dir + "/include/"
63
- cutlass_deps_root +='/'
64
-
65
-
66
- if not os.path.exists(output_dir):
67
- os.makedirs(output_dir)
68
-
69
- if not os.path.exists(output_dir + "/" + "auto_gen"):
70
- os.mkdir(output_dir + "/" + "auto_gen")
71
-
72
- if not os.path.exists(output_dir + "/" + "fixed_impl"):
73
- os.mkdir(output_dir + "/" + "fixed_impl" )
74
-
75
- if not os.path.exists(output_dir + "/" + "sample"):
76
- os.mkdir(output_dir + "/" + "sample" )
77
-
78
- if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"):
79
- os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device")
80
- if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"):
81
- os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel")
82
- if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"):
83
- os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock")
84
-
85
- with open(args.config_file, 'r') as infile:
86
- gemm_info_dict = json.load(infile)
87
-
88
- keys = sorted(gemm_info_dict.keys())
89
- fuse_gemm_info = [gemm_info_dict[k] for k in keys]
90
-
91
-
92
- for_cutlass_gen_user_include_header_file = [
93
- cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
94
- cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
95
- ]
96
-
97
- for_fused_wrapper = [
98
- cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
99
- cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
100
- "auto_gen/device/" + gen_name + ".h",
101
- cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h",
102
- cutlass_deps_root + "cutlass/cutlass.h",
103
- ]
104
-
105
- # Copy fixed implementation to the output directory
106
- fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root)
107
- fix_impl.gen_code()
108
-
109
- auto_gen_output_dir = output_dir + "/auto_gen/"
110
- project_root = ""
111
- turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir)
112
- turing_plus.gen_code(75, 'hmma1688', False)
113
-
114
- api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
115
- api.gen_code()
116
-
117
- # Generate C++ sample
118
- os.system("cp ../leaky_bias.h " + output_dir + "/sample/")
119
- os.system("cp ../utils.h " + output_dir + "/sample/")
120
-
121
- sample_dir = output_dir + "/sample/"
122
- sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir)
123
- sample.gen_cpp_sample()
124
-
125
- cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir)
126
- cmake_gen.gen_code()
127
-
128
- verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
129
- verify.gen_code()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py DELETED
@@ -1,131 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- class gen_build_sys:
34
- def __init__(self, cutlass_deps_dir, output_dir = "../"):
35
- self.output_dir = output_dir
36
- self.cutlass_deps_dir = cutlass_deps_dir
37
-
38
- def gen_top(self):
39
- code = ""
40
- code += '''\
41
- # Auto Generated code - Do not edit.
42
-
43
- cmake_minimum_required(VERSION 3.8)
44
- project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA)
45
- find_package(CUDAToolkit)
46
- set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}})
47
- set(CUTLASS_PATH \"{cutlass_deps_dir}/include\")
48
- set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\")
49
- list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}})
50
- '''.format(cutlass_deps_dir=self.cutlass_deps_dir)
51
-
52
- code += '''\
53
- set(GPU_ARCHS \"\" CACHE STRING
54
- \"List of GPU architectures (semicolon-separated) to be compiled for.\")
55
-
56
- if(\"${GPU_ARCHS}\" STREQUAL \"\")
57
- set(GPU_ARCHS \"70\")
58
- endif()
59
-
60
- foreach(arch ${GPU_ARCHS})
61
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\")
62
- if(SM STREQUAL 70 OR SM STREQUAL 75)
63
- set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\")
64
- set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\")
65
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\")
66
- endif()
67
- endforeach()
68
-
69
- set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\")
70
- set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\")
71
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\")
72
-
73
- set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\")
74
- set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\")
75
- set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\")
76
-
77
- set(CMAKE_CXX_STANDARD 11)
78
- set(CMAKE_CXX_STANDARD_REQUIRED ON)
79
-
80
- if(CMAKE_CXX_STANDARD STREQUAL \"11\")
81
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\")
82
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\")
83
- endif()
84
-
85
- set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\")
86
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\")
87
- set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\")
88
-
89
- set(COMMON_HEADER_DIRS
90
- ${PROJECT_SOURCE_DIR}
91
- ${CUDAToolkit_INCLUDE_DIRS}
92
- )
93
-
94
- set(COMMON_LIB_DIRS
95
- ${CUDAToolkit_LIBRARY_DIR}
96
- )
97
- list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH})
98
- list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH})
99
- '''
100
- code += '''\
101
- include_directories(
102
- ${COMMON_HEADER_DIRS}
103
- )
104
-
105
- link_directories(
106
- ${COMMON_LIB_DIRS}
107
- )
108
-
109
- add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
110
- add_definitions(-DGOOGLE_CUDA=1)
111
-
112
- add_executable(sample
113
- sample/sample.cu
114
- one_api.cu
115
- )
116
- target_link_libraries(sample PRIVATE
117
- -lcudart
118
- -lnvToolsExt
119
- ${CMAKE_THREAD_LIBS_INIT}
120
- )
121
-
122
- if(NOT DEFINED LIB_INSTALL_PATH)
123
- set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR})
124
- endif()
125
- '''
126
- return code
127
-
128
- def gen_code(self):
129
- top_code = self.gen_top()
130
- with open(self.output_dir + "CMakeLists.txt", "w") as f:
131
- f.write(top_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py DELETED
@@ -1,120 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import ast
34
-
35
- fuse_gemm_info = [
36
- {
37
- 'epilogue': {
38
- 'tp': 'LeakyRelu', #'CustomizedLeaky_RELU'
39
- 'bias': {'addbias': False, 'bias_tp': 'mat'},
40
- 'args': [('float', 'leaky_alpha', 1.3), ],
41
- 'func': '''
42
- y = max(leaky_alpha * x, x)
43
- y = y * x
44
- '''
45
- }
46
- },
47
-
48
- ]
49
- class AnalysisNodeVisitor(ast.NodeVisitor):
50
- def visit_Import(self,node):
51
- ast.NodeVisitor.generic_visit(self, node)
52
-
53
- def visit_ImportFrom(self,node):
54
- ast.NodeVisitor.generic_visit(self, node)
55
-
56
- def visit_Assign(self,node):
57
- print('Node type: Assign and fields: ', node._fields)
58
- # print('Node type: Assign and targets value: ', node.targets, node.value)
59
-
60
- ast.NodeVisitor.generic_visit(self, node)
61
-
62
- def visit_BinOp(self, node):
63
- print('Node type: BinOp and fields: ', node._fields)
64
- print('node op: ', type(node.op).__name__)
65
- ast.NodeVisitor.generic_visit(self, node)
66
-
67
- def visit_Expr(self, node):
68
- print('Node type: Expr and fields: ', node._fields)
69
- ast.NodeVisitor.generic_visit(self, node)
70
-
71
- def visit_Num(self,node):
72
- print('Node type: Num and fields: ', node._fields)
73
- print('Node type: Num: ', node.n)
74
-
75
- def visit_Name(self,node):
76
- print('Node type: Name and fields: ', node._fields)
77
- print('Node type: Name and fields: ', type(node.ctx).__name__, node.id)
78
-
79
- ast.NodeVisitor.generic_visit(self, node)
80
-
81
- def visit_Str(self, node):
82
- print('Node type: Str and fields: ', node._fields)
83
-
84
- class CodeVisitor(ast.NodeVisitor):
85
- def visit_BinOp(self, node):
86
- if isinstance(node.op, ast.Add):
87
- node.op = ast.Sub()
88
- self.generic_visit(node)
89
-
90
- def visit_Assign(self, node):
91
- print('Assign %s' % node.value)
92
- self.generic_visit(node)
93
-
94
- def visit_Name(self, node):
95
- print("Name:", node.id)
96
- self.generic_visit(node)
97
-
98
-
99
- def visit_FunctionDef(self, node):
100
- print('Function Name:%s'% node.name.op)
101
- self.generic_visit(node)
102
- func_log_stmt = ast.Print(
103
- dest = None,
104
- values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
105
- nl = True,
106
- lineno = 0,
107
- col_offset = 0,
108
- )
109
- node.body.insert(0, func_log_stmt)
110
-
111
- visitor = AnalysisNodeVisitor()
112
-
113
- code = \
114
- '''
115
-
116
- a=max(leaky_alpha * x, x +1)
117
-
118
- '''
119
-
120
- visitor.visit(ast.parse(code))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py DELETED
@@ -1,469 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from typing import *
34
-
35
- import helper
36
- import gen_ir
37
-
38
- import gen_kernel as gen_ker
39
-
40
-
41
- class gen_device:
42
- def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"):
43
- self.fuse_gemm_info = fuse_gemm_info
44
- self.raw_gemm_info = fuse_gemm_info
45
- self.b2b_num = len(fuse_gemm_info)
46
- self.user_header_file = user_header_file
47
- self.args = {}
48
- # device arg struct memebr
49
- self.arg_member = []
50
- self.gen_class_name = gen_class_name
51
- self.gen_kernel_name = gen_class_name + "Kernel"
52
- self.template_args = []
53
- self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int}
54
-
55
- self.file_name = output_dir + "/device/" +gen_class_name +".h"
56
- self.sample_dir = output_dir
57
-
58
-
59
- self.cutlass_deps_root = cutlass_deps_root
60
- self.project_root = project_root
61
- self.this_file_root = output_dir + "/device/"
62
-
63
- self.first_use_1stage = False
64
-
65
- ## gen kernel
66
- self.gen_kernel = gen_ker.gen_kernel(self.template_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root)
67
-
68
-
69
- def __check_arg_type(self, temp_arg):
70
- if temp_arg in self.__tempalate_arg_list.keys():
71
- return self.__tempalate_arg_list[temp_arg]
72
-
73
- find_sub = False
74
- for candidate_arg in self.__tempalate_arg_list.keys():
75
- if (temp_arg.find(candidate_arg) != -1):
76
- return self.__tempalate_arg_list[candidate_arg]
77
-
78
- return 'typename'
79
-
80
- # def gen_B2b2bGemm_class():
81
- def set_arch(self, sm_cap, mma_tp):
82
- if sm_cap == 75 or sm_cap == 80 or sm_cap == 86:
83
- self.arch = "cutlass::arch::Sm" + str(sm_cap)
84
-
85
- if mma_tp is 'hmma1688':
86
- self.mma_shape = [16, 8, 8]
87
- self.mma_tp = 'hmma'
88
- elif mma_tp is 'imma8816':
89
- self.mma_tp = 'imma'
90
- self.mma_shape = [8, 8, 16]
91
- else:
92
- return 0
93
-
94
- def gen_include_header(self):
95
- code = '''\
96
- /* Auto Generated code - Do not edit.*/
97
-
98
- #pragma once
99
-
100
- #include \"{cutlass_root}cutlass/cutlass.h\"
101
- #include \"{cutlass_root}cutlass/numeric_types.h\"
102
- #include \"{cutlass_root}cutlass/arch/arch.h\"
103
- #include \"{cutlass_root}cutlass/device_kernel.h\"
104
-
105
- #include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\"
106
-
107
- #include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\"
108
- #include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\"
109
- #include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\"
110
-
111
- #include \"{project_root}../kernel/b2b_gemm.h\"
112
- #include \"{project_root}../kernel/default_b2b_gemm.h\"
113
- '''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root)
114
- include_user_header = ""
115
- for header in self.user_header_file:
116
- include_user_header += "#include \"" + header + "\"\n"
117
- return code + include_user_header
118
-
119
- def gen_code(self, sm_cap, mma_tp, ifprint = True):
120
- self.set_arch(sm_cap, mma_tp)
121
-
122
- self.update_b2b_args()
123
- print(self.fuse_gemm_info)
124
- self.update_b2b_class_template_args()
125
-
126
- func_code = self.gen_all_func()
127
- member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n"
128
-
129
- gen_code = gen_ir.gen_template_class(self.gen_class_name, self.template_args, func_code + member_var_code)
130
- code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code)))
131
-
132
- if ifprint:
133
- print(code)
134
-
135
- print("[INFO]: Gen device code output Dir: is ", self.file_name)
136
- with open(self.file_name, 'w+') as f:
137
- f.write(code)
138
-
139
-
140
- gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage)
141
- print(gen_kernel)
142
-
143
- def update_b2b_class_template_args(self):
144
- for arg in self.args.keys():
145
- self.template_args.append([self.__check_arg_type(arg), arg, self.args[arg]])
146
-
147
- def update_b2b_args(self):
148
-
149
- self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp'])
150
- self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format'])
151
-
152
- cnt = 0
153
-
154
- warp_M_tile = 32
155
-
156
- # Determine maximum N_tile
157
- Max_Ntile = 0
158
- for layer in self.fuse_gemm_info:
159
- n_tile = layer['mnk'][1]
160
- if n_tile > Max_Ntile:
161
- Max_Ntile = n_tile
162
- if Max_Ntile >= 256:
163
- warp_M_tile = 16
164
-
165
- stages_temp = []
166
-
167
- for layer in self.fuse_gemm_info:
168
- cnt_str = str(cnt)
169
- B_tp_str= 'ElementB' + cnt_str
170
- B_format_str = 'LayoutB' + cnt_str
171
- C_tp_str= 'ElementC' + cnt_str
172
- C_format_str = 'LayoutC' + cnt_str
173
- Acc_str = 'ElementAccumulator' + cnt_str
174
-
175
- self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp'])
176
- self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format'])
177
- self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp'])
178
- self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format'])
179
- self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp'])
180
-
181
-
182
- mnk = layer['mnk'][:]
183
-
184
- tile_mnk = mnk[:]
185
-
186
- tile_mnk[2] = 32 # force the ktile is 32
187
-
188
- #N tile gen
189
- if mnk[1] > 1024:
190
- assert(0)
191
- elif mnk[1] > 512:
192
- tile_mnk[1] = 1024
193
- elif mnk[1] > 256:
194
- tile_mnk[1] = 512
195
- elif mnk[1] > 128:
196
- tile_mnk[1] = 256
197
- elif mnk[1] > 64:
198
- tile_mnk[1] = 128
199
- elif mnk[1] > 32:
200
- tile_mnk[1] = 64
201
- else :
202
- tile_mnk[1] = 32
203
-
204
- if tile_mnk[1] == 512:
205
- stages_temp.append(1)
206
- else:
207
- stages_temp.append(2)
208
-
209
- tile_mnk[0] = 4 * warp_M_tile
210
-
211
-
212
-
213
- epilogue_setted_type = helper.get_epilogue_tp(layer)
214
- cutlass_epilogue_name = "LinearCombinationRelu"
215
- if epilogue_setted_type.lower() == 'leakyrelu':
216
- cutlass_epilogue_name = "LinearCombinationLeakyRelu"
217
- elif epilogue_setted_type.lower() == 'identity':
218
- cutlass_epilogue_name = "LinearCombination"
219
-
220
- epilogue_str = 'EpilogueOutputOp' + cnt_str
221
- if cnt != len(self.fuse_gemm_info) - 1:
222
- n = layer['mnk'][1]
223
- Fragments = tile_mnk[1] // 8 * 2
224
- self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "<ElementC0_, " + str(Fragments) +", ElementAccumulator0_, ElementAccumulator0_>"
225
- else:
226
- n = layer['mnk'][1]
227
- n_mod_8 = n % 4
228
- N_align_elements = 1
229
- if n_mod_8 == 0:
230
- N_align_elements = 8
231
- elif n_mod_8 == 4:
232
- N_align_elements = 4
233
- elif n_mod_8 == 2 or n_mod_8 == 6:
234
- N_align_elements = 2
235
-
236
- self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<ElementC0_, " + str(N_align_elements) + ", ElementAccumulator0_, ElementAccumulator0_>"
237
-
238
-
239
-
240
- ThreadBlockShape_str = 'ThreadblockShape' + cnt_str
241
-
242
- self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
243
-
244
- WarpShape_str = 'WarpShape' + cnt_str
245
- tile_mnk[0] = warp_M_tile
246
- self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
247
- cnt += 1
248
-
249
-
250
- self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp'])
251
- self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format'])
252
-
253
- self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape)
254
- self.args['OperatorClass'] = 'arch::OpClassTensorOp'
255
- self.args['ArchTag'] = self.arch
256
- self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle'
257
-
258
-
259
- for i in range(self.b2b_num):
260
- self.args[helper.var_idx('Stages', i)] = "2"
261
-
262
- self.args['AlignmentA'] = str(8)
263
- self.args['AlignmentB'] = str(8)
264
- self.args['SplitKSerial'] = 'false'
265
- self.args['Operator'] = 'typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB0_, ElementC0_, ElementAccumulator0_>::Operator'
266
- self.args['IsBetaZero'] = 'false'
267
-
268
-
269
- def gen_using_kernel(self):
270
- code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n"
271
- code += " " + "ElementA,\n"
272
- code += " " + "LayoutA,\n"
273
-
274
- for i in range(self.b2b_num):
275
- code += " " + helper.var_idx("ElementB", i) + ",\n"
276
- code += " " + helper.var_idx("LayoutB", i) + ",\n"
277
- code += " " + helper.var_idx("ElementC", i) + ",\n"
278
- code += " " + helper.var_idx("LayoutC", i) + ",\n"
279
- code += " " + helper.var_idx("ElementAccumulator", i) + ",\n"
280
- code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n"
281
- code += " " + helper.var_idx("ThreadblockShape", i) + ",\n"
282
- code += " " + helper.var_idx("WarpShape", i) + ",\n"
283
-
284
- code += " " + "ElementD,\n"
285
- code += " " + "LayoutD,\n"
286
- code += " " + "InstructionShape,\n"
287
- code += " " + "OperatorClass,\n"
288
- code += " " + "ArchTag,\n"
289
- code += " " + "ThreadblockSwizzle,\n"
290
-
291
- for i in range(self.b2b_num):
292
- code += " " + helper.var_idx("Stages", i) + ",\n"
293
-
294
-
295
- code += " " + "AlignmentA,\n"
296
- code += " " + "AlignmentB,\n"
297
- code += " " + "SplitKSerial,\n"
298
- code += " " + "Operator,\n"
299
- code += " " + "IsBetaZero_\n"
300
-
301
- code += ">::B2bGemmKernel;\n\n"
302
-
303
- return code
304
-
305
- def gen_args(self):
306
-
307
- def gen_arg_member(b2b_num):
308
- data_members = []
309
-
310
- for i in range(b2b_num):
311
- member_type = "GemmCoord"
312
- member_name = "problem_size_" + str(i)
313
- data_members.append((member_type, member_name))
314
-
315
- member_type = "TensorRef<ElementA const, LayoutA>"
316
- member_name = "ref_A0"
317
- data_members.append((member_type, member_name))
318
-
319
- for i in range(b2b_num):
320
- member_type = "TensorRef<ElementB" + str(i) + " const, LayoutB" + str(i) +">"
321
- member_name = "ref_B" + str(i)
322
- data_members.append((member_type, member_name))
323
- member_type = "TensorRef<ElementC" + str(i) + " const, LayoutC" + str(i) +">"
324
- member_name = "ref_C" + str(i)
325
- data_members.append((member_type, member_name))
326
-
327
- member_type = "TensorRef<ElementD, LayoutD>"
328
- member_name = helper.var_idx("ref_D", b2b_num - 1)
329
- data_members.append((member_type, member_name))
330
-
331
- for i in range(b2b_num):
332
- member_type = "typename EpilogueOutputOp" + str(i) + "::Params"
333
- member_name = "epilogue" + str(i)
334
- data_members.append((member_type, member_name))
335
-
336
- data_members.append(('int', 'batch_count'))
337
-
338
- return data_members
339
-
340
- def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value):
341
- constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
342
- gen_ir.indentation + struct_name + " (): "
343
- for i in range(inital_param_num):
344
- final_param = ','
345
- if i == inital_param_num - 1:
346
- final_param = '{ }'
347
- constructs_code += data_members[i][1] + inital_value + final_param
348
-
349
- constructs_code += "\n"
350
- return constructs_code
351
-
352
- def gen_arg_struct_ctor(struct_name, data_members):
353
- constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
354
- gen_ir.indentation + struct_name + " (\n"
355
- cnt = 0
356
- param_num = len(data_members)
357
- for param in data_members:
358
- final = ',\n'
359
- if cnt == param_num - 1:
360
- final = '\n):\n'
361
- constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final
362
- cnt += 1
363
-
364
- cnt = 0
365
- for param in data_members:
366
- final = '),\n'
367
- if cnt == param_num - 1:
368
- final = ") { }\n"
369
- constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final
370
- cnt += 1
371
-
372
- constructs_code += "\n"
373
- return constructs_code
374
-
375
- # (variable type, variable name)
376
- struct_member = gen_arg_member(self.b2b_num)
377
- self.arg_member = struct_member
378
-
379
- codeBody = ""
380
- for each_member in struct_member:
381
- codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n"
382
-
383
- codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n"
384
- codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n"
385
- struct_code = gen_ir.gen_struct("Arguments", codeBody)
386
- return struct_code
387
-
388
- def gen_func_constructs(self):
389
- code = self.gen_class_name +"() {}"
390
- return code
391
-
392
- def gen_func_initialize(self):
393
- code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \
394
- "// Determine grid shape\n" + \
395
- "ThreadblockSwizzle threadblock_swizzle;\n" + \
396
- "cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \
397
- " args.problem_size_0, \n" + \
398
- " { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \
399
- " args.batch_count);\n" + \
400
- "// Initialize the Params structure\n" + \
401
- "params_ = typename B2bGemmKernel::Params{\n"
402
- for i in range(self.b2b_num):
403
- code += helper.var_idx(" args.problem_size_", i) + ",\n"
404
- code += " grid_shape,\n" + \
405
- " args.ref_A0.non_const_ref(),\n"
406
- for i in range(self.b2b_num):
407
- code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n"
408
- code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n"
409
-
410
- code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n"
411
- for i in range(self.b2b_num):
412
- code += helper.var_idx(" args.epilogue", i) + ",\n"
413
-
414
- code += " args.batch_count\n"
415
- code += "};\n" + \
416
- "return Status::kSuccess;\n" + \
417
- "}\n"
418
- return code
419
-
420
- def gen_func_run(self):
421
- code = "Status run(cudaStream_t stream = nullptr) {\n" + \
422
- "\n" + \
423
- " ThreadblockSwizzle threadblock_swizzle;\n" + \
424
- "\n" + \
425
- " dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \
426
- " dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \
427
- "\n" + \
428
- " cudaError_t result;\n" + \
429
- "\n" + \
430
- " int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \
431
- " if (smem_size >= (48 << 10)) {\n" + \
432
- " result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \
433
- "\n" + \
434
- " if (result != cudaSuccess) {\n" + \
435
- " return Status::kErrorInternal;\n" + \
436
- " }\n" + \
437
- " }\n" + \
438
- " cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);\n" + \
439
- " result = cudaGetLastError();\n" + \
440
- " return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \
441
- " }\n"
442
-
443
- return code
444
- def gen_func_operator(self):
445
- opeartor_with_arg_code = "Status operator()(\n" + \
446
- " Arguments const &args,\n" + \
447
- " void *workspace = nullptr,\n" + \
448
- " cudaStream_t stream = nullptr) {\n" + \
449
- " Status status = initialize(args, workspace);\n" + \
450
- " \n" + \
451
- " if (status == Status::kSuccess) {\n" + \
452
- " status = run(stream);\n" + \
453
- " }\n" + \
454
- " return status;\n" + \
455
- "}\n"
456
- operator_code = "Status operator()(\n" + \
457
- " cudaStream_t stream = nullptr) {\n" + \
458
- " Status status = run(stream);\n" + \
459
- " return status;\n" + \
460
- "}\n"
461
- return opeartor_with_arg_code + "\n" + operator_code
462
-
463
- def gen_all_func(self):
464
- return self.gen_using_kernel() + "\n" + \
465
- self.gen_args() + "\n" + \
466
- self.gen_func_constructs() + "\n" + \
467
- self.gen_func_initialize() + "\n" + \
468
- self.gen_func_run() + "\n" + \
469
- self.gen_func_operator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py DELETED
@@ -1,249 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import helper
34
-
35
-
36
- indentation = " "
37
-
38
-
39
- def append_word(word):
40
- code = ""
41
- code += word
42
- code += " "
43
- return code
44
-
45
-
46
- def gen_namespace(namespace, codeBody):
47
- code_gen = "namespace " + namespace + " {\n"
48
- code_gen += codeBody
49
- code_gen += "} // namespace " + namespace + "\n"
50
- return code_gen
51
-
52
-
53
- def gen_expression(type, lval, rval = None):
54
- code_gen = ""
55
- code_gen += append_word(type)
56
- code_gen += append_word(lval)
57
- if rval is not None:
58
- code_gen += append_word("=")
59
- code_gen += append_word(rval)
60
- return code_gen
61
-
62
-
63
- def gen_class(name, codeBody, inheritance_code = None):
64
- code_gen = ""
65
- if inheritance_code is None:
66
- code_gen = "class " + name + "{\n"
67
- else:
68
- code_gen = "class " + name + " : "+ inheritance_code + "{\n"
69
- code_gen += codeBody
70
- code_gen += "}; // class " + name + "\n"
71
- return code_gen
72
-
73
-
74
- def gen_struct(name, codeBody, specialized = None):
75
- specialized_code = ""
76
- if specialized is not None:
77
- specialized_code = "<" + specialized + ">"
78
- code_gen = "struct " + name + specialized_code + "{\n"
79
- code_gen += codeBody
80
- code_gen += "}; // struct " + name + "\n"
81
- return code_gen
82
-
83
-
84
- def gen_template_arg(arg_type, arg_name, default_val = None):
85
- rval = None
86
- if default_val is not None:
87
- rval = str(default_val)
88
-
89
- arg_typename = ""
90
- if arg_type is int:
91
- arg_typename = "int"
92
- elif arg_type is bool:
93
- arg_typename = "bool"
94
- else:
95
- arg_typename = "typename"
96
-
97
- internal_arg_name = arg_name + "_"
98
-
99
- code_gen = indentation
100
- code_gen += gen_expression(arg_typename, internal_arg_name, rval)
101
-
102
- return code_gen
103
-
104
-
105
- def gen_template_args(args, set_default = True):
106
- arg_len = len(args)
107
- cnt = 1
108
- code_gen = ""
109
- for arg_tuple in args:
110
- arg_type = arg_tuple[0]
111
- arg_name = arg_tuple[1]
112
- arg_default_val = None
113
- if len(arg_tuple) == 3 and set_default:
114
- arg_default_val = arg_tuple[2]
115
-
116
- code_gen += gen_template_arg(arg_type, arg_name, arg_default_val)
117
- if cnt != arg_len:
118
- code_gen += ",\n"
119
- cnt += 1
120
-
121
- return code_gen
122
-
123
-
124
- def gen_template_head(args, set_default = True):
125
- code_gen = "template <\n"
126
- code_gen += gen_template_args(args, set_default)
127
- code_gen += ">\n"
128
- return code_gen
129
-
130
-
131
- def export_template_args(args):
132
- code_gen = "public:\n"
133
- for arg_tuple in args:
134
- code_gen += indentation
135
- arg_type = arg_tuple[0]
136
- arg_name = arg_tuple[1]
137
- internal_arg_name = arg_name + "_"
138
-
139
- typename = ""
140
- if arg_type is int:
141
- typename = "static int const"
142
- elif arg_type is bool:
143
- typename = "static bool const"
144
- else:
145
- typename = "using"
146
-
147
- code_gen += gen_expression(typename, arg_name, internal_arg_name)
148
- code_gen += ";\n"
149
- return code_gen
150
-
151
-
152
- def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None):
153
- code_gen = ""
154
-
155
- code_gen += gen_template_head(args, set_default)
156
- code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code)
157
-
158
- return code_gen
159
-
160
-
161
- def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True):
162
- code_gen = ""
163
- code_gen += gen_template_head(args, set_default)
164
- code = export_template_args(args) + codeBody
165
- if export_args is False:
166
- code = codeBody
167
- code_gen += gen_struct(struct_name, code , speicalized)
168
-
169
- return code_gen
170
-
171
-
172
- def gen_declare_template_struct(name, *params):
173
- code = name + "<"
174
- cnt = 0
175
- param_num = len(params)
176
- for param in params:
177
- final = ", "
178
- if cnt == param_num - 1:
179
- final = ""
180
- code += param + final
181
- cnt += 1
182
- code += ">;\n"
183
- return code
184
-
185
-
186
- def filtered_param(params, name_and_value_pair, keep_ = False):
187
- rtn_template_args = []
188
- speicalized_template_args = []
189
-
190
- for param in params:
191
- param_name = ""
192
- if len(param) >= 1:
193
- param_name = param[1]
194
- else:
195
- param_name = param[0]
196
-
197
- hit_flag = False
198
- set_value = ""
199
- for n_v_pair in name_and_value_pair:
200
-
201
- filter_name = n_v_pair[0]
202
- set_value = n_v_pair[1]
203
-
204
- if param_name == (filter_name + "_") or param_name == filter_name :
205
- hit_flag = True
206
- break
207
-
208
-
209
- if hit_flag is False:
210
- rtn_template_args.append(param)
211
-
212
- if hit_flag is True:
213
- speicalized_template_args.append(set_value)
214
- else:
215
- if keep_ is True:
216
- speicalized_template_args.append(param_name + "_")
217
- else:
218
- speicalized_template_args.append(param_name)
219
-
220
-
221
- specialized_template_arg_str = helper.list_2_string(speicalized_template_args)
222
-
223
- return rtn_template_args, specialized_template_arg_str
224
-
225
-
226
- def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True):
227
- code = "void " + func_name + "(\n"
228
- for arg in arg_lists:
229
- arg_tp = arg[0]
230
- arg_nm = arg[1]
231
- code += " " + arg_tp + " " + arg_nm + ",\n"
232
- code += "cudaStream_t stream)"
233
- if only_declare :
234
- return code
235
- code += "{\n"
236
-
237
- code += code_body + "\n"
238
- code += "}\n"
239
- return code
240
-
241
-
242
- def indent_level(code, level = 0):
243
- rtn_code = ""
244
- for i in range(level):
245
- rtn_code += " "
246
-
247
- rtn_code += code
248
-
249
- return rtn_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py DELETED
@@ -1,476 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import gen_ir
34
- import helper
35
- import gen_threadblock as gen_tb
36
-
37
-
38
- class gen_default_Gemm:
39
- def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
40
- self.gen_class_name = "B2bGemm"
41
- self.template_param = template_param
42
- self.b2b_num = b2b_num
43
-
44
- self.cutlass_deps_root = cutlass_deps_root
45
- self.project_root = project_root
46
-
47
- def gen_B2bMma(self, specialized_template_args):
48
- code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n"
49
- code += specialized_template_args
50
- code += ">::ThreadblockB2bMma;\n"
51
-
52
- # print(code)
53
- return code
54
-
55
- def gen_epilogue(self):
56
- epilogue_code = ""
57
- epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n"
58
-
59
- epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n"
60
- epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n"
61
- epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n"
62
- epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n"
63
- epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n"
64
- epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n"
65
- epilogue_code += ">::Epilogue;\n"
66
-
67
- epilogue_code += "using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;\n\n"
68
-
69
- return epilogue_code
70
-
71
-
72
- def gen_include_header(self):
73
- code = '''
74
- /* Auto Generated code - Do not edit.*/
75
-
76
- #pragma once
77
- #include \"{cutlass_dir}cutlass/cutlass.h\"
78
-
79
- #include \"{cutlass_dir}cutlass/layout/matrix.h\"
80
- #include \"{cutlass_dir}cutlass/numeric_types.h\"
81
-
82
- #include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\"
83
- #include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\"
84
-
85
- #include \"{cutlass_dir}cutlass/gemm/gemm.h\"
86
- #include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\"
87
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\"
88
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\"
89
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\"
90
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\"
91
- #include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\"
92
- #include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\"
93
- #include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\"
94
- #include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\"
95
-
96
- #include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\"
97
-
98
- #include \"../kernel/b2b_gemm.h\"
99
- #include \"../threadblock/default_b2b_mma.h\"
100
- '''.format(cutlass_dir=self.cutlass_deps_root)
101
- return code
102
-
103
- def gen_code(self):
104
- gen_using = ''
105
- # Generate default template struct
106
- gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False)
107
-
108
-
109
- filter_list = []
110
- filter_list.append(('Stages', 2))
111
- filter_list.append(("OperatorClass", "arch::OpClassTensorOp"))
112
- filter_list.append(("ArchTag", "arch::Sm75"))
113
-
114
- for i in range(self.b2b_num):
115
- filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor"))
116
-
117
-
118
- rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True)
119
-
120
-
121
- B2bMma_code = self.gen_B2bMma(speicalized_template_args)
122
- epilogue_and_rest_code = self.gen_epilogue()
123
-
124
- gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False)
125
-
126
- code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code)))
127
-
128
- return self.gen_include_header() + code
129
-
130
-
131
- class gen_Kernel:
132
- def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
133
- self.gen_class_name = "B2bGemm"
134
- self.template_param = template_param
135
- self.b2bnum = b2b_num
136
-
137
- self.cutlass_deps_root = cutlass_deps_root
138
- self.project_root = project_root
139
-
140
- def gen_include_header(self):
141
- code = '''
142
- #pragma once
143
-
144
- #include \"{cutlass_dir}cutlass/cutlass.h\"
145
- #include \"{cutlass_dir}cutlass/gemm/gemm.h\"
146
- #include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root)
147
- return code
148
-
149
- def gen_Params(self):
150
- gen_param = ""
151
- for i in range(self.b2bnum):
152
- gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n"
153
- gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n"
154
- gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n"
155
- gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n"
156
-
157
- for i in range(self.b2bnum):
158
- gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n"
159
- gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n"
160
- if i == self.b2bnum - 1:
161
- gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n"
162
- gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n"
163
-
164
- else:
165
- gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n"
166
- gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n"
167
-
168
-
169
-
170
-
171
- gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n"
172
- gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n"
173
-
174
- for i in range(self.b2bnum):
175
- gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n"
176
-
177
- gen_param += " " + 'int batch_count' + ";\n"
178
- gen_param += " " + 'int gemm_k_iterations_0' + ";\n"
179
-
180
-
181
- return gen_param
182
-
183
- def gen_Memberfunc(self):
184
- code_default = "\nCUTLASS_HOST_DEVICE\n"
185
- code_default += "Params()"
186
-
187
- code_default += " { } \n\n"
188
-
189
- code_construct = "\nCUTLASS_HOST_DEVICE\n"
190
- code_construct += "Params(\n"
191
-
192
- for i in range(self.b2bnum):
193
- code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n"
194
-
195
- code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n"
196
-
197
- code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n"
198
-
199
- for i in range(self.b2bnum):
200
- code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n"
201
- if i == self.b2bnum - 1:
202
- code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n"
203
- else:
204
- code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n"
205
-
206
- code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n"
207
- for i in range(self.b2bnum):
208
- code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n"
209
-
210
- code_construct += " " + "int batch_count = 1\n"
211
-
212
- code_construct += "):\n"
213
-
214
- for i in range(self.b2bnum):
215
- code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n"
216
-
217
- code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n"
218
- code_construct += " " + "params_A0(ref_A0.layout()),\n"
219
- code_construct += " " + "ref_A0(ref_A0),\n"
220
-
221
- for i in range(self.b2bnum):
222
- code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n"
223
- code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n"
224
- code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n"
225
- code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n"
226
-
227
- code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n"
228
- code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n"
229
-
230
- for i in range(self.b2bnum):
231
- code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n"
232
-
233
- code_construct += " " + "batch_count(batch_count) {\n"
234
- code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n"
235
-
236
- code_construct += "}\n"
237
-
238
- return code_default + code_construct
239
-
240
- def gen_using(self):
241
- code_using = ""
242
-
243
- for i in range(self.b2bnum - 1):
244
- code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n"
245
-
246
- code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n"
247
-
248
- for i in range(self.b2bnum - 1):
249
- code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n"
250
-
251
-
252
- code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n"
253
- code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n"
254
-
255
- code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc())
256
-
257
- code_using += "union SharedStorage {\n"
258
- code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n"
259
- code_using += " " + "typename Epilogue::SharedStorage epilogue;\n"
260
- code_using += "};\n"
261
-
262
- return code_using
263
-
264
- def gen_can_implement(self):
265
- gen_code = ""
266
- return gen_code
267
-
268
- def gen_operator_and_constr(self):
269
- ctr_code = "CUTLASS_HOST_DEVICE\n"
270
- ctr_code += self.gen_class_name + "() { } \n\n"
271
- operator_code = "CUTLASS_DEVICE\n"
272
- operator_code += "void operator()(Params const &params, SharedStorage &shared_storage) {\n"
273
- operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n"
274
- operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
275
- operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n"
276
- operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n"
277
- operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n"
278
- operator_code += " " + " " + "return;\n"
279
- operator_code += " " + "}\n"
280
-
281
- operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n"
282
- operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n"
283
- operator_code += " " + " " + "0\n"
284
- operator_code += " " + "};\n"
285
-
286
- for i in range(self.b2bnum):
287
- operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n"
288
- operator_code += " " + " " + "0,\n"
289
- operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n"
290
- operator_code += " " + "};\n"
291
-
292
- operator_code += " " + "int thread_idx = threadIdx.x;\n\n"
293
-
294
- operator_code += " " + "MatrixCoord threadblock_offset(\n"
295
- operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n"
296
- operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n"
297
- operator_code += " " + ");\n"
298
-
299
- operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n"
300
- operator_code += " " + " " + "params.params_A0,\n"
301
- operator_code += " " + " " + "params.ref_A0.data(),\n"
302
- operator_code += " " + " " + "params.problem_size_0.mk(),\n"
303
- operator_code += " " + " " + "thread_idx,\n"
304
- operator_code += " " + " " + "tb_offset_A0);\n"
305
-
306
- operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n"
307
-
308
-
309
- for i in range (self.b2bnum):
310
- operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n"
311
- operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n"
312
- operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n"
313
- operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n"
314
- operator_code += " " + " " + "thread_idx,\n"
315
- operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n"
316
- operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n"
317
-
318
-
319
- for i in range (self.b2bnum - 1):
320
- operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n"
321
- operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n"
322
- operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n"
323
- operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n"
324
- operator_code += " " + " " + "thread_idx,\n"
325
- operator_code += " " + " " + "threadblock_offset" + ");\n"
326
- operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n"
327
- operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n"
328
-
329
-
330
- for i in range (self.b2bnum - 1):
331
- operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
332
-
333
-
334
- operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n"
335
- operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
336
-
337
- for i in range (self.b2bnum - 1):
338
- operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n"
339
-
340
- operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n"
341
-
342
- operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n"
343
- operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n"
344
-
345
- operator_code += " " + "src_accum.clear();\n"
346
- operator_code += " " + "accumulators.clear();\n"
347
- operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, "
348
-
349
- for i in range(self.b2bnum):
350
- operator_code += helper.var_idx("iterator_B", i) + ", "
351
-
352
- operator_code += "src_accum"
353
- if self.b2bnum != 1:
354
- operator_code += ", "
355
- for i in range(self.b2bnum - 1):
356
- operator_code += helper.var_idx("output_op_", i) + ", "
357
-
358
- for i in range(self.b2bnum - 1):
359
- operator_code += helper.var_idx("epilogue_", i) + ", "
360
-
361
- for i in range(self.b2bnum - 1):
362
- final = ", "
363
- if i == self.b2bnum - 2:
364
- final =""
365
- operator_code += helper.var_idx("iterator_C", i) + final
366
- operator_code += ");\n"
367
-
368
- operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n"
369
- operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
370
-
371
-
372
-
373
- operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n"
374
- operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n"
375
- operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n"
376
- operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
377
- operator_code += " " + " " + "thread_idx,\n"
378
- operator_code += " " + " " + "threadblock_offset\n"
379
- operator_code += " " + ");\n"
380
- operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n"
381
-
382
- operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n"
383
-
384
- operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n"
385
- operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n"
386
- operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n"
387
- operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
388
- operator_code += " " + " " + "thread_idx,\n"
389
- operator_code += " " + " " + "threadblock_offset\n"
390
- operator_code += " " + ");\n"
391
- operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n"
392
-
393
-
394
- operator_code += " " + "Epilogue epilogue(\n"
395
- operator_code += " " + " " + "shared_storage.epilogue,\n"
396
- operator_code += " " + " " + "thread_idx,\n"
397
- operator_code += " " + " " + "warp_idx,\n"
398
- operator_code += " " + " " + "lane_idx\n"
399
- operator_code += " " + ");\n"
400
-
401
- operator_code += " " + "epilogue("
402
- operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", "
403
- operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", "
404
- operator_code += "accumulators, "
405
- operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n"
406
- operator_code += "}\n"
407
-
408
- return ctr_code + operator_code
409
-
410
- def gen_include_header(self):
411
- code = '''
412
- #pragma once
413
-
414
- #include \"{cutlass_dir}cutlass/cutlass.h\"
415
-
416
- #include \"{cutlass_dir}cutlass/gemm/gemm.h\"
417
- #include \"{cutlass_dir}cutlass/matrix_coord.h\"
418
- #include \"{cutlass_dir}cutlass/semaphore.h\"
419
- '''.format(cutlass_dir=self.cutlass_deps_root)
420
- return code
421
- def gen_code(self):
422
-
423
- template_param = []
424
- template_param.append(("typename", "B2bMma"))
425
- template_param.append(("typename", "Epilogue"))
426
- template_param.append(("typename", "ThreadblockSwizzle"))
427
- template_param.append((bool, "SplitKSerial"))
428
-
429
- code_body = ""
430
- code_body += self.gen_using()
431
- code_body += self.gen_operator_and_constr()
432
-
433
- struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body)
434
- code = self.gen_include_header()
435
- code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code)))
436
-
437
- return self.gen_include_header() + code
438
-
439
-
440
-
441
- class gen_kernel:
442
- def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root):
443
- self.template_param = template_param
444
-
445
- self.gen_class_name = "B2bGemm"
446
- self.gen_kernel_name = gen_class_name + "Kernel"
447
- self.template_args = []
448
-
449
- self.cutlass_deps_root = cutlass_deps_root
450
- self.project_root = project_root
451
-
452
- self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
453
- self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
454
-
455
- # Include gen_threadBlock
456
- self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root)
457
-
458
- self.file_dir = output_dir + "/kernel/"
459
-
460
- def gen_code(self, first_use_1stage):
461
-
462
- default_b2b_gemm = self.gen_default_b2b_gemm.gen_code()
463
-
464
- print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir)
465
-
466
- with open(self.file_dir + "default_b2b_gemm.h", "w+") as f:
467
- f.write(default_b2b_gemm)
468
-
469
- kernel = self.gen_Kerenl.gen_code()
470
- print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir)
471
-
472
- with open(self.file_dir + "b2b_gemm.h", "w+") as f:
473
- f.write(kernel)
474
-
475
- # Call code to gen threadblock
476
- self.gen_threadBlock.gen_code(first_use_1stage)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py DELETED
@@ -1,232 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import helper
34
- import gen_ir as ir
35
-
36
- class gen_test:
37
- def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
38
- self.fuse_gemm_info = fuse_gemm_info
39
- self.gen_class_name = gen_class_name
40
- self.user_header_file = user_header_file
41
- self.sample_dir = output_dir
42
- self.b2b_num = len(fuse_gemm_info)
43
-
44
- def gen_cpp_sample(self):
45
- code = "/* Auto Generated code - Do not edit.*/\n"
46
- code += "#include <cstdio> \n"
47
-
48
- code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
49
- code += "#include \"cutlass/cutlass.h\" \n"
50
-
51
- code += "#include \"../cutlass_irrelevant.h\" \n"
52
- code += "#include \"../cutlass_verify.h\" \n"
53
-
54
- code += "#include \"leaky_bias.h\" \n"
55
-
56
- code += "#include \"utils.h\" \n"
57
-
58
-
59
-
60
- code += "int main(int args, char * argv[]) {\n"
61
- code += " " + "int M = atoi(argv[1]);\n"
62
- code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n"
63
- code += " " + "if(args == 3);\n"
64
- code += " " + " " + "K0 = atoi(argv[2]);\n"
65
- code += " " + "int B = 1;\n"
66
- code += " " + "if(args == 4);\n"
67
- code += " " + " " + "B = atoi(argv[3]);\n"
68
-
69
- code += " " + "srand(1234UL);\n"
70
- code += " " + "int device_id = 0;\n"
71
- code += " " + "cudaGetDevice(&device_id);\n"
72
- code += " " + "cudaDeviceProp prop;\n"
73
- code += " " + "cudaGetDeviceProperties(&prop, device_id);\n"
74
- code += " " + "int sm = prop.major *10 + prop.minor;\n"
75
- code += "using ElementCompute = cutlass::half_t;\n"
76
-
77
- for i in range(self.b2b_num):
78
- code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n"
79
- addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
80
- if addbias:
81
- code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n"
82
- else:
83
- code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n"
84
-
85
- code += " " + "size_t flops = 0;\n"
86
-
87
- for i in range(self.b2b_num):
88
- m = self.fuse_gemm_info[i]['mnk'][0]
89
- n = self.fuse_gemm_info[i]['mnk'][1]
90
- k = self.fuse_gemm_info[i]['mnk'][2]
91
-
92
- bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])
93
-
94
- this_k = "K0"
95
- if (i > 0):
96
- this_k = str(k)
97
-
98
- code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n"
99
-
100
- code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n"
101
-
102
- code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n"
103
- code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n"
104
- code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n"
105
- code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n"
106
-
107
- code += " " + helper.var_idx("Mat_A", i) + ".init();\n"
108
- code += " " + helper.var_idx("Mat_B", i) + ".init();\n"
109
- code += " " + helper.var_idx("Mat_C", i) + ".init();\n"
110
-
111
-
112
-
113
- code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n"
114
-
115
- params = []
116
- params.append("M")
117
- params.append("B")
118
-
119
- params.append("Mat_A0.device_ptr")
120
- for i in range(self.b2b_num):
121
- params.append(helper.var_idx("Mat_B", i) + ".device_ptr")
122
- params.append(helper.var_idx("Mat_C", i) + ".device_ptr")
123
- if i != self.b2b_num-1:
124
- params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr")
125
- params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr")
126
-
127
- code += " " + "Param arguments = {\n"
128
- code += " " + " " + "M,\n"
129
- code += " " + " " + "K0,\n"
130
- code += " " + " " + "B,\n"
131
-
132
- code += " " + " " + "reinterpret_cast<const void*>(Mat_A0.device_ptr),\n"
133
- cnt = 1
134
- for i in range(self.b2b_num):
135
- bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
136
- code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n"
137
- cnt += 1
138
- if bias_flag:
139
- code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n"
140
- cnt += 1
141
- else:
142
- code += " " + " " + "reinterpret_cast<const void*>(NULL),\n"
143
-
144
- epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
145
- acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
146
- for arg in epilogue_args:
147
- arg_value = str(arg[2])
148
-
149
- code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n"
150
-
151
- if i != self.b2b_num - 1:
152
- code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n"
153
- else:
154
- code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n"
155
-
156
-
157
-
158
-
159
- code += " " + "TI(FUSED_CUTLASS);\n"
160
- code += " " + "for(int i = 0; i < 100; i++){\n"
161
- code += " " + " " + "one_api(arguments, sm, NULL);\n"
162
-
163
- code += " " + "}\n"
164
- code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n"
165
-
166
- code += "\n"
167
-
168
- for i in range(self.b2b_num):
169
- code_this = ""
170
-
171
- N_str = str(self.fuse_gemm_info[i]['mnk'][1])
172
-
173
- code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
174
- code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n"
175
- ldmA = str(self.fuse_gemm_info[i]['mnk'][2])
176
- if i == 0:
177
- ldmA = "K0"
178
- ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
179
- if i == 0:
180
- ldmB = "K0"
181
- ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
182
-
183
- ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
184
-
185
- if self.fuse_gemm_info[i]['A_format'] is 'Col':
186
- ldmA = "M"
187
- if self.fuse_gemm_info[i]['B_format'] is 'Row':
188
- ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
189
- if self.fuse_gemm_info[i]['C_format'] is 'Col':
190
- ldmC = "M"
191
-
192
- if i == 0:
193
- code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
194
- else:
195
- code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
196
-
197
- code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
198
-
199
- M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
200
-
201
- code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
202
- code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n"
203
- code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
204
- for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
205
- arg_value = str(epilogue_arg[2])
206
- code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")"
207
- code_this += " " + " },\n"
208
- code_this += " " + " " + "B};\n"
209
-
210
- code += code_this
211
-
212
-
213
-
214
- code += " " + "TI(UNFUSED_CUTLASS);\n"
215
- code += " " + "for(int i = 0; i < 100; i++){\n"
216
- code += " " + " " + self.gen_class_name + "_verify(\n"
217
- for i in range(self.b2b_num):
218
- code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n"
219
- code += " " + " " + " " + "NULL);\n"
220
-
221
- code += " " + "}\n"
222
- code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n"
223
-
224
- code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n"
225
- code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n"
226
- code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \
227
- + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n"
228
-
229
- code += "\n\n}\n"
230
-
231
- with open(self.sample_dir + "sample.cu", "w+") as f:
232
- f.write(code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py DELETED
@@ -1,1013 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import gen_ir
34
- import helper
35
-
36
-
37
- class gen_default_b2b_mma:
38
- def __init__(self, template_param, gen_class_name, b2b_num,cutlass_deps_root, project_root):
39
- self.gen_class_name = "DefaultB2bMma"
40
- self.template_param = template_param
41
- self.b2b_num = b2b_num
42
-
43
- self.cutlass_deps_root = cutlass_deps_root
44
- self.project_root = project_root
45
-
46
- def gen_include_header(self):
47
- code = '''
48
- /* Auto Generated code - Do not edit.*/
49
-
50
- #pragma once
51
-
52
- #include \"{cutlass_dir}cutlass/cutlass.h\"
53
- #include \"{cutlass_dir}cutlass/numeric_types.h\"
54
- #include \"{cutlass_dir}cutlass/arch/arch.h\"
55
-
56
- #include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\"
57
- #include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h\"
58
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\"
59
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\"
60
- #include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\"
61
-
62
- #include \"../threadblock/b2b_mma_pipelined.h\"
63
- #include \"../../fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h\"
64
- #include \"../../fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h\"
65
- #include \"../../fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h\"
66
- '''.format(cutlass_dir=self.cutlass_deps_root)
67
- return code
68
-
69
-
70
- def gen_using_MmaCore(self, stage):
71
- threadBlockShape = "ThreadblockShape"
72
- warpShape = "WarpShape"
73
- instrunctionShape = "InstructionShape"
74
- Mma_typename = "typename cutlass::gemm::threadblock::DefaultMmaCore"
75
-
76
-
77
- gen_code = ""
78
-
79
- for i in range(self.b2b_num):
80
- code_using = "using MmaCore" + str(i)
81
- gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(Mma_typename, \
82
- helper.var_idx(threadBlockShape, i), helper.var_idx(warpShape, i), instrunctionShape, \
83
- "ElementA", "LayoutA", \
84
- helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), \
85
- helper.var_idx("ElementAccumulator", i), "layout::RowMajor", \
86
- "OperatorClass", str(stage), "Operator")
87
- return gen_code
88
-
89
- def gen_using_FusedAddBiasEpilogue(self):
90
- gen_code = ""
91
- for i in range(self.b2b_num - 1):
92
- code_using = helper.var_idx("using FusedAddBiasEpilogue", i)
93
- epilogue_name = "typename cutlass::epilogue::threadblock::DefaultFusedBiasActEpilogueTensorOp"
94
- template_args = helper.var_idx("<ThreadblockShape", i) + helper.var_idx(",typename MmaCore", i) + helper.var_idx("::MmaPolicy::Operator, 1, EpilogueOutputOp", i) + ", 2>::Epilogue"
95
-
96
- gen_code += code_using + " = " + epilogue_name + template_args + ";\n"
97
-
98
- return gen_code
99
-
100
-
101
- def gen_using_Iterator(self):
102
- code_using = "using IteratorA0"
103
- iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator"
104
- MmaCore = "MmaCore0"
105
- matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kM, " + MmaCore + "::Shape::kK>"
106
- iterator_map = "typename " + MmaCore + "::IteratorThreadMapA"
107
- gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \
108
- matrix_shape, "ElementA", "LayoutA", "1", iterator_map, "AlignmentA_")
109
-
110
- for i in range(self.b2b_num):
111
- code_using = "using IteratorB" + str(i)
112
- iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator"
113
- MmaCore = "MmaCore" + str(i)
114
- matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kK, " + MmaCore + "::Shape::kN>"
115
- iterator_map = "typename " + MmaCore + "::IteratorThreadMapB"
116
-
117
- gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \
118
- matrix_shape, helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), "0", iterator_map, "AlignmentB_")
119
-
120
- return gen_code
121
-
122
- def gen_fragment_iterator(self):
123
- gen_code = "using AccumulatorLayout = cutlass::layout::ColumnMajor;\n"
124
-
125
- for i in range(1, self.b2b_num):
126
- code_using = "using FragmentIteratorA" + str(i)
127
- iterator_typename = "cutlass::gemm::warp::MmaTensorOpPureFragmentIterator"
128
- curr_MmaCore = "MmaCore" + str(i)
129
- prev_MmaCore = "MmaCore" + str(i - 1)
130
- Matrix_shape_curr = "cutlass::MatrixShape<" + curr_MmaCore + "::WarpShape::kM, " + curr_MmaCore + "::InstructionShape::kK>"
131
- Matrix_shape_prev = "cutlass::MatrixShape<" + prev_MmaCore + "::WarpShape::kM, " + prev_MmaCore + "::WarpShape::kN>"
132
- Curr_shape_kK = curr_MmaCore + "::Shape::kK"
133
-
134
- gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \
135
- Matrix_shape_curr, Matrix_shape_prev, Curr_shape_kK, \
136
- helper.var_idx("ElementAccumulator", i-1), "ElementA", \
137
- "AccumulatorLayout", "InstructionShape_", "true")
138
-
139
- return gen_code
140
-
141
- def gen_threadblockmma(self):
142
- code_using = "using ThreadblockB2bMma"
143
- iterator_typename = "cutlass::gemm::threadblock::B2bMmaPipelined"
144
-
145
- MmaPipelined_param_Mma0_shape = "typename MmaCore0::Shape"
146
- MmaPipelined_param_Mma0_iteratorA = "IteratorA0"
147
- MmaPipelined_param_Mma0_smemIteratorA = "typename MmaCore0::SmemIteratorA"
148
- MmaPipelined_param_Mma0_iteratorB = "IteratorB0"
149
- MmaPipelined_param_Mma0_smemIteratorB = "typename MmaCore0::SmemIteratorB"
150
-
151
- MmaPipelined_param_list = MmaPipelined_param_Mma0_shape + ", " + MmaPipelined_param_Mma0_iteratorA + ", " + MmaPipelined_param_Mma0_smemIteratorA + ", " + MmaPipelined_param_Mma0_iteratorB + ", " + MmaPipelined_param_Mma0_smemIteratorB + ", "
152
-
153
- for i in range(1, self.b2b_num):
154
- MmaPipelined_param_Mma_shape = "typename MmaCore" + str(i) + "::Shape"
155
- MmaPipelined_param_Mma_iteratorA = "FragmentIteratorA" + str(i)
156
- MmaPipelined_param_Mma_iteratorB = "IteratorB" + str(i)
157
- MmaPipelined_param_Mma_smemIteratorB = "typename MmaCore" + str(i) + "::SmemIteratorB"
158
-
159
- MmaPipelined_param_list += MmaPipelined_param_Mma_shape + ", " + MmaPipelined_param_Mma_iteratorA + ", " + MmaPipelined_param_Mma_iteratorB + ", " + MmaPipelined_param_Mma_smemIteratorB + ", "
160
-
161
- MmaPipelined_param_list += "ElementAccumulator0, layout::RowMajor, "
162
-
163
- for i in range(self.b2b_num - 1):
164
- epilogue_name = "EpilogueOutputOp" + str(i)
165
- MmaPipelined_param_list += epilogue_name + ", "
166
-
167
- for i in range(self.b2b_num - 1):
168
- epilogue_name = "FusedAddBiasEpilogue" + str(i)
169
- MmaPipelined_param_list += epilogue_name + ", "
170
-
171
- for i in range(self.b2b_num):
172
- MmaPolicy = "typename MmaCore" + str(i) + "::MmaPolicy"
173
- MmaPipelined_param_list += MmaPolicy + ", "
174
-
175
-
176
- cnt = 0
177
- for i in range(self.b2b_num):
178
- MmaStage = helper.var_idx("Stages", i)
179
- final = ", "
180
- if cnt == self.b2b_num - 1:
181
- final = ""
182
- MmaPipelined_param_list += MmaStage + final
183
- cnt += 1
184
-
185
- gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, MmaPipelined_param_list)
186
-
187
- return gen_code
188
-
189
-
190
-
191
- def gen_code(self):
192
- gen_using = ''
193
- # Generate default template struct
194
- gen_code = gen_ir.gen_template_struct(self.gen_class_name, self.template_param, "", speicalized = None, set_default=False)
195
-
196
- # Generate specialized template struct
197
-
198
- mmacore_codebody = self.gen_using_MmaCore(2)
199
- iterator_codebody = self.gen_using_Iterator()
200
- fragment_iterator_codebody = self.gen_fragment_iterator()
201
- epilogue_iterator_codebody = self.gen_using_FusedAddBiasEpilogue()
202
- threadBlockMma = self.gen_threadblockmma()
203
- specialized_code = mmacore_codebody + iterator_codebody + fragment_iterator_codebody + epilogue_iterator_codebody + threadBlockMma
204
-
205
- # Specialize layout C -> cutlass::layout::RowMajor
206
-
207
- rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, [ ('LayoutD', "cutlass::layout::RowMajor")], keep_= True)
208
-
209
- gen_speical_code = gen_ir.gen_template_struct(self.gen_class_name, rtn_template_args, specialized_code, speicalized = speicalized_template_args, set_default=False)
210
- code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", gen_code + gen_speical_code)))
211
-
212
- return self.gen_include_header() + code
213
-
214
-
215
- class gen_b2b_mme_pipelined:
216
- def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
217
- self.gen_class_name = "B2bMmaPipelined"
218
- self.template_param = template_param
219
- self.b2b_num = b2b_num
220
- self.cutlass_deps_root = cutlass_deps_root
221
- self.project_root = project_root
222
-
223
-
224
- def gen_include_header(self):
225
- code = '''
226
- #pragma once
227
-
228
- #include \"{cutlass_dir}cutlass/cutlass.h\"
229
- #include \"{cutlass_dir}cutlass/array.h\"
230
- #include \"{cutlass_dir}cutlass/aligned_buffer.h\"
231
- #include \"{cutlass_dir}cutlass/numeric_conversion.h\"
232
-
233
- #include \"{cutlass_dir}cutlass/numeric_types.h\"
234
- #include \"{cutlass_dir}cutlass/matrix_shape.h\"
235
-
236
- #include \"{cutlass_dir}cutlass/gemm/gemm.h\"
237
- #include \"{cutlass_dir}cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h\"
238
-
239
- #include \"../threadblock/b2b_mma_base.h\"\n'''.format(cutlass_dir = self.cutlass_deps_root)
240
- return code
241
-
242
-
243
- def gen_using(self):
244
- code_using = "using FragmentA0 = typename IteratorA0::Fragment;\n"
245
-
246
- code_using += "using Base = B2bMmaBase<"
247
- for i in range(self.b2b_num):
248
- code_using += helper.var_idx("Shape", i) + "_, "
249
- for i in range(self.b2b_num):
250
- code_using += helper.var_idx("Policy", i) + "_, "
251
- for i in range(self.b2b_num):
252
- code_using += helper.var_idx("Stage", i) + "_, "
253
- code_using = code_using[: -2] + ">;\n"
254
-
255
-
256
- for i in range(self.b2b_num):
257
- code_using += helper.var_idx("using FragmentB", i) + helper.var_idx(" = typename IteratorB", i) + "::Fragment;\n"
258
- code_using += helper.var_idx("using FragmentC", i) + helper.var_idx(" = typename Policy", i) + "::Operator::FragmentC;\n"
259
- code_using += helper.var_idx("using Operator", i) + helper.var_idx(" = typename Policy", i) + "::Operator;\n"
260
-
261
- for i in range(self.b2b_num - 1):
262
- code_using += helper.var_idx("using IteratorC", i) + helper.var_idx(" = typename FusedAddBiasEpilogue", i) + "::OutputTileIterator;\n"
263
-
264
- code_using += "using ArchTag = typename Policy0::Operator::ArchTag;\n"
265
- code_using += "static ComplexTransform const kTransformA0 = Operator0::kTransformA;\n"
266
-
267
- for i in range(self.b2b_num):
268
- code_using += helper.var_idx("static ComplexTransform const kTransformB", i) + helper.var_idx(" = Operator", i) + "::kTransformB;\n"
269
-
270
- code_using += "private:\n"
271
- code_using += "using WarpFragmentA0 = typename Operator0::FragmentA;\n"
272
- code_using += "using WarpFragmentB0 = typename Operator0::FragmentB;\n"
273
-
274
- for i in range(1, self.b2b_num):
275
- code_using += helper.var_idx("using WarpFragmentA", i) + helper.var_idx(" = typename FragmentIteratorA", i) + "::Fragment;\n"
276
- code_using += helper.var_idx("using WarpFragmentB", i) + helper.var_idx(" = typename Operator", i) + "::FragmentB;\n"
277
-
278
- code_using += "protected:\n"
279
-
280
- code_using += "SmemIteratorA0 smem_iterator_A_;\n"
281
-
282
- for i in range(self.b2b_num):
283
- code_using += helper.var_idx("SmemIteratorB", i) + helper.var_idx(" smem_iterator_B", i) + "_;\n"
284
-
285
- return code_using
286
-
287
-
288
- def gen_operator(self, first_use_1stage = False):
289
- code = ""
290
- def gen_operator_param(b2b_num):
291
- param_code = ""
292
- param_code += "int gemm_k_iterations_0,\n"
293
- param_code += helper.var_idx("FragmentC", b2b_num-1) + helper.var_idx(" &accum", b2b_num-1) + ",\n"
294
- param_code += "IteratorA0 iterator_A,\n"
295
-
296
- for i in range(b2b_num):
297
- param_code += helper.var_idx("IteratorB", i) + " " + helper.var_idx("iterator_B", i) + ",\n"
298
-
299
- param_code += "FragmentC0 const &src_accum, \n"
300
-
301
- for i in range(b2b_num - 1):
302
- param_code += helper.var_idx("OutputOp", i) + " " + helper.var_idx("output_op_", i) + ",\n"
303
- for i in range(b2b_num - 1):
304
- param_code += helper.var_idx("FusedAddBiasEpilogue", i) + " " + helper.var_idx("epilogue_", i) + ",\n"
305
- for i in range(b2b_num - 1):
306
- param_code += helper.var_idx("IteratorC", i) + " " + helper.var_idx("iterator_C", i) + ",\n"
307
-
308
-
309
- param_code += "TransformA0 transform_A0 = TransformA0(), \n"
310
-
311
- for i in range(b2b_num):
312
- final = "(),\n"
313
- if i == b2b_num - 1:
314
- final = "()\n"
315
- param_code += helper.var_idx("TransformB", i) + " " + helper.var_idx("transform_B", i) + " = " +helper.var_idx("TransformB", i) + final
316
-
317
- return param_code
318
-
319
-
320
-
321
- def gen_first_gemm_1stage(b2b_num):
322
- accu_code = " FragmentC0 accum0 = src_accum;\n"
323
- if b2b_num == 1:
324
- accu_code = " accum0 = src_accum;\n"
325
-
326
- code ="\
327
- \n\
328
- FragmentA0 tb_frag_A;\n\
329
- FragmentB0 tb_frag_B0;\n\
330
- \n\
331
- int smem_write_stage_idx = 1;\n\
332
- \n\
333
- tb_frag_A.clear();\n\
334
- tb_frag_B0.clear();\n\
335
- \n\
336
- // The last kblock is loaded in the prolog\n\
337
- iterator_A.load(tb_frag_A);\n\
338
- iterator_B0.load(tb_frag_B0);\n\
339
- \n\
340
- ++iterator_A;\n\
341
- ++iterator_B0;\n\
342
- \n\
343
- WarpFragmentA0 warp_frag_A0;\n\
344
- WarpFragmentB0 warp_frag_B0;\n\
345
- \n\
346
- Operator0 warp_mma0;\n\
347
- \n\
348
- // Avoid reading out of bounds\n\
349
- if (gemm_k_iterations_0 <= 1) {\n\
350
- iterator_A.clear_mask();\n\
351
- iterator_B0.clear_mask();\n\
352
- }\n\
353
- \n\
354
- // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\
355
- // shared memory loads (which have the tightest latency requirement).\n\
356
- \n\
357
- //\n\
358
- // Mainloop\n\
359
- //\n\
360
- \n\
361
- // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\
362
- CUTLASS_GEMM_LOOP\n\
363
- for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\
364
- \n\
365
- this->smem_iterator_A_.store(tb_frag_A);\n\
366
- this->smem_iterator_B0_.store(tb_frag_B0);\n\
367
- \n\
368
- __syncthreads();\n\
369
- //\n\
370
- // Loop over GEMM K dimension\n\
371
- //\n\
372
- \n\
373
- CUTLASS_PRAGMA_UNROLL\n\
374
- for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\
375
- \n\
376
- // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\
377
- // as the case may be.\n\
378
- \n\
379
- this->warp_tile_iterator_A0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\
380
- this->warp_tile_iterator_B0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\
381
- \n\
382
- this->warp_tile_iterator_A0_.load(warp_frag_A0);\n\
383
- this->warp_tile_iterator_B0_.load(warp_frag_B0);\n\
384
- \n\
385
- ++this->warp_tile_iterator_A0_;\n\
386
- ++this->warp_tile_iterator_B0_;\n\
387
- \n\
388
- warp_mma0(accum0, warp_frag_A0, warp_frag_B0, accum0);\n\
389
- }\n\
390
- this->warp_tile_iterator_A0_.add_tile_offset({0, -Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\
391
- this->warp_tile_iterator_B0_.add_tile_offset({-Policy0::kPartitionsK * Base::kWarpGemmIterations0, 0});\n\
392
- \n\
393
- __syncthreads();\n\
394
- iterator_A.load(tb_frag_A);\n\
395
- iterator_B0.load(tb_frag_B0);\n\
396
- \n\
397
- ++iterator_A;\n\
398
- ++iterator_B0;\n\
399
- \n\
400
- if(gemm_k_iterations_0 <= 2) {\n\
401
- iterator_A.clear_mask();\n\
402
- iterator_B0.clear_mask();\n\
403
- }\n\
404
- }\n"
405
-
406
- return accu_code + code
407
-
408
-
409
- def gen_first_gemm_2stage(b2b_num):
410
-
411
- accu_code = " FragmentC0 accum0 = src_accum;\n"
412
- if b2b_num == 1:
413
- accu_code = " accum0 = src_accum;\n"
414
-
415
- code ="\
416
- \n\
417
- FragmentA0 tb_frag_A;\n\
418
- FragmentB0 tb_frag_B0;\n\
419
- \n\
420
- tb_frag_A.clear();\n\
421
- tb_frag_B0.clear();\n\
422
- \n\
423
- // The last kblock is loaded in the prolog\n\
424
- iterator_A.load(tb_frag_A);\n\
425
- iterator_B0.load(tb_frag_B0);\n\
426
- \n\
427
- ++iterator_A;\n\
428
- ++iterator_B0;\n\
429
- \n\
430
- this->smem_iterator_A_.store(tb_frag_A);\n\
431
- this->smem_iterator_B0_.store(tb_frag_B0);\n\
432
- \n\
433
- ++this->smem_iterator_A_;\n\
434
- ++this->smem_iterator_B0_;\n\
435
- \n\
436
- __syncthreads();\n\
437
- \n\
438
- // Pair of fragments used to overlap shared memory loads and math instructions\n\
439
- WarpFragmentA0 warp_frag_A0[2];\n\
440
- WarpFragmentB0 warp_frag_B0[2];\n\
441
- \n\
442
- this->warp_tile_iterator_A0_.set_kgroup_index(0);\n\
443
- this->warp_tile_iterator_B0_.set_kgroup_index(0);\n\
444
- \n\
445
- this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);\n\
446
- this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);\n\
447
- \n\
448
- ++this->warp_tile_iterator_A0_;\n\
449
- ++this->warp_tile_iterator_B0_;\n\
450
- \n\
451
- Operator0 warp_mma0;\n\
452
- \n\
453
- int smem_write_stage_idx = 1;\n\
454
- \n\
455
- // Avoid reading out of bounds\n\
456
- if (gemm_k_iterations_0 <= 1) {\n\
457
- iterator_A.clear_mask();\n\
458
- iterator_B0.clear_mask();\n\
459
- }\n\
460
- \n\
461
- // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\
462
- // shared memory loads (which have the tightest latency requirement).\n\
463
- iterator_A.load(tb_frag_A);\n\
464
- \n\
465
- //\n\
466
- // Mainloop\n\
467
- //\n\
468
- \n\
469
- // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\
470
- CUTLASS_GEMM_LOOP\n\
471
- for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\
472
- \n\
473
- //\n\
474
- // Loop over GEMM K dimension\n\
475
- //\n\
476
- \n\
477
- CUTLASS_PRAGMA_UNROLL\n\
478
- for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\
479
- \n\
480
- // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\
481
- // as the case may be.\n\
482
- \n\
483
- if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {\n\
484
- \n\
485
- // Write fragments to shared memory\n\
486
- this->smem_iterator_A_.store(tb_frag_A);\n\
487
- \n\
488
- this->smem_iterator_B0_.store(tb_frag_B0);\n\
489
- \n\
490
- __syncthreads();\n\
491
- \n\
492
- // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\
493
- // shared memory loads (which have the tightest latency requirement).\n\
494
- iterator_A.load(tb_frag_A);\n\
495
- \n\
496
- ++this->smem_iterator_B0_;\n\
497
- ++this->smem_iterator_A_;\n\
498
- \n\
499
- \n\
500
- // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory\n\
501
- if (smem_write_stage_idx == 1) {\n\
502
- this->smem_iterator_A_.add_tile_offset({0, -Base::Stage0});\n\
503
- this->smem_iterator_B0_.add_tile_offset({-Base::Stage0, 0});\n\
504
- }\n\
505
- else {\n\
506
- this->warp_tile_iterator_A0_.add_tile_offset(\n\
507
- {0, -Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\
508
- this->warp_tile_iterator_B0_.add_tile_offset(\n\
509
- {-Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0,\n\
510
- 0});\n\
511
- }\n\
512
- \n\
513
- smem_write_stage_idx ^= 1;\n\
514
- }\n\
515
- \n\
516
- this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\
517
- this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\
518
- \n\
519
- this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);\n\
520
- this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);\n\
521
- \n\
522
- ++this->warp_tile_iterator_A0_;\n\
523
- ++this->warp_tile_iterator_B0_;\n\
524
- \n\
525
- if (warp_mma_k == 0) {\n\
526
- \n\
527
- iterator_B0.load(tb_frag_B0);\n\
528
- \n\
529
- ++iterator_A;\n\
530
- ++iterator_B0;\n\
531
- \n\
532
- // Avoid reading out of bounds if this was the last loop iteration\n\
533
- if (gemm_k_iterations_0 <= 2) {\n\
534
- iterator_A.clear_mask();\n\
535
- iterator_B0.clear_mask();\n\
536
- }\n\
537
- }\n\
538
- \n\
539
- warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);\n\
540
- }\n\
541
- }\n"
542
- return accu_code + code
543
-
544
- def gen_other_gemms_2stage(b2b_num):
545
-
546
- code = ""
547
-
548
- def gemm_teamplate(id):
549
- code = "// " + str(id + 1) + " Gemm"
550
- code += " /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile\n"
551
-
552
- code += " " + helper.var_idx("FragmentC", id - 1) + helper.var_idx(" after_epilogue_accu", id - 1) + ";\n"
553
- code += " " + helper.var_idx("epilogue_", id - 1) + helper.var_idx("(output_op_", id - 1) + helper.var_idx(", accum", id - 1) \
554
- + helper.var_idx(", after_epilogue_accu", id - 1) + helper.var_idx(", iterator_C", id - 1) +");\n"
555
-
556
- # FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
557
- code += " " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx(" warp_tile_iterator_A", id) +"_(" + helper.var_idx("after_epilogue_accu", id - 1) + ");\n"
558
- # FragmentB1 tb_frag_B1;
559
- code += " " + helper.var_idx("FragmentB", id) + " " + helper.var_idx("tb_frag_B", id) + ";\n"
560
- # tb_frag_B1.clear();
561
- code += " " + helper.var_idx("tb_frag_B", id) + ".clear();\n"
562
- # iterator_B1.load(tb_frag_B1);
563
- code += " " + helper.var_idx("iterator_B", id) + ".load(" + helper.var_idx("tb_frag_B", id) + ");\n"
564
- # ++iterator_B1;
565
- code += " " + "++" + helper.var_idx("iterator_B", id) + ";\n"
566
- # this->smem_iterator_B1_.store(tb_frag_B1);
567
- code += " " + helper.var_idx("this->smem_iterator_B", id) + "_.store(" + helper.var_idx("tb_frag_B", id) + ");\n"
568
- # ++this->smem_iterator_B1_;
569
- code += " " + helper.var_idx("++this->smem_iterator_B", id) + "_;\n"
570
- # __syncthreads();
571
- code += " " + "__syncthreads();\n"
572
- # WarpFragmentA1 warp_frag_A1[2];
573
- code += " " + helper.var_idx("WarpFragmentA", id) + helper.var_idx(" warp_frag_A", id) + "[2];\n"
574
- # WarpFragmentB1 warp_frag_B1[2];
575
- code += " " + helper.var_idx("WarpFragmentB", id) + helper.var_idx(" warp_frag_B", id) + "[2];\n"
576
- # this->warp_tile_iterator_B1_.set_kgroup_index(0);
577
- code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.set_kgroup_index(0);\n"
578
- # warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
579
- code += " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[0]);\n"
580
- # this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
581
- code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[0]);\n"
582
- # ++warp_tile_iterator_A1_;
583
- code += " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n"
584
- # ++this->warp_tile_iterator_B1_;
585
- code += " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n"
586
- # Operator1 warp_mma1;
587
- code += " " + helper.var_idx("Operator", id) + " " + helper.var_idx("warp_mma", id) + ";\n"
588
- # smem_write_stage_idx = 1;
589
- code += " " + "smem_write_stage_idx = 1;\n"
590
- # int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
591
- code += " " + helper.var_idx("int gemm_k_iterations_", id) + " = " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx("::Policy::kIterations / Base::kWarpGemmIterations", id) +";\n"
592
- # if (gemm_k_iterations_1 <= 1) {
593
- # iterator_B1.clear_mask();
594
- # }
595
- code += " " + "if (" + helper.var_idx("gemm_k_iterations_", id) + " <= 1 ){\n" \
596
- + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \
597
- + " " +"}\n"
598
- # CUTLASS_PRAGMA_UNROLL
599
- code += " " + "CUTLASS_PRAGMA_UNROLL\n"
600
- # for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
601
- code += " " + helper.var_idx("for (; gemm_k_iterations_", id) + helper.var_idx(" > 0; --gemm_k_iterations_", id) + ") {\n"
602
- # CUTLASS_PRAGMA_UNROLL
603
- code += " " + " " + "CUTLASS_PRAGMA_UNROLL\n"
604
- # for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
605
- code += " " + " " + helper.var_idx("for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations", id) + "; ++warp_mma_k) {\n"
606
- # if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
607
- code += " " + " " + " " + helper.var_idx("if (warp_mma_k == Base::kWarpGemmIterations", id) + " - 1) {\n"
608
- # this->smem_iterator_B1_.store(tb_frag_B1);
609
- code += " " + " " + " " + " " + helper.var_idx(" this->smem_iterator_B", id) + helper.var_idx("_.store(tb_frag_B", id) + ");\n"
610
- # __syncthreads();
611
- code += " " + " " + " " + " " + "__syncthreads();\n"
612
- # ++smem_iterator_B1_;
613
- code += " " + " " + " " + " " + helper.var_idx(" ++smem_iterator_B", id) + "_;\n"
614
- # if (smem_write_stage_idx == 1) {
615
- # smem_iterator_B1_.add_tile_offset({-Base::Stage, 0});
616
- # }
617
- code += " " + " " + " " + " " + "if ( smem_write_stage_idx == 1 ) {\n" \
618
- + " " + " " + " " + " " + " " + helper.var_idx("smem_iterator_B", id) + helper.var_idx("_.add_tile_offset({-Base::Stage", i) + ", 0});\n" \
619
- + " " + " " + " " + " " +"}\n"
620
- # else {
621
- # this->warp_tile_iterator_B1_.add_tile_offset(
622
- # {-Base::Stage * Policy1::kPartitionsK *
623
- # Base::kWarpGemmIterations1,
624
- # 0});
625
- # }
626
- code += " " + " " + " " + " " + "else {\n" \
627
- + " " + " " + " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.add_tile_offset(\n" \
628
- + " " + " " + " " + " " + " " + helper.var_idx("{-Base::Stage", id) + helper.var_idx(" * Policy", id) + "::kPartitionsK *\n" \
629
- + " " + " " + " " + " " + " " + helper.var_idx("Base::kWarpGemmIterations", id) + ",\n" \
630
- + " " + " " + " " + " " + " " + "0});\n" \
631
- + " " + " " + " " + " " + "}\n"
632
-
633
- # smem_write_stage_idx ^= 1;
634
- # }
635
- code += " " + " " + " " + " " + "smem_write_stage_idx ^= 1;\n" \
636
- + " " + " " + " " + "}\n"
637
-
638
- # this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
639
- code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations", id) + ");\n"
640
- # warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
641
- code += " " + " " + " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[(warp_mma_k + 1) % 2]);\n"
642
- # this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
643
- code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[(warp_mma_k + 1) % 2]);\n"
644
- # ++warp_tile_iterator_A1_;
645
- code += " " + " " + " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n"
646
- # ++this->warp_tile_iterator_B1_;
647
- code += " " + " " + " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n"
648
- # if (warp_mma_k == 0) {
649
- # iterator_B1.load(tb_frag_B1);
650
- # ++iterator_B1;
651
- # if (gemm_k_iterations_1 <= 2) {
652
- # iterator_B1.clear_mask();
653
- # }
654
- # }
655
- code += " " + " " + " " + " if (warp_mma_k == 0) {\n" \
656
- + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + helper.var_idx(".load(tb_frag_B", id) + ");\n" \
657
- + " " + " " + " " + " " + helper.var_idx("++iterator_B", id) +";\n" \
658
- + " " + " " + " " + " " + helper.var_idx("if (gemm_k_iterations_", id) +" <= 2) {\n" \
659
- + " " + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \
660
- + " " + " " + " " + " " + "}\n" \
661
- + " " + " " + " " + "}\n"
662
- # warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
663
- # }
664
- # }
665
- code += " " + " " + " " + helper.var_idx("warp_mma", id) + helper.var_idx("(accum", id) + helper.var_idx(", warp_frag_A", id) + helper.var_idx("[warp_mma_k % 2], warp_frag_B", id) + helper.var_idx("[warp_mma_k % 2], accum", id) + ");\n" \
666
- + " " + " " + "}\n" \
667
- + " " + "}\n\n\n"
668
-
669
- return code
670
-
671
- for i in range (1, b2b_num):
672
- clear_accu = ""
673
- if i != b2b_num - 1:
674
- clear_accu = " " + helper.var_idx("FragmentC", i) + helper.var_idx(" accum", i) +";\n"
675
- clear_accu += " " + helper.var_idx("accum", i) +".clear();\n"
676
- code += clear_accu + gemm_teamplate(i)
677
-
678
- return code
679
-
680
- operator_code = " CUTLASS_DEVICE\n\
681
- void operator()(\n " + gen_operator_param(self.b2b_num) + ") {\n"
682
- if first_use_1stage:
683
- operator_code += gen_first_gemm_1stage(self.b2b_num)
684
- else:
685
- operator_code += gen_first_gemm_2stage(self.b2b_num)
686
- operator_code += gen_other_gemms_2stage(self.b2b_num) + "}\n"
687
- return operator_code
688
-
689
- def gen_construct_func(self):
690
- name = self.gen_class_name
691
- func_code = "CUTLASS_DEVICE\n"
692
- func_code += name + "(\n" \
693
- + " " + "typename Base::B2bMmaSharedStorage &shared_storage,\n" \
694
- + " " + "int thread_idx,\n" \
695
- + " " + "int warp_idx,\n" \
696
- + " " + "int lane_idx\n" \
697
- + "):\n"
698
- func_code += " " + "Base(shared_storage, thread_idx, warp_idx, lane_idx),\n" \
699
- + " " + "smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),\n"
700
-
701
- for i in range(self.b2b_num):
702
- final = ",\n"
703
- if i == self.b2b_num - 1:
704
- final = " {\n"
705
- func_code += helper.var_idx("smem_iterator_B", i) + helper.var_idx("_(shared_storage.sharedStorage", i) +".operand_B_ref(), thread_idx)" + final
706
-
707
- func_code += " " + "int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);\n"
708
- func_code += " " + "int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);\n"
709
-
710
- func_code += " " + "int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;\n"
711
- func_code += " " + "int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;\n"
712
-
713
- for i in range(self.b2b_num):
714
- func_code += " " + helper.var_idx("int tile_offset_k", i) + helper.var_idx(" = Base::kWarpGemmIterations", i) + " * warp_idx_k;\n"
715
-
716
- func_code += " " + "this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k0});\n"
717
-
718
- for i in range(self.b2b_num):
719
- func_code += " " + helper.var_idx("this->warp_tile_iterator_B", i) + helper.var_idx("_.add_tile_offset({tile_offset_k", i) + ", warp_idx_n});\n"
720
-
721
- func_code += "}\n"
722
-
723
- return func_code
724
-
725
- def gen_member_func(self, first_use_1stage):
726
- code = "public:\n"
727
- code += self.gen_operator(first_use_1stage)
728
- code += self.gen_construct_func()
729
-
730
- return code
731
-
732
- def gen_code(self, first_use_1stage):
733
-
734
- def gen_template_args(b2b_num):
735
- template_param = []
736
- template_param.append(("typename", "Shape0"))
737
- template_param.append(("typename", "IteratorA0"))
738
- template_param.append(("typename", "SmemIteratorA0"))
739
- template_param.append(("typename", "IteratorB0"))
740
- template_param.append(("typename", "SmemIteratorB0"))
741
-
742
- for i in range(1, b2b_num):
743
- template_param.append(("typename", helper.var_idx("Shape", i)))
744
- template_param.append(("typename", helper.var_idx("FragmentIteratorA", i)))
745
- template_param.append(("typename", helper.var_idx("IteratorB", i)))
746
- template_param.append(("typename", helper.var_idx("SmemIteratorB", i)))
747
-
748
- template_param.append(("typename", "ElementC"))
749
- template_param.append(("typename", "LayoutC"))
750
-
751
- for i in range(0, b2b_num - 1):
752
- template_param.append(("typename", helper.var_idx("OutputOp", i)))
753
-
754
- for i in range(0, b2b_num - 1):
755
- template_param.append(("typename", helper.var_idx("FusedAddBiasEpilogue", i)))
756
-
757
- for i in range(0, b2b_num):
758
- template_param.append(("typename", helper.var_idx("Policy", i)))
759
- for i in range(0, b2b_num):
760
- template_param.append((int, helper.var_idx("Stage", i)))
761
-
762
- template_param.append(("typename","TransformA0", "NumericArrayConverter<typename SmemIteratorA0_::Element, typename IteratorA0_::Element, IteratorA0_::Fragment::kElements>"))
763
-
764
- for i in range(0, b2b_num):
765
- cvtr = helper.var_idx("NumericArrayConverter<typename SmemIteratorB", i) + helper.var_idx("_::Element, typename IteratorB", i) + helper.var_idx("_::Element, IteratorB", i) + "_::Fragment::kElements>"
766
- template_param.append(("typename", helper.var_idx("TransformB", i), cvtr))
767
-
768
- template_param.append(("typename", "Enable", "bool"))
769
-
770
- return template_param
771
-
772
- template_param = gen_template_args(self.b2b_num)
773
- inheritance_code = "public B2bMmaBase<"
774
- for i in range(self.b2b_num):
775
- inheritance_code += helper.var_idx("Shape", i) + "_, "
776
- for i in range(self.b2b_num):
777
- inheritance_code += helper.var_idx("Policy", i) + "_, "
778
- for i in range(self.b2b_num - 1):
779
- inheritance_code += helper.var_idx("Stage", i) + "_, "
780
- inheritance_code += helper.var_idx("Stage", self.b2b_num - 1) + "_"
781
- inheritance_code += ">"
782
-
783
- code_body = ""
784
- using_code= self.gen_using()
785
- func_code = self.gen_member_func(first_use_1stage)
786
-
787
- code_body = using_code + func_code
788
-
789
- class_code = gen_ir.gen_template_class(self.gen_class_name, template_param, code_body, inheritance_code = inheritance_code)
790
-
791
- code = self.gen_include_header()
792
- code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code)))
793
- # print(code)
794
- return code
795
-
796
-
797
- class gen_b2b_mma_base:
798
- def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
799
- self.gen_class_name = gen_class_name
800
- self.template_param = template_param
801
- self.b2b_num = b2b_num
802
- self.cutlass_deps_root = cutlass_deps_root
803
- self.project_root = project_root
804
-
805
- def gen_include_header(self):
806
- code = '''
807
- #pragma once
808
-
809
- #include \"{cutlass_dirs}cutlass/aligned_buffer.h\"
810
- #include \"{cutlass_dirs}cutlass/arch/memory.h\"
811
- #include \"{cutlass_dirs}cutlass/array.h\"
812
- #include \"{cutlass_dirs}cutlass/cutlass.h\"
813
- #include \"{cutlass_dirs}cutlass/gemm/gemm.h\"
814
- #include \"{cutlass_dirs}cutlass/matrix_shape.h\"
815
- #include \"{cutlass_dirs}cutlass/numeric_types.h\"\n'''.format(cutlass_dirs=self.cutlass_deps_root)
816
- return code
817
-
818
- def gen_shared_storage(self):
819
- code = \
820
- " template< \n\
821
- typename Shape_,\n\
822
- typename Policy_,\n\
823
- int ThisStage_\n\
824
- >\n\
825
- class SharedStorage {\n\
826
- public:\n\
827
- using Shape = Shape_;\n\
828
- using Policy = Policy_;\n\
829
- static int const ThisStage = ThisStage_;\n\
830
- using Operator = typename Policy::Operator;\n\
831
- \
832
- using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;\n\
833
- \
834
- /// Tensor reference to the B operand \n\
835
- using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;\n\
836
- \n\
837
- /// Shape of the A matrix operand in shared memory \n\
838
- using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,\n\
839
- Shape::kK * ThisStage +\n\
840
- Policy::SmemPaddingA::kColumn>;\n\
841
- \n\
842
- /// Shape of the B matrix operand in shared memory\n\
843
- using ShapeB =\n\
844
- MatrixShape<Shape::kK * ThisStage + Policy::SmemPaddingB::kRow,\n\
845
- Shape::kN + Policy::SmemPaddingB::kColumn>;\n\
846
- \n\
847
- public:\n\
848
- \n\
849
- /// Buffer for A operand\n\
850
- AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;\n\
851
- \n\
852
- /// Buffer for B operand\n\
853
- AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;\n\
854
- \n\
855
- public:\n\
856
- \n\
857
- /// Returns a layout object for the A matrix\n\
858
- CUTLASS_DEVICE\n\
859
- static typename Operator::LayoutA LayoutA() {\n\
860
- return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});\n\
861
- }\n\
862
- \n\
863
- /// Returns a layout object for the B matrix\n\
864
- CUTLASS_HOST_DEVICE\n\
865
- static typename Operator::LayoutB LayoutB() {\n\
866
- return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});\n\
867
- }\n\
868
- \n\
869
- /// Returns a TensorRef to the A operand\n\
870
- CUTLASS_HOST_DEVICE\n\
871
- TensorRefA operand_A_ref() {\n\
872
- return TensorRefA{operand_A.data(), LayoutA()};\n\
873
- }\n\
874
- \n\
875
- /// Returns a TensorRef to the B operand\n\
876
- CUTLASS_HOST_DEVICE\n\
877
- TensorRefB operand_B_ref() {\n\
878
- return TensorRefB{operand_B.data(), LayoutB()};\n\
879
- }\n\
880
- CUTLASS_HOST_DEVICE\n\
881
- void * get_B_Shared_ptr() {\n\
882
- return operand_B.data();\n\
883
- }\n\
884
- };\n"
885
- return code
886
-
887
- def gen_using_and_misc(self, b2b_num):
888
- code_using = ""
889
- for i in range(b2b_num):
890
- code_using += "using Operator" +str(i) + " = typename Policy" + str(i) +"::Operator;\n"
891
-
892
- for i in range(b2b_num):
893
- code_using += "using WarpGemm" +str(i) + " = typename Policy" + str(i) +"::Operator::Shape;\n"
894
-
895
- for i in range(b2b_num):
896
- code_using += "using WarpCount" +str(i) + " = GemmShape<" + helper.var_idx("Shape", i) +"::kM / " + helper.var_idx("WarpGemm", i) +"::kM, "\
897
- + helper.var_idx("Shape", i) +"::kN / " + helper.var_idx("WarpGemm", i) +"::kN, "\
898
- + helper.var_idx("Shape", i) +"::kK / " + helper.var_idx("WarpGemm", i) +"::kK>;\n"
899
-
900
- code_misc = ""
901
- for i in range(b2b_num):
902
- code_misc += "static int const " + helper.var_idx("kWarpGemmIterations", i) + " = (" + helper.var_idx("WarpGemm", i) + "::kK / " + helper.var_idx("Operator", i) +"::Policy::MmaShape::kK);\n"
903
-
904
- code = code_using + code_misc + self.gen_shared_storage()
905
-
906
- for i in range(b2b_num):
907
- code += "using " + helper.var_idx("SharedStorage", i) + " = SharedStorage<" + helper.var_idx("Shape", i) + ", " + helper.var_idx("Policy", i) +", " + helper.var_idx("Stage", i) + ">;\n"
908
-
909
- def gen_union_shared_storage(b2b_num):
910
- code = ""
911
- for i in range(b2b_num):
912
- code += " " +helper.var_idx("SharedStorage", i) + " " + helper.var_idx("sharedStorage", i) +";\n"
913
- return code
914
-
915
- code += "union B2bMmaSharedStorage {\n" + gen_union_shared_storage(self.b2b_num) + "};\n"
916
-
917
- for i in range(b2b_num - 1):
918
- code += helper.var_idx("void * C", i) + "_smm_ptr;\n"
919
-
920
- return code
921
-
922
- def gen_protected(self):
923
- code = "\nprotected:\n"
924
- code += "typename Operator0::IteratorA warp_tile_iterator_A0_;\n"
925
- for i in range(self.b2b_num):
926
- code += "typename Operator" +str(i) + "::IteratorB" +" warp_tile_iterator_B" + str(i) + "_;\n"
927
- return code
928
-
929
- def gen_public_member(self):
930
- code = "\npublic:\n"
931
-
932
- code += "CUTLASS_DEVICE\n"
933
- code += \
934
- "B2bMmaBase(\n" + \
935
- " B2bMmaSharedStorage & shared_storage,\n" + \
936
- " int thread_idx,\n" + \
937
- " int warp_idx,\n" + \
938
- " int lane_idx\n" + \
939
- "):\n" + \
940
- " warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),\n"
941
- for i in range(self.b2b_num):
942
- final = ",\n"
943
- if i == self.b2b_num-1:
944
- final = "\n"
945
-
946
- iterator = " warp_tile_iterator_B" + str(i) + "_"
947
- shared_storage = "shared_storage.sharedStorage" + str(i) + ".operand_B_ref()"
948
- code += iterator + "(" + shared_storage + ", lane_idx)" + final
949
-
950
-
951
- code += "{\n"
952
- for i in range(self.b2b_num - 1):
953
- code += helper.var_idx(" C", i) + helper.var_idx("_smm_ptr = shared_storage.sharedStorage", i) + ".get_B_Shared_ptr();\n"
954
- code += "}\n"
955
-
956
- return code
957
-
958
- def gen_code(self):
959
-
960
- template_arg = []
961
- for i in range(self.b2b_num):
962
- template_arg.append(("typename", helper.var_idx("Shape", i)))
963
- for i in range(self.b2b_num):
964
- template_arg.append(("typename", helper.var_idx("Policy", i)))
965
- for i in range(self.b2b_num):
966
- template_arg.append((int, helper.var_idx("Stage", i)))
967
-
968
-
969
-
970
- code_body = self.gen_using_and_misc(self.b2b_num)
971
- code_body += self.gen_protected()
972
- code_body += self.gen_public_member()
973
-
974
- class_code = gen_ir.gen_template_class("B2bMmaBase", template_arg, code_body)
975
-
976
- code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code)))
977
-
978
- return code
979
-
980
-
981
- class gen_threadblock:
982
- def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root):
983
- self.gen_class_name = gen_class_name
984
- self.template_param = template_param
985
- self.b2b_num = b2b_num
986
- self.file_dir = output_dir + "/threadblock/"
987
-
988
- self.cutlass_deps_root = cutlass_deps_root
989
- self.project_root = project_root
990
-
991
-
992
- self.gen_b2b_mma_base = gen_b2b_mma_base(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
993
- self.gen_b2b_mma_pipelined = gen_b2b_mme_pipelined(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
994
- self.gen_default_b2b_mma = gen_default_b2b_mma(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
995
-
996
-
997
- def gen_code(self, first_use_1stage):
998
-
999
- base_code = self.gen_b2b_mma_base.gen_code()
1000
- print("[INFO]: Gen kernel code [b2b_mma_base.h]output Dir: is ", self.file_dir)
1001
-
1002
- with open(self.file_dir + "b2b_mma_base.h", "w+") as f:
1003
- f.write(base_code)
1004
- pipeline_code = self.gen_b2b_mma_pipelined.gen_code(first_use_1stage = first_use_1stage)
1005
- print("[INFO]: Gen kernel code [b2b_mma_pipelined.h]output Dir: is ", self.file_dir)
1006
-
1007
- with open(self.file_dir + "b2b_mma_pipelined.h", "w+") as f:
1008
- f.write(pipeline_code)
1009
- default_code = self.gen_default_b2b_mma.gen_code()
1010
- print("[INFO]: Gen kernel code [default_b2b_mma.h]output Dir: is ", self.file_dir)
1011
-
1012
- with open(self.file_dir + "default_b2b_mma.h", "w+") as f:
1013
- f.write(default_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py DELETED
@@ -1,456 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import helper
34
- import gen_ir as ir
35
-
36
- class gen_turing_impl:
37
- def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
38
- self.fuse_gemm_info = fuse_gemm_info
39
- self.class_name = gen_class_name
40
- self.gen_class_name = gen_class_name + "_turing_impl"
41
- self.user_header_file = ""
42
- for header in user_header_file:
43
- self.user_header_file += "#include \"" + header + "\"\n"
44
- self.output_dir = output_dir
45
- self.b2b_num = len(fuse_gemm_info)
46
-
47
- self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
48
-
49
- def gen_using(self):
50
- code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + "<cutlass::half_t>;"
51
-
52
- return code_using + "\n"
53
-
54
- def gen_initialize(self):
55
- code = ""
56
- for i in range(self.b2b_num):
57
- code_this = ""
58
-
59
- code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
60
- beta = "(1)"
61
-
62
- if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False:
63
- beta = "(0)"
64
- code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
65
- k_str = str(self.fuse_gemm_info[i]['mnk'][2])
66
- if i == 0:
67
- k_str = "K0"
68
- code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
69
- code += code_this
70
- code += "typename b2b_gemm::Arguments arguments{\n"
71
-
72
- for i in range(self.b2b_num):
73
- code += " " + helper.var_idx("problem_size_", i) + ",\n"
74
-
75
-
76
- code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n"
77
-
78
- for i in range(self.b2b_num):
79
-
80
- ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
81
- if i == 0:
82
- ldmB = "K0"
83
-
84
- if self.fuse_gemm_info[i]['B_format'] is 'Row':
85
- ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
86
-
87
- ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
88
-
89
- code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n"
90
- code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n"
91
- code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n"
92
-
93
-
94
- for i in range(self.b2b_num):
95
- code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
96
- for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
97
- arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
98
- code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
99
- code += "},\n"
100
- code += " " + "Batch};\n\n"
101
-
102
- code += " " "b2b_gemm gemm_op;\n"
103
- code += " " + "gemm_op.initialize(arguments);\n"
104
- return code + "\n"
105
-
106
-
107
-
108
- def gen_run(self):
109
- code = " " + "gemm_op(stream);\n"
110
-
111
- return code
112
-
113
- def gen_wrapper(self):
114
- code_body = ""
115
-
116
- arg_lists = []
117
- arg_lists.append(["int", "M"])
118
- arg_lists.append(["int", "K0"])
119
- arg_lists.append(["int", "Batch"])
120
- arg_lists.append(["void*", helper.var_idx("A", 0)])
121
- for i in range(self.b2b_num):
122
- arg_lists.append(["void*", helper.var_idx("B", i)])
123
- arg_lists.append(["void*", helper.var_idx("C", i)])
124
- arg_lists.append(["void*", helper.var_idx("D", i)])
125
- epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
126
- acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
127
- for arg in epilogue_args:
128
- arg_tp = arg[0]
129
- arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
130
- arg_lists.append([arg_tp, arg_name])
131
-
132
- if self.b2b_num == 1:
133
- code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta
134
- code_body += self.gen_turing_unfused.gen_initialize()
135
- code_body += self.gen_turing_unfused.gen_run()
136
- else:
137
- code_body += self.gen_using()
138
- code_body += self.gen_initialize()
139
- code_body += self.gen_run()
140
-
141
- code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
142
-
143
- return code
144
-
145
- def gen_code(self):
146
-
147
- code = self.gen_wrapper()
148
- helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code)
149
-
150
- class gen_volta_turing_fuse_act_impl:
151
- def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
152
- self.fuse_gemm_info = fuse_gemm_info
153
- self.gen_class_name = gen_class_name + "_volta_impl"
154
- self.user_header_file = ""
155
- for header in user_header_file:
156
- self.user_header_file += "#include \"" + header + "\"\n"
157
- self.output_dir = output_dir
158
- self.b2b_num = len(fuse_gemm_info)
159
-
160
- def perf_tiling(self, layer_mnk):
161
- mnk = layer_mnk[:]
162
- block_tile = mnk[:]
163
- block_tile[2] = 32 # force the K tile to be 32
164
-
165
- # M tile gen
166
- block_tile[0] = 32
167
-
168
- # N tile gen
169
- if mnk[1] > 128:
170
- block_tile[1] = 256
171
- elif mnk[1] > 64:
172
- block_tile[1] = 128
173
- elif mnk[1] > 32:
174
- block_tile[1] = 64
175
- else :
176
- block_tile[1] = 32
177
-
178
- warp_tile = block_tile[:]
179
- if block_tile[1] == 256:
180
- warp_tile[1] = 64
181
- elif block_tile[1] == 128:
182
- warp_tile[1] = 32
183
- elif block_tile[1] == 64:
184
- warp_tile[1] = 32
185
- else :
186
- warp_tile[1] = 32
187
-
188
- warp_tile[0] = 32
189
-
190
- return block_tile, warp_tile
191
-
192
-
193
- def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp):
194
- epilogue_setted_type = epilogue_tp
195
- cutlass_epilogue_name = "LinearCombinationRelu"
196
- if epilogue_setted_type.lower() == 'leakyrelu':
197
- cutlass_epilogue_name = "LinearCombinationLeakyRelu"
198
- elif epilogue_setted_type.lower() == 'identity':
199
- cutlass_epilogue_name = "LinearCombination"
200
-
201
-
202
- n_mod_8 = n % 4
203
- N_align_elements = 1
204
- if n_mod_8 == 0:
205
- N_align_elements = 8
206
- elif n_mod_8 == 4:
207
- N_align_elements = 4
208
- elif n_mod_8 == 2 or n_mod_8 == 6:
209
- N_align_elements = 2
210
-
211
- epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">"
212
-
213
- return epilogue_str
214
-
215
- def gen_using(self, volta = True):
216
- code_using = ""
217
- volta_arch = "cutlass::arch::Sm70"
218
- volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>"
219
-
220
- turing_arch = "cutlass::arch::Sm75"
221
- turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>"
222
-
223
- arch = ""
224
- tc = ""
225
- if volta:
226
- arch = volta_arch
227
- tc = volta_tc
228
- else:
229
- arch = turing_arch
230
- tc = turing_tc
231
-
232
- for i in range(self.b2b_num):
233
-
234
- k = self.fuse_gemm_info[i]['mnk'][2]
235
-
236
- k_mod_8 = k % 4
237
- ab_ldm = 1
238
- if k_mod_8 == 0:
239
- ab_ldm = 8
240
- elif k_mod_8 == 4:
241
- ab_ldm = 4
242
- elif k_mod_8 == 2 or k_mod_8 == 6:
243
- ab_ldm = 2
244
-
245
- block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk'])
246
-
247
- this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n"
248
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n"
249
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n"
250
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n"
251
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n"
252
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n"
253
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n"
254
- this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n"
255
- this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n"
256
- this_gemm_config += " " + arch + ",\n"
257
- this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n"
258
- this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n"
259
- this_gemm_config += " " + tc + ",\n"
260
- this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n"
261
- this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n"
262
- this_gemm_config += " " + "2,\n"
263
- this_gemm_config += " " + str(ab_ldm) + ",\n"
264
- this_gemm_config += " " + str(ab_ldm) + ">;\n"
265
-
266
- code_using += this_gemm_config + "\n"
267
-
268
- return code_using + "\n"
269
-
270
- def gen_initialize(self):
271
- code = ""
272
- for i in range(self.b2b_num):
273
- code_this = ""
274
-
275
- N_str = str(self.fuse_gemm_info[i]['mnk'][1])
276
-
277
- code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
278
- beta = "(1)"
279
- if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False:
280
- beta = "(0)"
281
- code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
282
-
283
- k_str = str(self.fuse_gemm_info[i]['mnk'][2])
284
- if i == 0:
285
- k_str = "K0"
286
- code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
287
- code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
288
- code_this += " " + helper.var_idx("problem_size_", i) + ",\n"
289
- ldmA = k_str
290
- ldmB = k_str
291
- ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
292
-
293
- ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
294
-
295
- if self.fuse_gemm_info[i]['A_format'] is 'Col':
296
- ldmA = "M"
297
- if self.fuse_gemm_info[i]['B_format'] is 'Row':
298
- ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
299
- if self.fuse_gemm_info[i]['C_format'] is 'Col':
300
- ldmC = "M"
301
-
302
- if i == 0:
303
- code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
304
- else:
305
- code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
306
-
307
- code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
308
-
309
- M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
310
-
311
- code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
312
- code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n"
313
- code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
314
- for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
315
- arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
316
- code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
317
- code_this += " },\n"
318
- code_this += " " + "Batch};\n"
319
-
320
- code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
321
- code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n"
322
-
323
- code += code_this + "\n"
324
- return code + "\n"
325
-
326
-
327
- def gen_run(self):
328
- code = ""
329
- for i in range(self.b2b_num):
330
- code_this = ""
331
- code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n"
332
-
333
- code += code_this
334
- return code
335
-
336
- def gen_wrapper(self):
337
- code_body = ""
338
-
339
- arg_lists = []
340
- arg_lists.append(["int", "M"])
341
- arg_lists.append(["int", "K0"])
342
- arg_lists.append(["int", "Batch"])
343
- arg_lists.append(["void*", helper.var_idx("A", 0)])
344
- for i in range(self.b2b_num):
345
- arg_lists.append(["void*", helper.var_idx("B", i)])
346
- arg_lists.append(["void*", helper.var_idx("C", i)])
347
- arg_lists.append(["void*", helper.var_idx("D", i)])
348
- epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
349
- acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
350
- for arg in epilogue_args:
351
- arg_tp = arg[0]
352
- arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
353
- arg_lists.append([arg_tp, arg_name])
354
- code_body += self.gen_using()
355
- code_body += self.gen_initialize()
356
- code_body += self.gen_run()
357
-
358
- code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
359
-
360
- return code
361
-
362
- def gen_code(self):
363
- code = self.gen_wrapper()
364
- helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code)
365
-
366
- class gen_one_API:
367
- def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
368
- self.fuse_gemm_info = fuse_gemm_info
369
- self.gen_class_name = gen_class_name
370
- self.user_header_file = ""
371
- for header in user_header_file:
372
- self.user_header_file += "#include \"" + header + "\"\n"
373
- self.output_dir = output_dir
374
- self.b2b_num = len(fuse_gemm_info)
375
-
376
- self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
377
-
378
- self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
379
-
380
- def gen_CUTLASS_irrelevant_API(self):
381
- code = ""
382
- code += "#include <cuda_runtime.h>\n"
383
- code += "#include <cassert>\n"
384
-
385
- param_name = "Fused" + str(self.b2b_num) + "xGemm_"
386
- for i in range(self.b2b_num):
387
- param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_"
388
- param_name += "Params"
389
- params = ""
390
- params += " " + "int M;\n"
391
- params += " " + "int K0;\n"
392
- params += " " + "int Batch;\n"
393
- params += " " + "const void* A0;\n"
394
- for i in range(self.b2b_num):
395
- params += " " + "const void* " + helper.var_idx("B", i) + ";\n"
396
- params += " " + "const void* " + helper.var_idx("C", i) + ";\n"
397
- epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
398
- acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
399
- for arg in epilogue_args:
400
- arg_tp = arg[0]
401
- arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
402
- params += " " + arg_tp + " " + arg_name + ";\n"
403
- params += " " + "void* " + helper.var_idx("D", i) + ";\n"
404
- code += ir.gen_struct(param_name, params)
405
- code += "using Param = " + param_name + ";\n"
406
- code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n"
407
-
408
-
409
- return code
410
-
411
- def gen_one_api(self):
412
- code = ""
413
- code += "/* Auto Generated code - Do not edit.*/\n"
414
- code += "#include \"cutlass_irrelevant.h\"\n"
415
- code += "#include \"api.h\"\n"
416
- code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n"
417
-
418
- code += " " + "if (sm == 70) \n"
419
- code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
420
- for i in range(self.b2b_num):
421
- code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
422
- code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
423
- code += helper.var_idx("param.D", i) + ", "
424
- epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
425
- for arg in epilogue_args:
426
- arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
427
- code += "param." + arg_name + ", "
428
- code += "stream);\n"
429
- code += " " + "else if(sm >= 75) \n"
430
- code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
431
- for i in range(self.b2b_num):
432
- code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
433
- code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
434
- code += helper.var_idx("param.D", i) + ", "
435
- epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
436
- for arg in epilogue_args:
437
- arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
438
- code += "param." + arg_name + ", "
439
- code += "stream);\n"
440
- code += " " + "else assert(0);\n"
441
- code += "}\n"
442
- return code
443
-
444
- def gen_code(self):
445
-
446
- turing_code = self.gen_turing.gen_wrapper()
447
- volta_code = self.gen_volta.gen_wrapper()
448
- cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API()
449
-
450
- one_api_code = self.gen_one_api()
451
- with open(self.output_dir + "one_api.cu", "w+") as f:
452
- f.write(one_api_code)
453
-
454
- helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code)
455
-
456
- helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py DELETED
@@ -1,92 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import helper
34
- import gen_ir as ir
35
-
36
- import gen_turing_and_volta as gen_basic
37
-
38
-
39
- class gen_verify:
40
- def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
41
- self.fuse_gemm_info = fuse_gemm_info
42
- self.name = gen_class_name + "_verify"
43
- self.b2b_num = len(fuse_gemm_info)
44
- self.params = []
45
- self.user_header_file = ""
46
- for header in user_header_file:
47
- self.user_header_file += "#include \"" + header + "\"\n"
48
- self.separate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
49
- self.gen_params()
50
- self.output_dir = output_dir
51
-
52
-
53
- def gen_code(self):
54
- code = ""
55
- code += self.user_header_file
56
- code += self.separate_cutlass.gen_using(False) #False -> Turing, True -> Volta
57
-
58
- code_body = ""
59
- for i in range(self.b2b_num):
60
- code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
61
- code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n"
62
-
63
- code_body += self.separate_cutlass.gen_run()
64
-
65
- code += ir.gen_func(self.name, self.params, code_body)
66
- helper.write_2_headfile("cutlass_verify.h", self.output_dir, code)
67
-
68
-
69
- def gen_params(self):
70
- for i in range(self.b2b_num):
71
- self.params.append(
72
- (
73
- helper.var_idx("typename Gemm", i)+ "::Arguments",
74
- helper.var_idx("Arguments_", i)
75
- )
76
- )
77
-
78
-
79
- def get_params(self, declaration = True):
80
- code = ""
81
- if declaration:
82
- for param in self.params:
83
- code += param[0] + " " + param[1] + ";\n"
84
-
85
- return code
86
-
87
-
88
- def gen_initialize():
89
- code = ""
90
- initialize_code = self.separate_cutlass.gen_initialize()
91
-
92
- code = ir.gen_func("initialize", [[]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py DELETED
@@ -1,135 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- def type_2_cutlass_type(input_type = "fp16"):
34
- # float point type
35
- if input_type == "fp32":
36
- return "float"
37
- if input_type == "bf16":
38
- return "cutlass::bfloat16_t"
39
- if input_type == "fp16":
40
- return "cutlass::half_t"
41
-
42
- # integer type
43
- if(input_type == "int32"):
44
- return "int32_t"
45
- if(input_type == "int8"):
46
- return "int8_t"
47
-
48
- if input_type == 'Row':
49
- return 'cutlass::layout::RowMajor'
50
- if input_type == 'Col':
51
- return 'cutlass::layout::ColumnMajor'
52
-
53
- def cvt_2_cutlass_shape(gemm_shape):
54
- # gemm shape
55
- if len(gemm_shape) == 3:
56
- val = "cutlass::gemm::GemmShape<" \
57
- + str(gemm_shape[0]) + ", " \
58
- + str(gemm_shape[1]) + ", " \
59
- + str(gemm_shape[2]) + ">"
60
- return val
61
-
62
-
63
- def write_2_headfile(filename, file_dir, string):
64
- with open(file_dir + filename, 'w') as f:
65
- f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
66
-
67
- def var_idx(variable, index):
68
- return variable + str(index)
69
-
70
-
71
- def list_2_string(input_list, ):
72
- rtn_string = ""
73
-
74
- cnt = 0
75
-
76
- for element in input_list:
77
- final = ", \n"
78
- if cnt == len(input_list) - 1:
79
- final = "\n"
80
- cnt += 1
81
- rtn_string += str(element) + final
82
-
83
- return rtn_string
84
-
85
-
86
- def get_epilogue_info(layer_info):
87
- return layer_info['epilogue']
88
-
89
- def get_epilogue_tp(layer_info):
90
- epilogue_info = get_epilogue_info(layer_info)
91
- return epilogue_info['tp']
92
-
93
- def get_epilogue_add_bias_or_not(layer_info):
94
- epilogue_info = get_epilogue_info(layer_info)
95
- return epilogue_info['bias']['addbias']
96
-
97
- def get_epilogue_add_bias_tp(layer_info):
98
- epilogue_info = get_epilogue_info(layer_info)
99
- return epilogue_info['bias']['bias_tp']
100
-
101
- def get_epilogue_args(layer_info):
102
- epilogue_info = get_epilogue_info(layer_info)
103
- return epilogue_info['args']
104
-
105
- def get_epilogue_bias_shape(layer_info):
106
- bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
107
- mn_shape = layer_info['mnk'][:-1]
108
-
109
- if bias_tp == 'mat':
110
- mn_shape[0] = 'M'
111
- return mn_shape
112
- elif bias_tp == 'vec':
113
- mn_shape[0] = 1
114
- return mn_shape
115
- else:
116
- assert(0)
117
-
118
- def get_epilogue_bias_ldm(layer_info):
119
- bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
120
- mn_shape = layer_info['mnk'][:-1]
121
-
122
- c_layout = layer_info['C_format'].lower()
123
-
124
- if c_layout != 'row':
125
- assert(0)
126
-
127
- if bias_tp == 'mat':
128
- return mn_shape[1]
129
- elif bias_tp == 'vec':
130
- return 0
131
- else:
132
- assert(0)
133
-
134
- def get_epilogue_compute_tp(layer_info):
135
- return layer_info['Acc_tp']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py DELETED
@@ -1,67 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import os
34
-
35
- class replace_fix_impl:
36
- def __init__(self, src_dir, dst_dir, cutlass_deps_root):
37
- self.src_dir = src_dir
38
- self.dst_dir = dst_dir
39
- self.cutlass_deps_root = cutlass_deps_root
40
-
41
-
42
-
43
- def gen_code(self):
44
- for sub_dir in os.walk(self.src_dir):
45
- files_in_sub_dir = sub_dir[2]
46
-
47
- src_dirs = sub_dir[0]
48
- output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):]
49
-
50
- if not os.path.exists(output_dirs):
51
- os.mkdir(output_dirs)
52
-
53
- for f in files_in_sub_dir:
54
- with open(src_dirs +"/" + f, 'r') as current_file:
55
- output_lines = []
56
- lines = current_file.readlines()
57
-
58
- for line in lines:
59
- if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"):
60
- new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):]
61
- # print(new_line)
62
- output_lines.append(new_line)
63
- else:
64
- output_lines.append(line)
65
-
66
- with open(output_dirs + "/" + f, "w+") as dest_file:
67
- dest_file.writelines(output_lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h DELETED
@@ -1,292 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
- #include <cuda_fp16.h>
34
-
35
- template <typename T>
36
- __device__
37
- T add(T const & a, T const &b){
38
- return (a + b);
39
- }
40
-
41
- template <>
42
- __device__
43
- half2 add(half2 const & a, half2 const &b){
44
- return (__hadd2(a,b));
45
- }
46
-
47
- template <typename T>
48
- struct RELU{
49
- __device__
50
- T operator()(T const & a){
51
- return a > T(0) ? a : T(0);
52
- }
53
- __device__
54
- half2 operator()(half2 const & a){
55
- float2 a_fp32x2 = __half22float2(a);
56
- a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f;
57
- a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f;
58
- if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f)
59
- printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y);
60
- return __float22half2_rn(a_fp32x2);
61
- }
62
- };
63
-
64
- template <typename T>
65
- struct LEAKY_RELU{
66
- __device__
67
- T operator()(T const & a, T const & scale = half(1)){
68
- return a > T(0) ? a : scale * a;
69
- }
70
- __device__
71
- half2 operator()(half2 const & a, half const & scale = half(1)){
72
- half2 zero = __half2half2(half(0));
73
- half2 gt_zero = __hge2(a, zero);
74
- half2 le_zero = __hle2(a, zero);
75
-
76
-
77
- half2 scale_f16x2 = __half2half2(scale);
78
- half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero);
79
- return __hmul2(a, mask_scale_f16x2);
80
- }
81
- };
82
-
83
- template <int N, int BLOCKDIM>
84
- __global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){
85
-
86
- constexpr bool N_MOD_2 = N & 1 ? false : true;
87
-
88
- using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
89
-
90
- constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
91
-
92
- constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
93
-
94
- LEAKY_RELU<half> Act;
95
- Access_tp src_v[iter];
96
- Access_tp bias_v[iter];
97
-
98
- int batch_id = blockIdx.y;
99
- int batch_offset = batch_id * gridDim.x * N;
100
-
101
- for(int i = 0; i < iter; i++){
102
- int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
103
- if (idx < N){
104
- src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
105
- if (mat_bias)
106
- bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
107
- else
108
- bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
109
- *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale);
110
- }
111
-
112
- }
113
- }
114
-
115
-
116
-
117
- template <int N, int BLOCKDIM>
118
- __global__ void leaky_and_activation(half* inout, half scale){
119
-
120
- constexpr bool N_MOD_2 = N & 1 ? false : true;
121
-
122
- using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
123
-
124
- constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
125
-
126
- constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
127
-
128
- int batch_id = blockIdx.y;
129
- int batch_offset = batch_id * gridDim.x * N;
130
-
131
- LEAKY_RELU<half> Act;
132
- Access_tp src_v[iter];
133
-
134
- for(int i = 0; i < iter; i++){
135
- int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
136
- if (idx < N){
137
- src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
138
- *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale);
139
- }
140
-
141
- }
142
- }
143
-
144
-
145
-
146
- template <int N, int BLOCKDIM>
147
- void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){
148
-
149
- dim3 grid(m, b);
150
- if (bias == nullptr)
151
- leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, scale);
152
- else
153
- leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, scale, mat_bias);
154
- }
155
-
156
- template <int N, int BLOCKDIM>
157
- __global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){
158
-
159
- constexpr bool N_MOD_2 = N & 1 ? false : true;
160
-
161
- using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
162
-
163
- constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
164
-
165
- constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
166
-
167
- RELU<half> Act;
168
- Access_tp src_v[iter];
169
- Access_tp bias_v[iter];
170
-
171
- int batch_id = blockIdx.y;
172
- int batch_offset = batch_id * gridDim.x * N;
173
-
174
- for(int i = 0; i < iter; i++){
175
- int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
176
- if (idx < N){
177
- src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
178
- if (mat_bias)
179
- bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
180
- else
181
- bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
182
- *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]));
183
- }
184
-
185
- }
186
- }
187
-
188
-
189
-
190
- template <int N, int BLOCKDIM>
191
- __global__ void relu_and_activation(half* inout){
192
-
193
- constexpr bool N_MOD_2 = N & 1 ? false : true;
194
-
195
- using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
196
-
197
- constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
198
-
199
- constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
200
-
201
- int batch_id = blockIdx.y;
202
- int batch_offset = batch_id * gridDim.x * N;
203
-
204
- RELU<half> Act;
205
- Access_tp src_v[iter];
206
-
207
- for(int i = 0; i < iter; i++){
208
- int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
209
- if (idx < N){
210
- src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
211
- *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]);
212
- }
213
-
214
- }
215
- }
216
-
217
-
218
-
219
- template <int N, int BLOCKDIM>
220
- void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
221
- dim3 grid(m, b);
222
- if (bias == nullptr)
223
- relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
224
- else
225
- relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
226
- }
227
-
228
-
229
- template <int N, int BLOCKDIM>
230
- __global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){
231
-
232
- constexpr bool N_MOD_2 = N & 1 ? false : true;
233
-
234
- using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
235
-
236
- constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
237
-
238
- constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
239
-
240
- int batch_id = blockIdx.y;
241
- int batch_offset = batch_id * gridDim.x * N;
242
-
243
- Access_tp src_v[iter];
244
- Access_tp bias_v[iter];
245
-
246
- for(int i = 0; i < iter; i++){
247
- int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
248
- if (idx < N){
249
- src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
250
- if (mat_bias)
251
- bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
252
- else
253
- bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
254
- *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i]));
255
- }
256
-
257
- }
258
- }
259
-
260
- template <int N, int BLOCKDIM>
261
- __global__ void identity_and_activation(half* inout){
262
-
263
- constexpr bool N_MOD_2 = N & 1 ? false : true;
264
-
265
- using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
266
-
267
- constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
268
-
269
- constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
270
-
271
- int batch_id = blockIdx.y;
272
- int batch_offset = batch_id * gridDim.x * N;
273
- Access_tp src_v[iter];
274
-
275
- for(int i = 0; i < iter; i++){
276
- int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
277
- if (idx < N){
278
- src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
279
- *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]);
280
- }
281
-
282
- }
283
- }
284
-
285
- template <int N, int BLOCKDIM>
286
- void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
287
- dim3 grid(m, b);
288
- if (bias == nullptr)
289
- identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
290
- else
291
- identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
292
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h DELETED
@@ -1,94 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
- #define TI(tag) \
34
- cudaEvent_t _event_start_ ##tag; \
35
- cudaEvent_t _event_end_ ##tag; \
36
- float _event_time_ ##tag; \
37
- cudaEventCreate(& _event_start_ ##tag); \
38
- cudaEventCreate(& _event_end_ ##tag); \
39
- cudaEventRecord(_event_start_ ##tag);
40
-
41
- #define TO(tag, str, times) \
42
- cudaEventRecord(_event_end_ ##tag); \
43
- cudaEventSynchronize(_event_end_ ##tag); \
44
- cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \
45
- float _event_time_once_ ##tag = _event_time_ ##tag / times; \
46
- printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \
47
- cudaDeviceSynchronize(); \
48
- printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError()));
49
-
50
- template<typename T>
51
- struct memory_unit{
52
- T* host_ptr;
53
- T* device_ptr;
54
- int size_bytes;
55
- int elements;
56
- void h2d(){
57
- cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice);
58
- }
59
- void d2h(){
60
- cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost);
61
- }
62
- void free_all(){
63
- free(host_ptr);
64
- cudaFree(device_ptr);
65
- }
66
- memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){
67
- host_ptr = (T*) malloc(elements_ * sizeof(T));
68
- cudaMalloc((void**)&device_ptr, elements_ * sizeof(T));
69
- }
70
- void init(int abs_range = 1){
71
- for(int i = 0; i < elements; i++){
72
- host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range);
73
- }
74
- h2d();
75
- }
76
- };
77
-
78
- template<typename T>
79
- int check_result(T * a, T * b, int N){
80
- int cnt = 0;
81
- for(int i = 0; i < N; i ++){
82
- float std = float(a[i]);
83
- float my = float(b[i]);
84
-
85
- if(abs(std - my) / abs(std) > 1e-2)
86
- {
87
- // printf("my: %f , std: %f\n", my, std);
88
- cnt++;
89
- }
90
-
91
- }
92
- printf("total err: %d / %d\n", cnt, N);
93
- return cnt;
94
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h DELETED
@@ -1,499 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Performs a dual gemm in one fused kernel:
33
- ```
34
- D0 = epilogue0(X @ B0, C0)
35
- D1 = epilogue1(X @ B1, C1)
36
- D2 = element_wise(D0, D1)
37
- ```
38
- */
39
-
40
- #pragma once
41
-
42
- #include "cutlass/cutlass.h"
43
- #include "cutlass/numeric_types.h"
44
- #include "cutlass/arch/arch.h"
45
- #include "cutlass/device_kernel.h"
46
-
47
- #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
48
-
49
- #include "cutlass/gemm/device/default_gemm_configuration.h"
50
- #include "cutlass/gemm/threadblock/default_mma.h"
51
- #include "cutlass/epilogue/thread/linear_combination_relu.h"
52
- #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
53
-
54
- #include "../kernel/dual_gemm.h"
55
- #include "../dual_gemm_common.h"
56
-
57
- ////////////////////////////////////////////////////////////////////////////////
58
-
59
- namespace cutlass {
60
- namespace gemm {
61
- namespace device {
62
-
63
- /////////////////////////////////////////////////////////////////////////////////////////////////
64
-
65
- template <
66
- /// Element type for A matrix operand
67
- typename ElementA_,
68
- /// Layout type for A matrix operand
69
- typename LayoutA_,
70
- /// Element type for B matrix operand
71
- typename ElementB_,
72
- /// Layout type for B0 matrix operand
73
- typename LayoutB0_,
74
- /// Layout type for B1 matrix operand
75
- typename LayoutB1_,
76
- /// Element type for C and D matrix operands
77
- typename ElementC_,
78
- /// Layout type for C and D matrix operands
79
- typename LayoutC_,
80
- /// Element type for internal accumulation
81
- typename ElementAccumulator_,
82
- /// Operator class tag
83
- typename OperatorClass_,
84
- /// Tag indicating architecture to tune for
85
- typename ArchTag_,
86
- /// Threadblock-level tile size (concept: GemmShape)
87
- typename ThreadblockShape_,
88
- /// Warp-level tile size (concept: GemmShape)
89
- typename WarpShape_,
90
- /// Instruction-level tile size (concept: GemmShape)
91
- typename InstructionShape_,
92
- /// Epilogue output operator
93
- typename EpilogueOutputOp0_,
94
- typename EpilogueOutputOp1_,
95
- typename EpilogueOutputOp2_,
96
- /// Threadblock-level swizzling operator
97
- typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
98
- /// Number of stages used in the pipelined mainloop
99
- int Stages =
100
- DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
101
- ElementC_, ElementAccumulator_>::kStages,
102
- bool StoreD0 = true,
103
- bool StoreD1 = true,
104
- /// If true, kernel supports split-K with serial reduction
105
- bool SplitKSerial = false,
106
- /// Access granularity of A matrix in units of elements
107
- int AlignmentA =
108
- DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
109
- ElementC_, ElementAccumulator_>::kAlignmentA,
110
- /// Access granularity of B matrix in units of elements
111
- int AlignmentB =
112
- DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
113
- ElementC_, ElementAccumulator_>::kAlignmentB,
114
- /// Operation performed by GEMM
115
- typename Operator_ = typename DefaultGemmConfiguration<
116
- OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
117
- ElementAccumulator_>::Operator>
118
- class DualGemm {
119
- public:
120
-
121
- using ElementA = ElementA_;
122
- using LayoutA = LayoutA_;
123
- using TensorRefA = TensorRef<ElementA const, LayoutA>;
124
- using ElementB = ElementB_;
125
- using LayoutB0 = LayoutB0_;
126
- using LayoutB1 = LayoutB1_;
127
- using TensorRefB0 = TensorRef<ElementB const, LayoutB0>;
128
- using TensorRefB1 = TensorRef<ElementB const, LayoutB1>;
129
- using ElementC = ElementC_;
130
- using LayoutC = LayoutC_;
131
- using TensorRefC = TensorRef<ElementC const, LayoutC>;
132
- using TensorRefD = TensorRef<ElementC, LayoutC>;
133
- using ElementAccumulator = ElementAccumulator_;
134
- using OperatorClass = OperatorClass_;
135
- using ArchTag = ArchTag_;
136
- using ThreadblockShape = ThreadblockShape_;
137
- using WarpShape = WarpShape_;
138
- using InstructionShape = InstructionShape_;
139
- using EpilogueOutputOp0 = EpilogueOutputOp0_;
140
- using EpilogueOutputOp1 = EpilogueOutputOp1_;
141
- using EpilogueOutputOp2 = EpilogueOutputOp2_;
142
- using ThreadblockSwizzle = ThreadblockSwizzle_;
143
- using Operator = Operator_;
144
- static int const kStages = Stages;
145
- static int const kAlignmentA = AlignmentA;
146
- static int const kAlignmentB = AlignmentB;
147
- static int const kAlignmentC = EpilogueOutputOp1::kCount;
148
- static bool const kSplitKSerial = SplitKSerial;
149
- static bool constexpr kStoreD0 = StoreD0;
150
- static bool constexpr kStoreD1 = StoreD1;
151
- static ComplexTransform const kTransformA = ComplexTransform::kNone;
152
- static ComplexTransform const kTransformB = ComplexTransform::kNone;
153
-
154
- using LayoutScaleBias = layout::RowMajor;
155
- /// Define the kernel
156
- /// Define the threadblock-scoped matrix multiply-accumulate
157
- static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
158
- static_assert(kStages >= 3, "Only multistage is implemented");
159
- using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
160
- ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
161
- ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
162
- ThreadblockShape, WarpShape,
163
- InstructionShape, Stages, Operator>::ThreadblockMma;
164
- using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
165
- ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
166
- ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
167
- ThreadblockShape, WarpShape,
168
- InstructionShape, Stages, Operator>::ThreadblockMma;
169
- using DualMma = threadblock::DualMmaMultistage<
170
- typename Mma0::Shape,
171
- typename Mma0::IteratorA,
172
- typename Mma0::SmemIteratorA,
173
- Mma0::kCacheOpA,
174
- typename Mma0::IteratorB,
175
- typename Mma0::SmemIteratorB,
176
- Mma0::kCacheOpB,
177
- typename Mma1::IteratorB,
178
- typename Mma1::SmemIteratorB,
179
- typename Mma0::ElementC,
180
- typename Mma0::LayoutC,
181
- typename Mma0::Policy,
182
- typename Mma1::Policy,
183
- Mma0::kStages,
184
- SharedMemoryClearOption::kNone
185
- >;
186
-
187
- static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
188
-
189
- /// Define the epilogue
190
- using Epilogue0 =
191
- typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
192
- ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0,
193
- EpilogueOutputOp0::kCount>::Epilogue;
194
- using Epilogue1 =
195
- typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
196
- ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1,
197
- EpilogueOutputOp1::kCount>::Epilogue;
198
-
199
- /// Define the kernel-level GEMM operator.
200
- using DualGemmKernel = kernel::DualGemm<
201
- DualMma,
202
- Epilogue0, Epilogue1, EpilogueOutputOp2,
203
- ThreadblockSwizzle, kSplitKSerial,
204
- kStoreD0, kStoreD1>;
205
-
206
- /// Argument structure
207
- struct Arguments {
208
-
209
- //
210
- // Data members
211
- //
212
-
213
- DualGemmMode mode;
214
- GemmCoord problem_size;
215
- TensorRef<ElementA const, LayoutA> ref_A0;
216
- TensorRef<ElementB const, LayoutB0> ref_B0;
217
- TensorRef<ElementC const, LayoutC> ref_C0;
218
- TensorRef<ElementC, LayoutC> ref_D0;
219
- TensorRef<ElementB const, LayoutB1> ref_B1;
220
- TensorRef<ElementC const, LayoutC> ref_C1;
221
- TensorRef<ElementC, LayoutC> ref_D1;
222
- TensorRef<ElementC, LayoutC> ref_D2;
223
- typename EpilogueOutputOp0::Params epilogue0;
224
- typename EpilogueOutputOp1::Params epilogue1;
225
- typename EpilogueOutputOp2::Params epilogue2;
226
- int split_k_slices;
227
-
228
- int batch_count;
229
- int64_t batch_stride_A;
230
- int64_t batch_stride_B0;
231
- int64_t batch_stride_B1;
232
- int64_t batch_stride_C;
233
- int64_t batch_stride_D;
234
-
235
- //
236
- // Methods
237
- //
238
-
239
- /// Default ctor
240
- CUTLASS_HOST_DEVICE
241
- Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
242
-
243
- }
244
-
245
- /// Constructs an Arguments structure
246
- CUTLASS_HOST_DEVICE
247
- Arguments(
248
- DualGemmMode mode,
249
- GemmCoord problem_size_,
250
- TensorRef<ElementA const, LayoutA> ref_A0_,
251
- TensorRef<ElementB const, LayoutB0> ref_B0_,
252
- TensorRef<ElementC const, LayoutC> ref_C0_,
253
- TensorRef<ElementC, LayoutC> ref_D0_,
254
- TensorRef<ElementB const, LayoutB1> ref_B1_,
255
- TensorRef<ElementC const, LayoutC> ref_C1_,
256
- TensorRef<ElementC, LayoutC> ref_D1_,
257
- TensorRef<ElementC, LayoutC> ref_D2_,
258
- typename EpilogueOutputOp0::Params epilogue0_ =
259
- typename EpilogueOutputOp0::Params(),
260
- typename EpilogueOutputOp1::Params epilogue1_ =
261
- typename EpilogueOutputOp1::Params(),
262
- typename EpilogueOutputOp2::Params epilogue2_ =
263
- typename EpilogueOutputOp2::Params(),
264
- int split_k_slices_ = 1,
265
- int batch_count = 1,
266
- int64_t batch_stride_A = 0,
267
- int64_t batch_stride_B0 = 0,
268
- int64_t batch_stride_B1 = 0,
269
- int64_t batch_stride_C = 0,
270
- int64_t batch_stride_D = 0
271
- ):
272
- mode(mode),
273
- problem_size(problem_size_),
274
- ref_A0(ref_A0_),
275
- ref_B0(ref_B0_),
276
- ref_C0(ref_C0_),
277
- ref_D0(ref_D0_),
278
- ref_B1(ref_B1_),
279
- ref_C1(ref_C1_),
280
- ref_D1(ref_D1_),
281
- ref_D2(ref_D2_),
282
- epilogue0(epilogue0_),
283
- epilogue1(epilogue1_),
284
- epilogue2(epilogue2_),
285
- split_k_slices(split_k_slices_),
286
- batch_count(batch_count),
287
- batch_stride_A(batch_stride_A),
288
- batch_stride_B0(batch_stride_B0),
289
- batch_stride_B1(batch_stride_B1),
290
- batch_stride_C(batch_stride_C),
291
- batch_stride_D(batch_stride_D) {
292
-
293
- }
294
- };
295
-
296
- private:
297
-
298
- /// Kernel parameters object
299
- typename DualGemmKernel::Params params_;
300
-
301
- public:
302
-
303
- /// Constructs the GEMM.
304
- DualGemm() = default;
305
-
306
- /// Determines whether the GEMM can execute the given problem.
307
- static Status can_implement(Arguments const &args) {
308
-
309
- if (args.mode == DualGemmMode::kBatched && kSplitKSerial) {
310
- return Status::kErrorInvalidProblem;
311
- }
312
- if (!kSplitKSerial && args.split_k_slices > 1) {
313
- return Status::kErrorInvalidProblem;
314
- }
315
- if (kStoreD0 != (args.ref_D0.data() != nullptr)) {
316
- return Status::kErrorInternal;
317
- }
318
- if (kStoreD1 != (args.ref_D1.data() != nullptr)) {
319
- return Status::kErrorInternal;
320
- }
321
-
322
- Status status = DualGemmKernel::can_implement(
323
- args.problem_size,
324
- args.ref_A0.non_const_ref(),
325
- args.ref_B0.non_const_ref(),
326
- args.ref_C0.non_const_ref(),
327
- args.ref_D0,
328
- args.ref_B1.non_const_ref(),
329
- args.ref_C1.non_const_ref(),
330
- args.ref_D1,
331
- args.ref_D2
332
- );
333
-
334
- if (status != Status::kSuccess) {
335
- return status;
336
- }
337
-
338
- return Status::kSuccess;
339
- }
340
-
341
- /// Gets the workspace size
342
- static size_t get_workspace_size(Arguments const &args) {
343
-
344
- size_t bytes = 0;
345
-
346
- if (kSplitKSerial && args.split_k_slices > 1) {
347
- // Determine grid shape
348
- ThreadblockSwizzle threadblock_swizzle;
349
-
350
- cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
351
- args.problem_size,
352
- {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
353
- args.split_k_slices);
354
-
355
- bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
356
- }
357
-
358
- return bytes;
359
- }
360
-
361
- /// Initializes GEMM state from arguments.
362
- Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
363
-
364
- // Determine grid shape
365
- ThreadblockSwizzle threadblock_swizzle;
366
-
367
- cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
368
- args.problem_size,
369
- {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
370
- args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices);
371
-
372
- if (kSplitKSerial) {
373
- if (args.split_k_slices > 1) {
374
- if (!workspace) {
375
- return Status::kErrorWorkspaceNull;
376
- }
377
-
378
- size_t bytes = get_workspace_size(args);
379
-
380
- cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
381
-
382
- if (result != cudaSuccess) {
383
- return Status::kErrorInternal;
384
- }
385
- }
386
- }
387
- else {
388
-
389
- if (args.split_k_slices > 1) {
390
- return Status::kErrorInvalidProblem;
391
- }
392
- }
393
-
394
- // Initialize the Params structure
395
- params_ = typename DualGemmKernel::Params{
396
- args.mode,
397
- args.problem_size,
398
- grid_shape,
399
- args.ref_A0.non_const_ref(),
400
- args.ref_B0.non_const_ref(),
401
- args.ref_C0.non_const_ref(),
402
- args.ref_D0,
403
- args.ref_B1.non_const_ref(),
404
- args.ref_C1.non_const_ref(),
405
- args.ref_D1,
406
- args.ref_D2,
407
- args.epilogue0,
408
- args.epilogue1,
409
- args.epilogue2,
410
- reinterpret_cast<int *>(workspace),
411
- args.batch_stride_A,
412
- args.batch_stride_B0,
413
- args.batch_stride_B1,
414
- args.batch_stride_C,
415
- args.batch_stride_D,
416
- };
417
-
418
- return Status::kSuccess;
419
- }
420
-
421
- /// Lightweight update given a subset of arguments
422
- Status update(Arguments const &args, void *workspace = nullptr) {
423
-
424
- if (kSplitKSerial && args.split_k_slices > 1) {
425
- if (!workspace) {
426
- return Status::kErrorWorkspaceNull;
427
- }
428
- }
429
-
430
- params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
431
- params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
432
- params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
433
- params_.ref_D0.reset(args.ref_D0.data());
434
- params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
435
- params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
436
- params_.ref_D1.reset(args.ref_D1.data());
437
- params_.ref_D2.reset(args.ref_D2.data());
438
- params_.output_op_0 = args.epilogue0;
439
- params_.output_op_1 = args.epilogue1;
440
- params_.output_op_2 = args.epilogue2;
441
- params_.semaphore = reinterpret_cast<int *>(workspace);
442
-
443
- return Status::kSuccess;
444
- }
445
-
446
- /// Runs the kernel using initialized state.
447
- Status run(cudaStream_t stream = nullptr) {
448
-
449
- ThreadblockSwizzle threadblock_swizzle;
450
-
451
- dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
452
- dim3 block(DualGemmKernel::kThreadCount, 1, 1);
453
-
454
- cudaError_t result;
455
-
456
- int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage));
457
- if (smem_size >= (48 << 10)) {
458
- result = cudaFuncSetAttribute(Kernel<DualGemmKernel>,
459
- cudaFuncAttributeMaxDynamicSharedMemorySize,
460
- smem_size);
461
-
462
- if (result != cudaSuccess) {
463
- return Status::kErrorInternal;
464
- }
465
- }
466
-
467
- cutlass::Kernel<DualGemmKernel><<<grid, block, smem_size, stream>>>(params_);
468
-
469
- result = cudaGetLastError();
470
-
471
- return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
472
- }
473
-
474
- /// Runs the kernel using initialized state.
475
- Status operator()(cudaStream_t stream = nullptr) {
476
- return run(stream);
477
- }
478
-
479
- /// Runs the kernel using initialized state.
480
- Status operator()(
481
- Arguments const &args,
482
- void *workspace = nullptr,
483
- cudaStream_t stream = nullptr) {
484
-
485
- Status status = initialize(args, workspace, stream);
486
-
487
- if (status == Status::kSuccess) {
488
- status = run(stream);
489
- }
490
-
491
- return status;
492
- }
493
- };
494
-
495
- } // namespace device
496
- } // namespace gemm
497
- } // namespace cutlass
498
-
499
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h DELETED
@@ -1,52 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Defines common types used for all DualGemm operators.
33
- */
34
- #pragma once
35
-
36
- namespace cutlass {
37
- namespace gemm {
38
-
39
- /////////////////////////////////////////////////////////////////////////////////////////////////
40
-
41
- enum class DualGemmMode {
42
- kGemm,
43
- kBatched,
44
- kInvalid
45
- };
46
-
47
- ////////////////////////////////////////////////////////////////////////////////////////////////////
48
-
49
- } // namespace gemm
50
- } // namespace cutlass
51
-
52
- ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h DELETED
@@ -1,938 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include <iostream>
34
- #include <fstream>
35
- #include <sstream>
36
- #include <type_traits>
37
-
38
- #include "cutlass/util/host_tensor.h"
39
- #include "cutlass/util/tensor_view_io.h"
40
- #include "cutlass/util/distribution.h"
41
- #include "cutlass/util/reference/host/tensor_fill.h"
42
- #include "cutlass/util/reference/host/tensor_copy.h"
43
- #include "cutlass/util/reference/host/tensor_compare.h"
44
- #include "cutlass/util/reference/host/tensor_norm.h"
45
- #include "cutlass/util/reference/device/gemm.h"
46
- #include "cutlass/util/reference/device/tensor_relu.h"
47
-
48
- #include "cutlass/platform/platform.h"
49
- #include "cutlass/gemm/gemm.h"
50
- #include "cutlass/gemm/device/gemm_universal.h"
51
-
52
- #include "dual_gemm_common.h"
53
- #include "helper.h"
54
-
55
- #define CHECK_GT(val1, val2) \
56
- if((val1) <= (val2)) \
57
- std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
58
- #define CHECK_TRUE(val) \
59
- if(!(val)) \
60
- std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
61
-
62
- template <
63
- typename OutputOp,
64
- typename Element,
65
- typename Layout>
66
- struct TensorEpilogueForEachFunc {
67
- /// View type
68
- using TensorView = cutlass::TensorView<Element, Layout>;
69
-
70
- /// Coordinate in tensor's index space
71
- using TensorCoord = typename TensorView::TensorCoord;
72
-
73
- /// Parameters structure
74
- struct Params {
75
-
76
- //
77
- // Data members
78
- //
79
-
80
- TensorView view_x0;
81
- TensorView view_x1;
82
- TensorView view_y;
83
- OutputOp output_op;
84
-
85
-
86
- //
87
- // Methods
88
- //
89
-
90
- Params(
91
- TensorView view_x0_ = TensorView(),
92
- TensorView view_x1_ = TensorView(),
93
- TensorView view_y_ = TensorView(),
94
- OutputOp output_op_ = OutputOp(typename OutputOp::Params{})
95
- ):
96
- view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) {
97
- }
98
- };
99
-
100
- Params params;
101
-
102
- CUTLASS_DEVICE
103
- TensorEpilogueForEachFunc(Params const &params): params(params) {
104
-
105
- }
106
-
107
- CUTLASS_DEVICE
108
- void operator()(TensorCoord const &coord) {
109
- Element const & x0 = params.view_x0.at(coord);
110
- Element const & x1 = params.view_x1.at(coord);
111
- Element& y = params.view_y.at(coord);
112
- y = params.output_op(x0, x1);
113
- }
114
- };
115
-
116
- template <
117
- typename OutputOp,
118
- typename Element,
119
- typename Layout>
120
- void TensorEpilogueForEach(
121
- cutlass::TensorView<Element, Layout> x0,
122
- cutlass::TensorView<Element, Layout> x1,
123
- cutlass::TensorView<Element, Layout> y) {
124
-
125
- using Func = TensorEpilogueForEachFunc<OutputOp, Element, Layout>;
126
- using Params = typename Func::Params;
127
-
128
- cutlass::reference::device::TensorForEach<Func, Layout::kRank, Params>(
129
- y.extent(),
130
- Params(x0, x1, y)
131
- );
132
- }
133
-
134
- ////////////////////////////////////////////////////////////////////////////////
135
-
136
- template <typename Gemm0_, typename Gemm1_>
137
- struct NonFusedDualGemmRun
138
- {
139
-
140
- using Gemm0 = Gemm0_;
141
- using Gemm1 = Gemm1_;
142
- using ElementAccumulator = typename Gemm0::ElementAccumulator;
143
- using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
144
-
145
- /// Initialization
146
- cutlass::Distribution::Kind init_A;
147
- cutlass::Distribution::Kind init_B;
148
- cutlass::Distribution::Kind init_C;
149
- cutlass::Distribution::Kind init_Bias;
150
- uint64_t seed;
151
-
152
- //
153
- // Methods
154
- //
155
-
156
- NonFusedDualGemmRun(
157
- cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
158
- cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
159
- cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
160
- cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
161
- uint64_t seed_ = 2080
162
- ):
163
- init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
164
-
165
- /// Helper to initialize a tensor view
166
- template <typename Element, typename Layout>
167
- bool initialize_tensor(
168
- cutlass::TensorView<Element, Layout> view,
169
- cutlass::Distribution::Kind dist_kind,
170
- uint64_t seed) {
171
-
172
- if (dist_kind == cutlass::Distribution::Uniform) {
173
-
174
- cutlass::reference::host::TensorFillRandomUniform(
175
- view, seed, 2, -2, 0);
176
- }
177
- else if (dist_kind == cutlass::Distribution::Identity) {
178
-
179
- cutlass::reference::host::TensorFillIdentity(view);
180
- }
181
- else if (dist_kind == cutlass::Distribution::Gaussian) {
182
-
183
- cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
184
- }
185
- else if (dist_kind == cutlass::Distribution::Sequential) {
186
-
187
- cutlass::reference::host::BlockFillSequential(
188
- view.data(), view.capacity());
189
- }
190
- else if (dist_kind == cutlass::Distribution::AllZeros) {
191
- cutlass::reference::host::TensorFill(view, Element(0));
192
- }
193
- else if (dist_kind == cutlass::Distribution::AllOnes) {
194
- cutlass::reference::host::TensorFill(view, Element(1));
195
- }
196
- else {
197
- std::cerr << "Not implemented\n";
198
- return false;
199
- }
200
-
201
- return true;
202
- }
203
-
204
-
205
-
206
-
207
- /// Executes one test
208
- bool run(
209
- cutlass::gemm::GemmCoord problem_size,
210
- ElementCompute alpha0 = ElementCompute(1),
211
- ElementCompute beta0 = ElementCompute(0),
212
- ElementCompute alpha1 = ElementCompute(1),
213
- ElementCompute beta1 = ElementCompute(0),
214
- bool is_profiling = true,
215
- bool relu = false,
216
- int warm_ups = 1,
217
- int runs = 100) {
218
-
219
- //
220
- // Allocate the GEMM workspace
221
- //
222
-
223
- cutlass::HostTensor<
224
- typename Gemm0::ElementA,
225
- typename Gemm0::LayoutA> tensor_A0(problem_size.mk());
226
-
227
- cutlass::HostTensor<
228
- typename Gemm0::ElementB,
229
- typename Gemm0::LayoutB> tensor_B0(problem_size.kn());
230
-
231
- cutlass::HostTensor<
232
- typename Gemm0::ElementC,
233
- typename Gemm0::LayoutC> tensor_C0(problem_size.mn());
234
-
235
- cutlass::HostTensor<
236
- typename Gemm1::ElementC,
237
- typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()});
238
-
239
- cutlass::HostTensor<
240
- typename Gemm0::ElementC,
241
- typename Gemm0::LayoutC> tensor_D0(problem_size.mn());
242
-
243
- cutlass::HostTensor<
244
- typename Gemm0::ElementC,
245
- typename Gemm0::LayoutC> reference_D0(problem_size.mn());
246
-
247
- cutlass::HostTensor<
248
- typename Gemm1::ElementB,
249
- typename Gemm1::LayoutB> tensor_B1(problem_size.kn());
250
-
251
- cutlass::HostTensor<
252
- typename Gemm1::ElementC,
253
- typename Gemm1::LayoutC> tensor_C1(problem_size.mn());
254
-
255
- cutlass::HostTensor<
256
- typename Gemm1::ElementC,
257
- typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()});
258
-
259
- cutlass::HostTensor<
260
- typename Gemm1::ElementC,
261
- typename Gemm1::LayoutC> tensor_D1(problem_size.mn());
262
-
263
- cutlass::HostTensor<
264
- typename Gemm1::ElementC,
265
- typename Gemm1::LayoutC> reference_D1(problem_size.mn());
266
-
267
-
268
- CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
269
- CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
270
- CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
271
- CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
272
- CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
273
- CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
274
- CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
275
-
276
- cutlass::reference::host::TensorFill(
277
- tensor_D0.host_view());
278
- cutlass::reference::host::TensorFill(
279
- tensor_D1.host_view());
280
- cutlass::reference::host::TensorFill(
281
- reference_D0.host_view());
282
- cutlass::reference::host::TensorFill(
283
- reference_D1.host_view());
284
-
285
- tensor_A0.sync_device();
286
- tensor_B0.sync_device();
287
- tensor_C0.sync_device();
288
- tensor_Bias0.sync_device();
289
- tensor_D0.sync_device();
290
- reference_D0.sync_device();
291
- tensor_B1.sync_device();
292
- tensor_C1.sync_device();
293
- tensor_Bias1.sync_device();
294
- tensor_D1.sync_device();
295
- reference_D1.sync_device();
296
-
297
- //
298
- // Initialize the GEMM operator
299
- //
300
-
301
- int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1;
302
- typename Gemm0::Arguments arguments_0{
303
- problem_size,
304
- tensor_A0.device_ref(),
305
- tensor_B0.device_ref(),
306
- {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
307
- tensor_D0.device_ref(),
308
- {alpha0, beta0},
309
- split_k_slices
310
- };
311
-
312
- split_k_slices = Gemm1::kSplitKSerial ? 2 : 1;
313
- typename Gemm1::Arguments arguments_1{
314
- problem_size,
315
- tensor_A0.device_ref(),
316
- tensor_B1.device_ref(),
317
- {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
318
- tensor_D1.device_ref(),
319
- {alpha1, beta1},
320
- split_k_slices
321
- };
322
-
323
-
324
- Gemm0 gemm_op_0;
325
- Gemm1 gemm_op_1;
326
-
327
- // Allocate workspace memory
328
- cutlass::device_memory::allocation<uint8_t> workspace0(gemm_op_0.get_workspace_size(arguments_0));
329
- cutlass::device_memory::allocation<uint8_t> workspace1(gemm_op_1.get_workspace_size(arguments_1));
330
-
331
- cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get());
332
-
333
- CUTLASS_CHECK(status);
334
-
335
- status = gemm_op_1.initialize(arguments_1, workspace1.get());
336
-
337
- CUTLASS_CHECK(status);
338
-
339
- for(int i = 0; i < warm_ups; i++) {
340
- status = gemm_op_0();
341
- CUTLASS_CHECK(status);
342
- status = gemm_op_1();
343
- CUTLASS_CHECK(status);
344
- }
345
-
346
- if (is_profiling) {
347
- //
348
- // Profile the GEMM
349
- //
350
-
351
- cudaEvent_t start, stop1, stop2;
352
- cudaEventCreate(&start);
353
- cudaEventCreate(&stop1);
354
- cudaEventCreate(&stop2);
355
-
356
- cudaEventRecord(start);
357
-
358
- for(int i = 0; i < runs; i++) {
359
- status = gemm_op_0();
360
-
361
- CUTLASS_CHECK(status);
362
- }
363
- cudaEventRecord(stop1);
364
- for(int i = 0; i < runs; i++) {
365
- status = gemm_op_1();
366
-
367
- CUTLASS_CHECK(status);
368
- }
369
-
370
- cudaEventRecord(stop2);
371
- cudaDeviceSynchronize();
372
- float gemm0Time, gemm1Time, totalTime;
373
- cudaEventElapsedTime(&gemm0Time, start, stop1);
374
- cudaEventElapsedTime(&gemm1Time, stop1, stop2);
375
- cudaEventElapsedTime(&totalTime, start, stop2);
376
- std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
377
- std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
378
- std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
379
- }
380
-
381
- tensor_D0.sync_host();
382
- tensor_D1.sync_host();
383
-
384
- //
385
- // Verify
386
- //
387
- cutlass::reference::device::Gemm<
388
- typename Gemm0::ElementA, typename Gemm0::LayoutA,
389
- typename Gemm0::ElementB, typename Gemm0::LayoutB,
390
- typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
391
- ElementAccumulator, typename Gemm0::Operator>
392
- reference_gemm_0;
393
-
394
- cutlass::reference::device::Gemm<
395
- typename Gemm1::ElementA, typename Gemm1::LayoutA,
396
- typename Gemm1::ElementB, typename Gemm1::LayoutB,
397
- typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
398
- ElementAccumulator, typename Gemm1::Operator>
399
- reference_gemm_1;
400
-
401
- reference_gemm_0(
402
- problem_size,
403
- alpha0,
404
- tensor_A0.device_ref(),
405
- tensor_B0.device_ref(),
406
- beta0,
407
- {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
408
- reference_D0.device_ref()
409
- );
410
-
411
- if(relu) {
412
- cutlass::reference::device::TensorReLu(reference_D0.device_view());
413
- }
414
-
415
- reference_gemm_1(
416
- problem_size,
417
- alpha1,
418
- tensor_A0.device_ref(),
419
- tensor_B1.device_ref(),
420
- beta1,
421
- {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
422
- reference_D1.device_ref()
423
- );
424
-
425
- if(relu) {
426
- cutlass::reference::device::TensorReLu(reference_D1.device_view());
427
- }
428
-
429
- // Wait for kernels to finish
430
- cudaDeviceSynchronize();
431
- reference_D0.sync_host();
432
- reference_D1.sync_host();
433
-
434
- CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
435
- CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
436
- CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
437
- CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
438
-
439
- bool passed0 = cutlass::reference::host::TensorEquals(
440
- reference_D1.host_view(),
441
- tensor_D1.host_view());
442
- CHECK_TRUE(passed0);
443
-
444
- bool passed1 = cutlass::reference::host::TensorEquals(
445
- reference_D1.host_view(),
446
- tensor_D1.host_view());
447
- CHECK_TRUE(passed1);
448
- if (!passed0 || !passed1) {
449
-
450
- std::stringstream fname;
451
-
452
- fname << "error_DualGemm_device_nonfused.txt";
453
- std::cerr << "Dumping results in " << fname.str() << "\n";
454
-
455
- std::ofstream file(fname.str());
456
-
457
- file
458
- << "A0 =\n" << tensor_A0.host_view()
459
- << "\nB0 =\n" << tensor_B0.host_view()
460
- << "\nC0 =\n" << tensor_C0.host_view()
461
- << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
462
- << "\nD0 =\n" << tensor_D0.host_view()
463
- << "\nB1 =\n" << tensor_B1.host_view()
464
- << "\nC1 =\n" << tensor_C1.host_view()
465
- << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
466
- << "\n\nReference =\n" << reference_D1.host_view()
467
- << "\nComputed =\n" << tensor_D1.host_view();
468
- }
469
- return passed0 && passed1;
470
- }
471
- };
472
-
473
- template <typename DualGemm_>
474
- struct DualFusedGemmRun
475
- {
476
-
477
- using DualGemm = DualGemm_;
478
- using ElementAccumulator = typename DualGemm::ElementAccumulator;
479
- using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute;
480
- using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2;
481
-
482
- /// Initialization
483
- cutlass::Distribution::Kind init_A;
484
- cutlass::Distribution::Kind init_B;
485
- cutlass::Distribution::Kind init_C;
486
- cutlass::Distribution::Kind init_Scale;
487
- cutlass::Distribution::Kind init_Bias;
488
- uint64_t seed;
489
-
490
- //
491
- // Methods
492
- //
493
-
494
- DualFusedGemmRun(
495
- cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
496
- cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
497
- cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
498
- cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
499
- cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
500
- uint64_t seed_ = 2080
501
- ):
502
- init_A(init_A_), init_B(init_B_), init_C(init_C_),
503
- init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
504
-
505
- /// Helper to initialize a tensor view
506
- template <typename Element, typename Layout>
507
- bool initialize_tensor(
508
- cutlass::TensorView<Element, Layout> view,
509
- cutlass::Distribution::Kind dist_kind,
510
- uint64_t seed) {
511
-
512
- if (dist_kind == cutlass::Distribution::Uniform) {
513
-
514
- cutlass::reference::host::TensorFillRandomUniform(
515
- view, seed, 2, -2, 0);
516
- }
517
- else if (dist_kind == cutlass::Distribution::Identity) {
518
-
519
- cutlass::reference::host::TensorFillIdentity(view);
520
- }
521
- else if (dist_kind == cutlass::Distribution::Gaussian) {
522
-
523
- cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
524
- }
525
- else if (dist_kind == cutlass::Distribution::Sequential) {
526
-
527
- cutlass::reference::host::BlockFillSequential(
528
- view.data(), view.capacity());
529
- }
530
- else if (dist_kind == cutlass::Distribution::AllZeros) {
531
- cutlass::reference::host::TensorFill(view, Element(0));
532
- }
533
- else if (dist_kind == cutlass::Distribution::AllOnes) {
534
- cutlass::reference::host::TensorFill(view, Element(1));
535
- }
536
- else {
537
- std::cerr << "Not implemented\n";
538
- return false;
539
- }
540
-
541
- return true;
542
- }
543
-
544
-
545
-
546
-
547
- /// Executes one test
548
- bool run(
549
- cutlass::gemm::GemmCoord problem_size,
550
- ElementCompute alpha0 = ElementCompute(1),
551
- ElementCompute beta0 = ElementCompute(1),
552
- ElementCompute alpha1 = ElementCompute(1),
553
- ElementCompute beta1 = ElementCompute(1),
554
- int batch_count = 1,
555
- bool broadcast_b1 = false,
556
- bool is_profiling = true,
557
- bool relu = false,
558
- int warm_ups = 1,
559
- int runs = 100) {
560
-
561
- //
562
- // Allocate the GEMM workspace
563
- //
564
-
565
- cutlass::HostTensor<
566
- typename DualGemm::ElementA,
567
- typename DualGemm::LayoutA> tensor_A0(
568
- cutlass::platform::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
569
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
570
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k()));
571
-
572
- cutlass::HostTensor<
573
- typename DualGemm::ElementB,
574
- typename DualGemm::LayoutB0> tensor_B0(
575
- cutlass::platform::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
576
- cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
577
- cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
578
-
579
- cutlass::HostTensor<
580
- typename DualGemm::ElementC,
581
- typename DualGemm::LayoutC> tensor_C0(
582
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
583
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
584
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
585
-
586
- cutlass::HostTensor<
587
- typename DualGemm::ElementC,
588
- typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()});
589
-
590
- cutlass::HostTensor<
591
- typename DualGemm::ElementC,
592
- typename DualGemm::LayoutC> tensor_D0(
593
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
594
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
595
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
596
-
597
- cutlass::HostTensor<
598
- typename DualGemm::ElementC,
599
- typename DualGemm::LayoutC> reference_D0(
600
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
601
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
602
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
603
-
604
- cutlass::HostTensor<
605
- typename DualGemm::ElementB,
606
- typename DualGemm::LayoutB1> tensor_B1(
607
- cutlass::platform::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
608
- cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
609
- cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
610
- if (broadcast_b1) {
611
- tensor_B1.resize({problem_size.k(), batch_count});
612
- }
613
-
614
- cutlass::HostTensor<
615
- typename DualGemm::ElementC,
616
- typename DualGemm::LayoutC> tensor_C1(
617
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
618
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
619
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
620
-
621
- cutlass::HostTensor<
622
- typename DualGemm::ElementC,
623
- typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()});
624
-
625
- cutlass::HostTensor<
626
- typename DualGemm::ElementC,
627
- typename DualGemm::LayoutC> tensor_D1(
628
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
629
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
630
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
631
-
632
- cutlass::HostTensor<
633
- typename DualGemm::ElementC,
634
- typename DualGemm::LayoutC> tensor_D2(
635
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
636
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
637
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
638
-
639
- cutlass::HostTensor<
640
- typename DualGemm::ElementC,
641
- typename DualGemm::LayoutC> reference_D1(
642
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
643
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
644
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
645
-
646
- cutlass::HostTensor<
647
- typename DualGemm::ElementC,
648
- typename DualGemm::LayoutC> reference_D2(
649
- cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
650
- cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
651
- cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
652
-
653
- CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
654
- CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118));
655
- CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
656
- CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011));
657
- CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113));
658
- CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
659
- CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
660
-
661
- cutlass::reference::host::TensorFill(
662
- tensor_D0.host_view());
663
- cutlass::reference::host::TensorFill(
664
- tensor_D1.host_view());
665
- cutlass::reference::host::TensorFill(
666
- tensor_D2.host_view());
667
- cutlass::reference::host::TensorFill(
668
- reference_D0.host_view());
669
- cutlass::reference::host::TensorFill(
670
- reference_D1.host_view());
671
- cutlass::reference::host::TensorFill(
672
- reference_D2.host_view());
673
-
674
- tensor_A0.sync_device();
675
- tensor_B0.sync_device();
676
- tensor_C0.sync_device();
677
- tensor_Bias0.sync_device();
678
- tensor_B1.sync_device();
679
- tensor_C1.sync_device();
680
- tensor_Bias1.sync_device();
681
- tensor_D0.sync_device();
682
- tensor_D1.sync_device();
683
- tensor_D2.sync_device();
684
- reference_D0.sync_device();
685
- reference_D1.sync_device();
686
- reference_D2.sync_device();
687
-
688
- //
689
- // Batch strides (irrelevant when batch_count == 1)
690
- //
691
-
692
- int64_t batch_stride_A = problem_size.m() * problem_size.k();
693
- int64_t batch_stride_B0 = problem_size.k() * problem_size.n();
694
- int64_t batch_stride_B1 = problem_size.k() * problem_size.n();
695
- if (broadcast_b1) {
696
- // B1 is a (column) vector
697
- batch_stride_B1 = problem_size.k();
698
- }
699
- int64_t batch_stride_Bias = problem_size.n();
700
- int64_t batch_stride_D = problem_size.m() * problem_size.n();
701
-
702
- //
703
- // Initialize the GEMM operator
704
- //
705
-
706
- int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
707
- typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> nullptr_ref{};
708
- decltype(nullptr_ref) ref_B0, ref_B1;
709
- if (beta0 != ElementCompute(0)) {
710
- ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)};
711
- }
712
- if (beta1 != ElementCompute(0)) {
713
- ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
714
- }
715
- typename DualGemm::Arguments arguments{
716
- (batch_count > 1 ?
717
- cutlass::gemm::DualGemmMode::kBatched :
718
- cutlass::gemm::DualGemmMode::kGemm),
719
- problem_size,
720
- tensor_A0.device_ref(),
721
- tensor_B0.device_ref(),
722
- ref_B0,
723
- DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
724
- (broadcast_b1 ?
725
- typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
726
- tensor_B1.device_ref()),
727
- ref_B1,
728
- DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
729
- tensor_D2.device_ref(),
730
- {alpha0, beta0},
731
- {alpha1, beta1},
732
- {},
733
- split_k_slices,
734
- batch_count,
735
- batch_stride_A,
736
- batch_stride_B0,
737
- batch_stride_B1,
738
- batch_stride_Bias,
739
- batch_stride_D,
740
- };
741
-
742
- //
743
- // Run the GEMM
744
- //
745
-
746
- DualGemm b2b_gemm_op;
747
-
748
- cutlass::device_memory::allocation<uint8_t> workspace(b2b_gemm_op.get_workspace_size(arguments));
749
-
750
- cutlass::Status status = b2b_gemm_op.can_implement(arguments);
751
-
752
- CUTLASS_CHECK(status);
753
-
754
- status = b2b_gemm_op.initialize(arguments, workspace.get());
755
-
756
- CUTLASS_CHECK(status);
757
-
758
- for(int i = 0; i < warm_ups; i++) {
759
- status = b2b_gemm_op();
760
- CUTLASS_CHECK(status);
761
- }
762
-
763
- if (is_profiling) {
764
- //
765
- // Profile the GEMM
766
- //
767
-
768
- cudaEvent_t start, stop;
769
- cudaEventCreate(&start);
770
- cudaEventCreate(&stop);
771
-
772
- cudaEventRecord(start);
773
-
774
- for(int i = 0; i < runs; i++) {
775
- status = b2b_gemm_op();
776
- CUTLASS_CHECK(status);
777
- }
778
-
779
- cudaEventRecord(stop);
780
- cudaDeviceSynchronize();
781
- float gemmTime;
782
- cudaEventElapsedTime(&gemmTime, start, stop);
783
- std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
784
- }
785
-
786
- tensor_D0.sync_host();
787
- tensor_D1.sync_host();
788
- tensor_D2.sync_host();
789
-
790
- //
791
- // Verify
792
- //
793
-
794
- using GemmUniversal0 = cutlass::gemm::device::GemmUniversal<
795
- typename DualGemm::ElementA, typename DualGemm::LayoutA,
796
- typename DualGemm::ElementB, typename DualGemm::LayoutB0,
797
- typename DualGemm::ElementC, typename DualGemm::LayoutC,
798
- ElementAccumulator
799
- >;
800
-
801
- GemmUniversal0 reference_gemm0;
802
-
803
- typename GemmUniversal0::Arguments args0 {
804
- (batch_count > 1 ?
805
- cutlass::gemm::GemmUniversalMode::kBatched :
806
- cutlass::gemm::GemmUniversalMode::kGemm),
807
- problem_size,
808
- batch_count,
809
- {alpha0, beta0},
810
- tensor_A0.device_data(),
811
- tensor_B0.device_data(),
812
- tensor_Bias0.device_data(),
813
- reference_D0.device_data(),
814
- batch_stride_A,
815
- batch_stride_B0,
816
- batch_stride_Bias,
817
- batch_stride_D,
818
- tensor_A0.stride(0),
819
- tensor_B0.stride(0),
820
- 0, // zero stride for the bias vector
821
- reference_D0.stride(0),
822
- };
823
-
824
- status = reference_gemm0.can_implement(args0);
825
- CUTLASS_CHECK(status);
826
- status = reference_gemm0(args0);
827
- CUTLASS_CHECK(status);
828
-
829
- using GemmUniversal1 = cutlass::gemm::device::GemmUniversal<
830
- typename DualGemm::ElementA, typename DualGemm::LayoutA,
831
- typename DualGemm::ElementB, typename DualGemm::LayoutB1,
832
- typename DualGemm::ElementC, typename DualGemm::LayoutC,
833
- ElementAccumulator
834
- >;
835
-
836
- GemmUniversal1 reference_gemm1;
837
-
838
- typename GemmUniversal1::Arguments args1 {
839
- (batch_count > 1 ?
840
- cutlass::gemm::GemmUniversalMode::kBatched :
841
- cutlass::gemm::GemmUniversalMode::kGemm),
842
- problem_size,
843
- batch_count,
844
- {alpha1, beta1},
845
- tensor_A0.device_data(),
846
- tensor_B1.device_data(),
847
- tensor_Bias1.device_data(),
848
- reference_D1.device_data(),
849
- batch_stride_A,
850
- batch_stride_B1,
851
- batch_stride_Bias,
852
- batch_stride_D,
853
- tensor_A0.stride(0),
854
- (broadcast_b1 ? 0 : tensor_B1.stride(0)),
855
- 0, // zero stride for the bias vector
856
- reference_D1.stride(0),
857
- };
858
-
859
- status = reference_gemm1.can_implement(args1);
860
- CUTLASS_CHECK(status);
861
- status = reference_gemm1(args1);
862
- CUTLASS_CHECK(status);
863
-
864
- if(relu) {
865
- cutlass::reference::device::TensorReLu(reference_D0.device_view());
866
- cutlass::reference::device::TensorReLu(reference_D1.device_view());
867
- }
868
-
869
- TensorEpilogueForEach<EpilogueOutputOp2>(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view());
870
- cudaDeviceSynchronize();
871
- reference_D0.sync_host();
872
- reference_D1.sync_host();
873
- reference_D2.sync_host();
874
-
875
- CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
876
- CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
877
- CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
878
- CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0);
879
-
880
- bool passed_out0 = true;
881
- if (DualGemm::kStoreD0) {
882
- CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
883
- passed_out0 = cutlass::reference::host::TensorEquals(
884
- reference_D0.host_view(),
885
- tensor_D0.host_view());
886
- }
887
- CHECK_TRUE(passed_out0);
888
-
889
- bool passed_out1 = true;
890
- if (DualGemm::kStoreD1) {
891
- CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
892
- passed_out1 = cutlass::reference::host::TensorEquals(
893
- reference_D1.host_view(),
894
- tensor_D1.host_view());
895
- }
896
- CHECK_TRUE(passed_out1);
897
-
898
- bool passed_out2 = cutlass::reference::host::TensorEquals(
899
- reference_D2.host_view(),
900
- tensor_D2.host_view());
901
- CHECK_TRUE(passed_out2);
902
-
903
- bool passed = passed_out0 && passed_out1 && passed_out2;
904
- if (!passed)
905
- {
906
- std::stringstream fname;
907
-
908
- fname << "error_DualGemm_device_fused.txt";
909
- std::cerr << "Dumping results in " << fname.str() << "\n";
910
-
911
- std::ofstream file(fname.str());
912
-
913
- file
914
- << "A0 =\n" << tensor_A0.host_view()
915
- << "\nB0 =\n" << tensor_B0.host_view()
916
- << "\nC0 =\n" << tensor_C0.host_view()
917
- << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
918
- << "\nB1 =\n" << tensor_B1.host_view()
919
- << "\nC1 =\n" << tensor_C1.host_view()
920
- << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
921
- << "\n\nReference0 =\n" << reference_D0.host_view()
922
- << "\nComputed0 =\n" << tensor_D0.host_view()
923
- << "\n\nReference1 =\n" << reference_D1.host_view()
924
- << "\nComputed1 =\n" << tensor_D1.host_view()
925
- << "\n\nReference2 =\n" << reference_D2.host_view()
926
- << "\nComputed2 =\n" << tensor_D2.host_view();
927
- }
928
- //std::cout << "A0 " << tensor_A0.host_view() << std::endl;
929
- // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
930
- // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl;
931
- // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl;
932
- //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
933
- return passed;
934
- }
935
-
936
- };
937
-
938
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h DELETED
@@ -1,545 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/cutlass.h"
38
-
39
- #include "cutlass/gemm/gemm.h"
40
- #include "cutlass/matrix_coord.h"
41
- #include "cutlass/semaphore.h"
42
-
43
- #include "../threadblock/dual_mma_multistage.h"
44
- #include "../threadblock/dual_epilogue.h"
45
- #include "../dual_gemm_common.h"
46
-
47
- /////////////////////////////////////////////////////////////////////////////////////////////////
48
-
49
- namespace cutlass {
50
- namespace gemm {
51
- namespace kernel {
52
-
53
- /////////////////////////////////////////////////////////////////////////////////////////////////
54
-
55
- template <
56
- typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate
57
- typename Epilogue0_, ///! Epilogue
58
- typename Epilogue1_, ///! Epilogue
59
- typename OutputOp2_, ///! Epilogue
60
- typename ThreadblockSwizzle_, ///! Threadblock swizzling function
61
- bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
62
- bool StoreD0,
63
- bool StoreD1
64
- >
65
- struct DualGemm {
66
-
67
- using DualMma = DualMma_;
68
-
69
- using Epilogue0 = Epilogue0_;
70
- using Epilogue1 = Epilogue1_;
71
- using OutputOp0 = typename Epilogue0::OutputOp;
72
- using OutputOp1 = typename Epilogue1::OutputOp;
73
- using OutputOp2 = OutputOp2_;
74
- using ThreadblockSwizzle = ThreadblockSwizzle_;
75
- static constexpr bool kStoreD0 = StoreD0;
76
- static constexpr bool kStoreD1 = StoreD1;
77
-
78
- using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue<
79
- typename Epilogue0::Shape,
80
- typename Epilogue0::WarpMmaOperator,
81
- Epilogue0::kPartitionsK,
82
- typename Epilogue0::OutputTileIterator,
83
- typename Epilogue0::AccumulatorFragmentIterator,
84
- typename Epilogue0::WarpTileIterator,
85
- typename Epilogue0::SharedLoadIterator,
86
- OutputOp0,
87
- OutputOp1,
88
- OutputOp2,
89
- typename Epilogue0::Padding,
90
- kStoreD0,
91
- kStoreD1,
92
- Epilogue0::kFragmentsPerIteration,
93
- true // IterationsUnroll
94
- >;
95
-
96
- using ElementA = typename DualMma::IteratorA::Element;
97
- using ElementB = typename DualMma::IteratorB0::Element;
98
- using ElementC = typename DualEpilogue::OutputTileIterator::Element;
99
-
100
- static bool const kSplitKSerial = SplitKSerial;
101
- static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1),
102
- "Split-K serial requires buffers for D0/D1 for reduction");
103
-
104
- /// Warp count (concept: GemmShape)
105
- using WarpCount0 = typename DualMma::WarpCount;
106
- static int const kThreadCount = 32 * WarpCount0::kCount;
107
-
108
- /// Parameters structure
109
- struct Params {
110
- DualGemmMode mode;
111
- cutlass::gemm::GemmCoord problem_size;
112
- cutlass::gemm::GemmCoord grid_tiled_shape;
113
- int swizzle_log_tile;
114
-
115
- // Mma0
116
- typename DualMma::IteratorA::Params params_A0;
117
- typename DualMma::IteratorA::TensorRef ref_A0;
118
- typename DualMma::IteratorB0::Params params_B0;
119
- typename DualMma::IteratorB0::TensorRef ref_B0;
120
- typename Epilogue0::OutputTileIterator::Params params_C0;
121
- typename Epilogue0::OutputTileIterator::TensorRef ref_C0;
122
- typename Epilogue0::OutputTileIterator::Params params_D0;
123
- typename Epilogue0::OutputTileIterator::TensorRef ref_D0;
124
- typename OutputOp0::Params output_op_0;
125
-
126
- // Mma1
127
- typename DualMma::IteratorB1::Params params_B1;
128
- typename DualMma::IteratorB1::TensorRef ref_B1;
129
- typename Epilogue1::OutputTileIterator::Params params_C1;
130
- typename Epilogue1::OutputTileIterator::TensorRef ref_C1;
131
- typename Epilogue1::OutputTileIterator::Params params_D1;
132
- typename Epilogue1::OutputTileIterator::TensorRef ref_D1;
133
- typename OutputOp1::Params output_op_1;
134
-
135
- typename Epilogue1::OutputTileIterator::Params params_D2;
136
- typename Epilogue1::OutputTileIterator::TensorRef ref_D2;
137
- typename OutputOp2::Params output_op_2;
138
-
139
- int *semaphore;
140
- int gemm_k_size;
141
-
142
- int64_t batch_stride_A;
143
- int64_t batch_stride_B0;
144
- int64_t batch_stride_B1;
145
- int64_t batch_stride_C;
146
- int64_t batch_stride_D;
147
-
148
- //
149
- // Methods
150
- //
151
-
152
- CUTLASS_HOST_DEVICE
153
- Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { }
154
-
155
- CUTLASS_HOST_DEVICE
156
- Params(
157
- DualGemmMode mode,
158
- cutlass::gemm::GemmCoord const & problem_size,
159
- cutlass::gemm::GemmCoord const & grid_tiled_shape,
160
- // Mma0: D0 = A @ B0 + C0
161
- typename DualMma::IteratorA::TensorRef ref_A0,
162
- typename DualMma::IteratorB0::TensorRef ref_B0,
163
- typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
164
- typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
165
- // Mma1: D1 = A @ B1 + C1
166
- typename DualMma::IteratorB1::TensorRef ref_B1,
167
- typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
168
- typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
169
-
170
- typename Epilogue1::OutputTileIterator::TensorRef ref_D2,
171
- typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
172
- typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
173
- typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(),
174
- int *workspace = nullptr,
175
- int64_t batch_stride_A = 1,
176
- int64_t batch_stride_B0 = 1,
177
- int64_t batch_stride_B1 = 1,
178
- int64_t batch_stride_C = 1,
179
- int64_t batch_stride_D = 1
180
- ):
181
- mode(mode),
182
- problem_size(problem_size),
183
- grid_tiled_shape(grid_tiled_shape),
184
- swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
185
- // Mma0
186
- params_A0(ref_A0.layout()),
187
- ref_A0(ref_A0),
188
- params_B0(ref_B0.layout()),
189
- ref_B0(ref_B0),
190
- params_C0(ref_C0.layout()),
191
- ref_C0(ref_C0),
192
- params_D0(ref_D0.layout()),
193
- ref_D0(ref_D0),
194
- // Mma1
195
- params_B1(ref_B1.layout()),
196
- ref_B1(ref_B1),
197
- params_C1(ref_C1.layout()),
198
- ref_C1(ref_C1),
199
- params_D1(ref_D1.layout()),
200
- ref_D1(ref_D1),
201
- params_D2(ref_D2.layout()),
202
- ref_D2(ref_D2),
203
- output_op_0(output_op_0),
204
- output_op_1(output_op_1),
205
- output_op_2(output_op_2),
206
- batch_stride_A(batch_stride_A),
207
- batch_stride_B0(batch_stride_B0),
208
- batch_stride_B1(batch_stride_B1),
209
- batch_stride_C(batch_stride_C),
210
- batch_stride_D(batch_stride_D) {
211
-
212
- int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
213
- int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
214
- gemm_k_size = gemm_k_iterations * DualMma::Shape::kK;
215
-
216
- semaphore = workspace;
217
- }
218
- };
219
-
220
- /// Shared memory storage structure
221
- union SharedStorage {
222
- typename DualMma::SharedStorage main_loop;
223
- typename DualEpilogue::SharedStorage epilogue;
224
- };
225
-
226
- //
227
- // Methods
228
- //
229
-
230
- CUTLASS_HOST_DEVICE
231
- DualGemm() { }
232
-
233
- /// Determines whether kernel satisfies alignment
234
- static Status can_implement(
235
- cutlass::gemm::GemmCoord const & problem_size,
236
- typename DualMma::IteratorA::TensorRef ref_A0,
237
- typename DualMma::IteratorB0::TensorRef ref_B0,
238
- typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
239
- typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
240
- typename DualMma::IteratorB1::TensorRef ref_B1,
241
- typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
242
- typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
243
- typename Epilogue1::OutputTileIterator::TensorRef ref_D2) {
244
-
245
- static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements;
246
- static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements;
247
- static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess;
248
-
249
- if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
250
- return Status::kErrorMisalignedOperand;
251
- }
252
-
253
- if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
254
- return Status::kErrorMisalignedOperand;
255
- }
256
-
257
- if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
258
- return Status::kErrorMisalignedOperand;
259
- }
260
-
261
- if (!TensorRef_aligned(ref_D0, kAlignmentC)) {
262
- return Status::kErrorMisalignedOperand;
263
- }
264
-
265
- if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
266
- return Status::kErrorMisalignedOperand;
267
- }
268
-
269
- if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
270
- return Status::kErrorMisalignedOperand;
271
- }
272
-
273
- if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
274
- return Status::kErrorMisalignedOperand;
275
- }
276
-
277
- if (!TensorRef_aligned(ref_D2, kAlignmentC)) {
278
- return Status::kErrorMisalignedOperand;
279
- }
280
-
281
- return Status::kSuccess;
282
- }
283
-
284
- /// Executes one GEMM
285
- CUTLASS_DEVICE
286
- void operator()(Params const &params, SharedStorage &shared_storage) {
287
- // Compute threadblock location
288
- ThreadblockSwizzle threadblock_swizzle;
289
-
290
- cutlass::gemm::GemmCoord threadblock_tile_offset =
291
- threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
292
-
293
- // Early exit if CTA is out of range
294
- if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
295
- params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
296
-
297
- return;
298
- }
299
-
300
- int offset_k = 0;
301
- int problem_size_k = params.problem_size.k();
302
-
303
- ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
304
- ElementB *ptr_B0 = static_cast<ElementB *>(params.ref_B0.data());
305
- ElementB *ptr_B1 = static_cast<ElementB *>(params.ref_B1.data());
306
-
307
- //
308
- // Fetch pointers based on mode.
309
- //
310
- if (params.mode == DualGemmMode::kGemm) {
311
- if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
312
- problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
313
- }
314
-
315
- offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
316
- }
317
- else if (params.mode == DualGemmMode::kBatched) {
318
- ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A;
319
- ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
320
- ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
321
- }
322
-
323
- // Compute initial location in logical coordinates
324
- cutlass::MatrixCoord tb_offset_A0{
325
- threadblock_tile_offset.m() * DualMma::Shape::kM,
326
- offset_k,
327
- };
328
-
329
- cutlass::MatrixCoord tb_offset_B0{
330
- offset_k,
331
- threadblock_tile_offset.n() * DualMma::Shape::kN
332
- };
333
-
334
- cutlass::MatrixCoord tb_offset_B1{
335
- offset_k,
336
- threadblock_tile_offset.n() * DualMma::Shape::kN
337
- };
338
-
339
- // Compute position within threadblock
340
- int thread_idx = threadIdx.x;
341
-
342
- // Construct iterators to A and B operands
343
- typename DualMma::IteratorA iterator_A0(
344
- params.params_A0,
345
- ptr_A0,
346
- {params.problem_size.m(), problem_size_k},
347
- thread_idx,
348
- tb_offset_A0);
349
-
350
- typename DualMma::IteratorB0 iterator_B0(
351
- params.params_B0,
352
- ptr_B0,
353
- {problem_size_k, params.problem_size.n()},
354
- thread_idx,
355
- tb_offset_B0);
356
-
357
- typename DualMma::IteratorB1 iterator_B1(
358
- params.params_B1,
359
- ptr_B1,
360
- {problem_size_k, params.problem_size.n()},
361
- thread_idx,
362
- tb_offset_B1);
363
-
364
-
365
- // Broadcast the warp_id computed by lane 0 to ensure dependent code
366
- // is compiled as warp-uniform.
367
- int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
368
- int lane_idx = threadIdx.x % 32;
369
-
370
- //
371
- // Main loop
372
- //
373
-
374
-
375
- // Construct thread-scoped matrix multiply
376
- typename DualMma::FragmentC accum0;
377
- typename DualMma::FragmentC accum1;
378
- accum0.clear();
379
- accum1.clear();
380
-
381
- // Compute threadblock-scoped matrix multiply-add
382
- int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
383
-
384
- DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
385
- if (!kSplitKSerial || gemm_k_iterations > 0) {
386
- // Compute threadblock-scoped matrix multiply-add
387
- mma(gemm_k_iterations,
388
- accum0, accum1,
389
- iterator_A0, iterator_B0, iterator_B1,
390
- accum0, accum1);
391
- }
392
-
393
- //
394
- // Epilogue
395
- //
396
-
397
- OutputOp0 output_op_0(params.output_op_0);
398
- OutputOp1 output_op_1(params.output_op_1);
399
- OutputOp2 output_op_2(params.output_op_2);
400
-
401
- //
402
- // Masked tile iterators constructed from members
403
- //
404
-
405
- threadblock_tile_offset =
406
- threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
407
-
408
- //assume identity swizzle
409
- MatrixCoord threadblock_offset(
410
- threadblock_tile_offset.m() * DualMma::Shape::kM,
411
- threadblock_tile_offset.n() * DualMma::Shape::kN
412
- );
413
-
414
- int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
415
-
416
- ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
417
- ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
418
- ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
419
- ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
420
- ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
421
-
422
- // Construct the semaphore.
423
- Semaphore semaphore(params.semaphore + block_idx, thread_idx);
424
-
425
- if (params.mode == DualGemmMode::kGemm) {
426
- // If performing a reduction via split-K, fetch the initial synchronization
427
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
428
-
429
- // Fetch the synchronization lock initially but do not block.
430
- semaphore.fetch();
431
-
432
- // Indicate which position in a serial reduction the output operator is currently updating
433
- output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
434
- output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
435
- }
436
- }
437
- else if (params.mode == DualGemmMode::kBatched) {
438
- ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C;
439
- ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C;
440
- ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D;
441
- ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D;
442
- ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D;
443
- }
444
-
445
- // Tile iterator loading from source tensor.
446
- typename Epilogue0::OutputTileIterator iterator_C0(
447
- params.params_C0,
448
- ptr_C0,
449
- params.problem_size.mn(),
450
- thread_idx,
451
- threadblock_offset
452
- );
453
- typename Epilogue1::OutputTileIterator iterator_C1(
454
- params.params_C1,
455
- ptr_C1,
456
- params.problem_size.mn(),
457
- thread_idx,
458
- threadblock_offset
459
- );
460
-
461
- // Tile iterator writing to destination tensor.
462
- typename Epilogue0::OutputTileIterator iterator_D0(
463
- params.params_D0,
464
- ptr_D0,
465
- params.problem_size.mn(),
466
- thread_idx,
467
- threadblock_offset
468
- );
469
- typename Epilogue1::OutputTileIterator iterator_D1(
470
- params.params_D1,
471
- ptr_D1,
472
- params.problem_size.mn(),
473
- thread_idx,
474
- threadblock_offset
475
- );
476
- typename Epilogue1::OutputTileIterator iterator_D2(
477
- params.params_D2,
478
- ptr_D2,
479
- params.problem_size.mn(),
480
- thread_idx,
481
- threadblock_offset
482
- );
483
-
484
- DualEpilogue epilogue(
485
- shared_storage.epilogue,
486
- thread_idx,
487
- warp_idx,
488
- lane_idx);
489
-
490
- // Wait on the semaphore - this latency may have been covered by iterator construction
491
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
492
-
493
- // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
494
- if (threadblock_tile_offset.k()) {
495
- iterator_C0 = iterator_D0;
496
- iterator_C1 = iterator_D1;
497
- }
498
-
499
- semaphore.wait(threadblock_tile_offset.k());
500
-
501
- __threadfence();
502
- }
503
-
504
- // Execute the epilogue operator to update the destination tensor.
505
- typename Epilogue0::OutputTileIterator source_iters[] = {
506
- iterator_C0, iterator_C1
507
- };
508
- const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1);
509
- epilogue(
510
- output_op_0, output_op_1, output_op_2,
511
- iterator_D0, iterator_D1, iterator_D2,
512
- accum0, accum1,
513
- source_iters,
514
- writeToD2
515
- );
516
-
517
- //
518
- // Release the semaphore
519
- //
520
-
521
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
522
-
523
- int lock = 0;
524
- if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
525
-
526
- // The final threadblock resets the semaphore for subsequent grids.
527
- lock = 0;
528
- }
529
- else {
530
- // Otherwise, the semaphore is incremented
531
- lock = threadblock_tile_offset.k() + 1;
532
- }
533
-
534
- __threadfence();
535
- semaphore.release(lock);
536
- }
537
- }
538
- };
539
-
540
- /////////////////////////////////////////////////////////////////////////////////////////////////
541
-
542
- } // namespace kernel
543
- } // namespace gemm
544
- } // namespace cutlass
545
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h DELETED
@@ -1,95 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
-
33
- #include <iostream>
34
-
35
- // Run tests on GPUs
36
-
37
- int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
38
-
39
- bool supported = false;
40
-
41
- int arch_major = arch / 10;
42
- int arch_minor = arch - arch / 10 * 10;
43
-
44
- if(arch_major >= 8) {
45
- // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
46
- //
47
- // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
48
- if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
49
- supported = true;
50
- }
51
- }
52
- else if(arch_major >= 7) {
53
- // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
54
- //
55
- // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
56
- if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
57
- supported = true;
58
- }
59
- }
60
-
61
- cudaDeviceProp props;
62
-
63
- cudaError_t error = cudaGetDeviceProperties(&props, 0);
64
- if (error != cudaSuccess) {
65
- std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
66
- return -1;
67
- }
68
-
69
- if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) {
70
- supported = false;
71
- }
72
-
73
- if (!supported) {
74
- // Returning zero so this test passes on older Toolkits. Its actions are no-op.
75
- std::cout << "This example isn't supported on current architecture" << std::endl;
76
- return 0;
77
- }
78
-
79
- bool pass = true;
80
-
81
- std::cout << "Device: " << props.name << std::endl;
82
- std::cout << "Arch: SM" << arch << std::endl;
83
- std::cout << "Test: " << test_name << std::endl;
84
- for(auto func : test_funcs) {
85
- pass &= func();
86
- }
87
-
88
-
89
- if(pass)
90
- return 0;
91
- else
92
- return -1;
93
-
94
- }
95
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h DELETED
@@ -1,150 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Functor performing linear combination operations used by epilogues.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/cutlass.h"
38
- #include "cutlass/numeric_types.h"
39
- #include "cutlass/array.h"
40
- #include "cutlass/functional.h"
41
- #include "cutlass/numeric_conversion.h"
42
- #include "cutlass/epilogue/thread/scale_type.h"
43
- #include "cutlass/epilogue/thread/linear_combination_params.h"
44
-
45
- /////////////////////////////////////////////////////////////////////////////////////////////////
46
-
47
- namespace cutlass {
48
- namespace epilogue {
49
- namespace thread {
50
-
51
- /////////////////////////////////////////////////////////////////////////////////////////////////
52
-
53
- /// Applies a linear combination operator to an array of elements.
54
- ///
55
- /// D = alpha * accumulator + beta * source + uniform
56
- ///
57
- template <
58
- typename ElementOutput_, ///< Data type used to load and store tensors
59
- int Count, ///< Number of elements computed per operation.
60
- ///< Usually it is 128/sizeof_bits<ElementOutput_>,
61
- ///< but we use 64 or 32 sometimes when there are not enough data to store
62
- typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
63
- typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
64
- FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
65
- >
66
- class LeftSiLUAndMul {
67
- public:
68
-
69
- using ElementOutput = ElementOutput_;
70
- using ElementAccumulator = ElementAccumulator_;
71
- using ElementCompute = ElementCompute_;
72
-
73
- static int const kCount = Count;
74
- using FragmentOutput = Array<ElementOutput, kCount>;
75
- using FragmentAccumulator = Array<ElementAccumulator, kCount>;
76
- using ComputeFragment = Array<ElementCompute, kCount>;
77
-
78
- static FloatRoundStyle const kRound = Round;
79
-
80
- struct Params{};
81
-
82
- private:
83
-
84
- //
85
- // Data members
86
- //
87
-
88
- ElementCompute alpha_;
89
- ElementCompute beta_;
90
-
91
- public:
92
-
93
- /// Constructs the function object, possibly loading from pointers in host memory
94
- CUTLASS_HOST_DEVICE
95
- LeftSiLUAndMul(Params const &/*params*/) {}
96
-
97
- /// Returns true if source is needed
98
- CUTLASS_HOST_DEVICE
99
- bool is_source_needed() const {
100
- return true;
101
- }
102
-
103
- /// Functionally required for serial reduction in the epilogue
104
- CUTLASS_HOST_DEVICE
105
- void set_k_partition(int k_partition, int k_partition_count) {
106
- assert(false);
107
- }
108
-
109
- /// Computes linear scaling: D = alpha * accumulator + beta * source
110
- CUTLASS_HOST_DEVICE
111
- FragmentOutput operator()(
112
- FragmentAccumulator const &lhs,
113
- FragmentAccumulator const &rhs) const {
114
-
115
- // Convert source to interal compute numeric type
116
- NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_to_compute;
117
-
118
- // Convert to destination numeric type
119
- NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> compute_to_output;
120
-
121
- ComputeFragment converted_lhs = accumulator_to_compute(lhs);
122
- ComputeFragment converted_rhs = accumulator_to_compute(rhs);
123
-
124
- cutlass::epilogue::thread::SiLu<ComputeFragment> silu;
125
- cutlass::multiplies<ComputeFragment> mul;
126
- auto silu_lhs = silu(converted_lhs);
127
- return compute_to_output(mul(silu_lhs, converted_rhs));
128
- }
129
-
130
- CUTLASS_HOST_DEVICE
131
- ElementOutput operator()(
132
- ElementAccumulator const& lhs,
133
- ElementAccumulator const& rhs
134
- ) const {
135
- ElementCompute convert_lhs(lhs);
136
- ElementCompute convert_rhs(rhs);
137
- cutlass::epilogue::thread::SiLu<ElementCompute> silu;
138
- cutlass::multiplies<ElementCompute> mul;
139
- auto silu_lhs = silu(convert_lhs);
140
- return ElementOutput(mul(silu_lhs, convert_rhs));
141
- }
142
- };
143
-
144
- /////////////////////////////////////////////////////////////////////////////////////////////////
145
-
146
- } // namespace thread
147
- } // namespace epilogue
148
- } // namespace cutlass
149
-
150
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h DELETED
@@ -1,424 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
-
34
- The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
- tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
-
37
- */
38
-
39
- #pragma once
40
- #include "cutlass/array.h"
41
- #include CUDA_STD_HEADER(cassert)
42
- #include "cutlass/cutlass.h"
43
- #include "cutlass/numeric_types.h"
44
- #include "cutlass/layout/vector.h"
45
- #include "cutlass/layout/tensor.h"
46
- #include "cutlass/tensor_coord.h"
47
- #include "cutlass/aligned_buffer.h"
48
- #include "cutlass/functional.h"
49
-
50
- #include "cutlass/gemm/gemm.h"
51
-
52
- #include "cutlass/transform/pitch_linear_thread_map.h"
53
- #include "cutlass/transform/threadblock/regular_tile_iterator.h"
54
-
55
- #include "cutlass/epilogue/threadblock/epilogue_base.h"
56
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
57
- #include "cutlass/numeric_types.h"
58
-
59
- ////////////////////////////////////////////////////////////////////////////////
60
-
61
- namespace cutlass {
62
- namespace epilogue {
63
- namespace threadblock {
64
-
65
- ////////////////////////////////////////////////////////////////////////////////
66
-
67
- /// Epilogue operator
68
- template <
69
- typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
70
- typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
71
- int PartitionsK, ///< Number of partitions of the K dimension
72
- typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
73
- typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
74
- typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
75
- typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
76
- ///< Output operator
77
- typename OutputOp0_,
78
- typename OutputOp1_,
79
- typename OutputOp2_,
80
- typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
81
- bool StoreD0 = true,
82
- bool StoreD1 = true,
83
- int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
84
- int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
85
- (!IsEpilogueFunctorHeavy<OutputOp0_>::value)
86
- >
87
- class DualEpilogue {
88
-
89
- public:
90
-
91
- using Base = EpilogueBase<
92
- Shape_,
93
- typename WarpMmaOperator_::Shape,
94
- PartitionsK,
95
- AccumulatorFragmentIterator_,
96
- WarpTileIterator_,
97
- Padding_,
98
- FragmentsPerPartition>;
99
-
100
- using Shape = Shape_;
101
- using WarpMmaOperator = WarpMmaOperator_;
102
- static int const kPartitionsK = PartitionsK;
103
- static bool constexpr kStoreD0 = StoreD0;
104
- static bool constexpr kStoreD1 = StoreD1;
105
- using OutputTileIterator = OutputTileIterator_;
106
- using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
107
- using WarpTileIterator = WarpTileIterator_;
108
- using SharedLoadIterator = SharedLoadIterator_;
109
- using OutputOp0 = OutputOp0_;
110
- using OutputOp1 = OutputOp1_;
111
- using OutputOp2 = OutputOp2_;
112
- using Padding = Padding_;
113
-
114
- using Layout = layout::RowMajor;
115
- using LongIndex = typename Layout::LongIndex;
116
-
117
- /// The complete warp-level accumulator tile
118
- using AccumulatorTile = typename Base::AccumulatorTile;
119
-
120
- /// Accumulator element
121
- using ElementAccumulator = typename WarpTileIterator::Element;
122
-
123
- /// Output element
124
- using ElementOutput = typename OutputTileIterator::Element;
125
-
126
- /// Output access size
127
- static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
128
-
129
- /// Tensor reference to destination tensor
130
- using TensorRef = typename OutputTileIterator::TensorRef;
131
-
132
- /// Tensor reference to sync tensor
133
- using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
134
-
135
- /// Const tensor reference to source tensor
136
- using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
137
-
138
- /// Array type used to output
139
- using OutputAccessType = Array<
140
- typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
141
-
142
- /// Array type used by output functor
143
- using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
144
-
145
- /// Number of warps
146
- using WarpCount = typename Base::WarpCount;
147
-
148
- struct SharedStorage {
149
- using Element = typename WarpTileIterator::Element;
150
-
151
- /// Tensor reference to shared memory allocation
152
- using TensorRef = typename WarpTileIterator::TensorRef;
153
-
154
- /// Logical shape of the shared memory tile written to by all warps.
155
- using Shape = typename Base::Shape;
156
-
157
- /// Shape of the shared memory allocation for the epilogue
158
- using StorageShape = typename Base::SharedStorage::StorageShape;
159
-
160
- //
161
- // Data members
162
- //
163
-
164
- AlignedBuffer<Element, StorageShape::kCount> storage[2];
165
-
166
- //
167
- // Methods
168
- //
169
-
170
- /// Returns a tensor reference to the shared memory buffer
171
- CUTLASS_DEVICE
172
- TensorRef reference(int i) {
173
- return TensorRef(
174
- storage[i].data(),
175
- Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
176
- }
177
- };
178
-
179
- static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
180
- static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles;
181
-
182
- public:
183
-
184
- static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
185
- "Mismatch between shared load iterator and output tile iterator.");
186
-
187
- static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
188
-
189
- static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
190
- "Divisibility");
191
-
192
- private:
193
-
194
- /// Loads fragment from shared memory aligned with output tensor
195
- SharedLoadIterator shared_load_iterator0_;
196
- SharedLoadIterator shared_load_iterator1_;
197
-
198
- /// Stores a warp's fragment of accumulators to SMEM
199
- WarpTileIterator warp_tile_iterator0_;
200
- WarpTileIterator warp_tile_iterator1_;
201
-
202
- public:
203
-
204
- /// Constructor
205
- CUTLASS_DEVICE
206
- DualEpilogue(
207
- SharedStorage &shared_storage, ///< Shared storage object
208
- int thread_idx, ///< ID of a thread within the threadblock
209
- int warp_idx, ///< ID of warp within threadblock
210
- int lane_idx ///< Id of thread within warp
211
- ):
212
- shared_load_iterator0_(shared_storage.reference(0), thread_idx),
213
- shared_load_iterator1_(shared_storage.reference(1), thread_idx),
214
- warp_tile_iterator0_(shared_storage.reference(0), lane_idx),
215
- warp_tile_iterator1_(shared_storage.reference(1), lane_idx)
216
- {
217
- int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
218
- int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
219
- int warp_m = warp_mn % WarpCount::kM;
220
- int warp_n = warp_mn / WarpCount::kM;
221
-
222
- MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
223
-
224
- warp_tile_iterator0_.add_tile_offset(warp_offset);
225
- warp_tile_iterator1_.add_tile_offset(warp_offset);
226
- }
227
-
228
- /// Streams the result to global memory
229
- CUTLASS_DEVICE
230
- void operator()(
231
- OutputOp0 const &output_op0,
232
- OutputOp1 const &output_op1,
233
- OutputOp2 const &output_op2,
234
- OutputTileIterator dest0,
235
- OutputTileIterator dest1,
236
- OutputTileIterator dest2,
237
- AccumulatorTile const &accumulator0,
238
- AccumulatorTile const &accumulator1,
239
- OutputTileIterator source_iterator[2],
240
- bool writeToD2 // true if it's the final split-k
241
- ) {
242
- // TODO: Implement when no source is needed
243
-
244
- typename OutputTileIterator::Fragment source_fragment[2];
245
- CUTLASS_PRAGMA_UNROLL
246
- for (int i = 0; i < 2; ++i) {
247
- source_fragment[i].clear();
248
- }
249
-
250
- //
251
- // Iterator over warp-level accumulator fragment
252
- //
253
-
254
- AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1};
255
-
256
- //
257
- // Iterate over accumulator tile
258
- //
259
-
260
- #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
261
- for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
262
-
263
- //
264
- // Load the source
265
- //
266
-
267
- CUTLASS_PRAGMA_UNROLL
268
- for (int i = 0; i < 2; ++i) {
269
- source_iterator[i].load(source_fragment[i]);
270
- ++source_iterator[i];
271
- }
272
-
273
- //
274
- // Convert and store fragment
275
- //
276
-
277
- __syncthreads();
278
-
279
- acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
280
- iter, accum_fragment_iterator[0], this->warp_tile_iterator0_);
281
- acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
282
- iter, accum_fragment_iterator[1], this->warp_tile_iterator1_);
283
-
284
- __syncthreads();
285
-
286
- //
287
- // Load fragments from shared memory
288
- //
289
-
290
- typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK];
291
- typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK];
292
-
293
- shared_load_iterator0_.load(aligned_accum_fragment0[0]);
294
- shared_load_iterator1_.load(aligned_accum_fragment1[0]);
295
-
296
- // If the number of k-slices is > 1 - perform a reduction amongst the k-slices
297
- if (kPartitionsK > 1) {
298
-
299
- plus <typename SharedLoadIterator::Fragment> add_fragments;
300
-
301
- CUTLASS_PRAGMA_UNROLL
302
- for ( int i = 1; i < kPartitionsK; ++i) {
303
- shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset);
304
- shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset);
305
- shared_load_iterator0_.load(aligned_accum_fragment0[i]);
306
- shared_load_iterator1_.load(aligned_accum_fragment1[i]);
307
- aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]);
308
- aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]);
309
- }
310
-
311
- shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
312
- shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
313
- }
314
-
315
- //
316
- // Compute the output result
317
- //
318
-
319
- typename OutputTileIterator::Fragment output_fragment[3];
320
-
321
- apply_output_operator_(output_fragment,
322
- output_op0, output_op1, output_op2,
323
- aligned_accum_fragment0[0], aligned_accum_fragment1[0],
324
- source_fragment);
325
-
326
-
327
- //
328
- // Store the final result
329
- //
330
-
331
- if (kStoreD0) {
332
- dest0.store(output_fragment[0]);
333
- ++dest0;
334
- }
335
- if (kStoreD1) {
336
- dest1.store(output_fragment[1]);
337
- ++dest1;
338
- }
339
- if (writeToD2) {
340
- dest2.store(output_fragment[2]);
341
- ++dest2;
342
- }
343
- }
344
- }
345
-
346
- private:
347
-
348
- static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
349
-
350
- template<class Seq>
351
- struct acc2smem_source_needed;
352
-
353
- template <size_t... Seq>
354
- struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
355
- template<int Advance>
356
- CUTLASS_DEVICE
357
- static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
358
- WarpTileIterator &warp_tile_iterator) {
359
- CUTLASS_PRAGMA_UNROLL
360
- for (int i = 0; i < Advance; i++) {
361
- ++accum_fragment_iterator;
362
- }
363
-
364
- typename AccumulatorFragmentIterator::Fragment accum_fragment;
365
- accum_fragment_iterator.load(accum_fragment);
366
- warp_tile_iterator.store(accum_fragment);
367
- }
368
-
369
- CUTLASS_DEVICE
370
- static void push(size_t pos,
371
- AccumulatorFragmentIterator const &iterator_begin,
372
- WarpTileIterator &warp_tile_iterator) {
373
- int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
374
- }
375
- };
376
-
377
- /// Helper to invoke the output functor over each vector of output
378
- CUTLASS_DEVICE
379
- void apply_output_operator_(
380
- typename OutputTileIterator::Fragment (&output_fragment)[3],
381
- OutputOp0 const &output_op0,
382
- OutputOp1 const &output_op1,
383
- OutputOp2 const &output_op2,
384
- typename SharedLoadIterator::Fragment const& aligned_accum_fragment0,
385
- typename SharedLoadIterator::Fragment const& aligned_accum_fragment1,
386
- typename OutputTileIterator::Fragment const (&source_fragment)[2]) {
387
-
388
- OutputAccessType* output_frag_ptr[3] = {
389
- reinterpret_cast<OutputAccessType *>(&output_fragment[0]),
390
- reinterpret_cast<OutputAccessType *>(&output_fragment[1]),
391
- reinterpret_cast<OutputAccessType *>(&output_fragment[2])
392
- };
393
-
394
- AccumulatorAccessType const *compute_frag_ptr[2] = {
395
- reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment0),
396
- reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment1)
397
- };
398
-
399
- OutputAccessType const *source_frag_ptr[2] = {
400
- reinterpret_cast<OutputAccessType const *>(&source_fragment[0]),
401
- reinterpret_cast<OutputAccessType const *>(&source_fragment[1])
402
- };
403
-
404
- int const kOutputOpIterations =
405
- OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
406
-
407
- CUTLASS_PRAGMA_UNROLL
408
- for (int i = 0; i < kOutputOpIterations; ++i) {
409
-
410
- // Call the output operators
411
- output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]);
412
- output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]);
413
- output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]);
414
- }
415
- }
416
- };
417
-
418
- ////////////////////////////////////////////////////////////////////////////////
419
-
420
- } // namespace threadblock
421
- } // namespace epilogue
422
- } // namespace cutlass
423
-
424
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h DELETED
@@ -1,232 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/aligned_buffer.h"
38
- #include "cutlass/arch/memory.h"
39
- #include "cutlass/array.h"
40
- #include "cutlass/cutlass.h"
41
- #include "cutlass/gemm/gemm.h"
42
- #include "cutlass/matrix_shape.h"
43
- #include "cutlass/numeric_types.h"
44
-
45
- #include "cutlass/gemm/threadblock/mma_base.h"
46
-
47
- ////////////////////////////////////////////////////////////////////////////////
48
-
49
- namespace cutlass {
50
- namespace gemm {
51
- namespace threadblock {
52
-
53
- ////////////////////////////////////////////////////////////////////////////////
54
-
55
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math
56
- /// instructions.
57
- template <
58
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
59
- typename Shape_,
60
- /// Policy describing tuning details (concept: MmaPolicy)
61
- typename Policy0_,
62
- /// B1-specific version of the policy (concept: MmaPolicy)
63
- typename Policy1_,
64
- /// Number of stages,
65
- int Stages,
66
- /// Used for partial specialization
67
- typename Enable = bool>
68
- class DualMmaBase {
69
- public:
70
- ///< Size of the Gemm problem - concept: gemm::GemmShape<>
71
- using Shape = Shape_;
72
-
73
- ///< Policy describing tuning details
74
- using Policy0 = Policy0_;
75
- using Policy1 = Policy1_;
76
-
77
- //
78
- // Dependent types
79
- //
80
-
81
- /// Warp-level Mma
82
- using Operator0 = typename Policy0::Operator;
83
- using Operator1 = typename Policy1::Operator;
84
-
85
- /// Shape describing the overall GEMM computed from shared memory
86
- /// by each warp.
87
- using WarpGemm = typename Policy0::Operator::Shape;
88
-
89
- /// Shape describing the number of warps filling the CTA
90
- using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
91
- Shape::kN / WarpGemm::kN,
92
- Shape::kK / WarpGemm::kK>;
93
-
94
- /// Number of warp-level GEMM oeprations
95
- static int const kWarpGemmIterations =
96
- (WarpGemm::kK / Operator0::Policy::MmaShape::kK);
97
-
98
- /// Number of stages
99
- static int const kStages = Stages;
100
-
101
- /// Tensor reference to the A operand
102
- using TensorRefA = TensorRef<typename Operator0::ElementA, typename Operator0::LayoutA>;
103
-
104
- /// Tensor reference to the B operand
105
- using TensorRefB0 = TensorRef<typename Operator0::ElementB, typename Operator0::LayoutB>;
106
- using TensorRefB1 = TensorRef<typename Operator1::ElementB, typename Operator1::LayoutB>;
107
-
108
- static_assert(kWarpGemmIterations > 1,
109
- "The pipelined structure requires at least two warp-level "
110
- "GEMM operations.");
111
-
112
- static_assert((kWarpGemmIterations % 2) == 0,
113
- "Inner loop iteration must be an even number.");
114
-
115
- //
116
- // Nested structs
117
- //
118
-
119
- /// Shared storage object needed by threadblock-scoped GEMM
120
- class SharedStorage {
121
- public:
122
- //
123
- // Type definitions
124
- //
125
-
126
- /// Shape of the A matrix operand in shared memory
127
- using ShapeA = MatrixShape<Shape::kM + Policy0::SmemPaddingA::kRow,
128
- Shape::kK * kStages +
129
- Policy0::SmemPaddingA::kColumn>;
130
-
131
- /// Shape of the B matrix operand in shared memory
132
- using ShapeB0 =
133
- MatrixShape<Shape::kK * kStages + Policy0::SmemPaddingB::kRow,
134
- Shape::kN + Policy0::SmemPaddingB::kColumn>;
135
- using ShapeB1 =
136
- MatrixShape<Shape::kK * kStages + Policy1::SmemPaddingB::kRow,
137
- Shape::kN + Policy1::SmemPaddingB::kColumn>;
138
-
139
- public:
140
- //
141
- // Data members
142
- //
143
-
144
- /// Buffer for A operand
145
- AlignedBuffer<typename Operator0::ElementA, ShapeA::kCount> operand_A;
146
-
147
- /// Buffer for B operand
148
- AlignedBuffer<typename Operator0::ElementB, ShapeB0::kCount> operand_B0;
149
- AlignedBuffer<typename Operator1::ElementB, ShapeB1::kCount> operand_B1;
150
-
151
- public:
152
-
153
- //
154
- // Methods
155
- //
156
-
157
- /// Returns a layout object for the A matrix
158
- CUTLASS_DEVICE
159
- static typename Operator0::LayoutA LayoutA() {
160
- return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
161
- }
162
-
163
- /// Returns a layout object for the B matrix
164
- CUTLASS_HOST_DEVICE
165
- static typename Operator0::LayoutB LayoutB0() {
166
- return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn});
167
- }
168
-
169
- /// Returns a layout object for the B matrix
170
- CUTLASS_HOST_DEVICE
171
- static typename Operator1::LayoutB LayoutB1() {
172
- return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn});
173
- }
174
-
175
- /// Returns a TensorRef to the A operand
176
- CUTLASS_HOST_DEVICE
177
- TensorRefA operand_A_ref() {
178
- return TensorRefA{operand_A.data(), LayoutA()};
179
- }
180
-
181
- /// Returns a TensorRef to the B operand
182
- CUTLASS_HOST_DEVICE
183
- TensorRefB0 operand_B0_ref() {
184
- return TensorRefB0{operand_B0.data(), LayoutB0()};
185
- }
186
- CUTLASS_HOST_DEVICE
187
- TensorRefB1 operand_B1_ref() {
188
- return TensorRefB1{operand_B1.data(), LayoutB1()};
189
- }
190
- };
191
-
192
- protected:
193
-
194
- //
195
- // Data members
196
- //
197
-
198
- /// Iterator to load a warp-scoped tile of A operand from shared memory
199
- typename Operator0::IteratorA warp_tile_iterator_A_;
200
-
201
- /// Iterator to load a warp-scoped tile of B operand from shared memory
202
- typename Operator0::IteratorB warp_tile_iterator_B0_;
203
- typename Operator1::IteratorB warp_tile_iterator_B1_;
204
-
205
- public:
206
-
207
- /// Construct from tensor references
208
- CUTLASS_DEVICE
209
- DualMmaBase(
210
- ///< Shared storage needed for internal use by threadblock-scoped GEMM
211
- SharedStorage &shared_storage,
212
- ///< ID within the threadblock
213
- int thread_idx,
214
- ///< ID of warp
215
- int warp_idx,
216
- ///< ID of each thread within a warp
217
- int lane_idx
218
- ):
219
- warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
220
- warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx),
221
- warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) {
222
-
223
- }
224
- };
225
-
226
- /////////////////////////////////////////////////////////////////////////////////////////////////
227
-
228
- } // namespace threadblock
229
- } // namespace gemm
230
- } // namespace cutlass
231
-
232
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h DELETED
@@ -1,775 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/aligned_buffer.h"
38
- #include "cutlass/arch/memory.h"
39
- #include "cutlass/array.h"
40
- #include "cutlass/cutlass.h"
41
- #include "cutlass/gemm/gemm.h"
42
- #include "cutlass/matrix_shape.h"
43
- #include "cutlass/numeric_types.h"
44
-
45
- #include "cutlass/gemm/threadblock/mma_base.h"
46
- #include "dual_mma_base.h"
47
-
48
- /////////////////////////////////////////////////////////////////////////////////////////////////
49
-
50
- namespace cutlass {
51
- namespace gemm {
52
- namespace threadblock {
53
-
54
- /////////////////////////////////////////////////////////////////////////////////////////////////
55
-
56
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math
57
- /// instructions.
58
- template <
59
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
60
- typename Shape_,
61
- /// Iterates over tiles of A operand in global memory
62
- // (concept: ReadableTileIterator | ForwardTileIterator |
63
- // MaskedTileIterator)
64
- typename IteratorA_,
65
- /// Iterates over tiles of A operand in shared memory
66
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
67
- typename SmemIteratorA_,
68
- /// Cache operation for operand A
69
- cutlass::arch::CacheOperation::Kind CacheOpA,
70
- /// Iterates over tiles of B0 operand in global memory
71
- // (concept: ReadableTileIterator | ForwardTileIterator |
72
- // MaskedTileIterator)
73
- typename IteratorB0_,
74
- /// Iterates over tiles of B0 operand in shared memory
75
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
76
- typename SmemIteratorB0_,
77
- /// Cache operation for operand B
78
- cutlass::arch::CacheOperation::Kind CacheOpB,
79
- /// Iterates over tiles of B1 operand in global memory
80
- // (concept: ReadableTileIterator | ForwardTileIterator |
81
- // MaskedTileIterator)
82
- typename IteratorB1_,
83
- /// Iterates over tiles of B1 operand in shared memory
84
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
85
- typename SmemIteratorB1_,
86
- /// Data type of accumulator matrix
87
- typename ElementC_,
88
- /// Data type of accumulator matrix
89
- typename LayoutC_,
90
- /// Policy describing tuning details (concept: MmaPolicy)
91
- typename Policy0_,
92
- /// B1-specific version of the policy (concept: MmaPolicy)
93
- typename Policy1_,
94
- /// Number of stages,
95
- int Stages,
96
- /// Use zfill or predicate for out-of-bound cp.async
97
- SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
98
- /// Used for partial specialization
99
- typename Enable = bool>
100
- class DualMmaMultistage :
101
- public DualMmaBase<Shape_, Policy0_, Policy1_, Stages> {
102
- public:
103
- ///< Base class
104
- using Base = DualMmaBase<Shape_, Policy0_, Policy1_, Stages>;
105
- ///< Size of the Gemm problem - concept: gemm::GemmShape<>
106
- using Shape = Shape_;
107
- ///< Iterates over tiles of A operand in global memory
108
- using IteratorA = IteratorA_;
109
- ///< Iterates over tiles of B0 operand in global memory
110
- using IteratorB0 = IteratorB0_;
111
- ///< Iterates over tiles of B1 operand in global memory
112
- using IteratorB1 = IteratorB1_;
113
- ///< Data type of accumulator matrix
114
- using ElementC = ElementC_;
115
- ///< Layout of accumulator matrix
116
- using LayoutC = LayoutC_;
117
- ///< Policy describing tuning details
118
- using Policy0 = Policy0_;
119
- using Policy1 = Policy1_;
120
-
121
- using SmemIteratorA = SmemIteratorA_;
122
- using SmemIteratorB0 = SmemIteratorB0_;
123
- using SmemIteratorB1 = SmemIteratorB1_;
124
-
125
- static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
126
- static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
127
-
128
- //
129
- // Dependent types
130
- //
131
-
132
- /// Fragment of accumulator tile
133
- using FragmentC = typename Policy0::Operator::FragmentC;
134
-
135
- /// Warp-level Mma
136
- using Operator0 = typename Policy0::Operator;
137
- using Operator1 = typename Policy1::Operator;
138
-
139
- /// Minimum architecture is Sm80 to support cp.async
140
- using ArchTag = arch::Sm80;
141
-
142
- /// Complex transform on A operand
143
- static ComplexTransform const kTransformA = Operator0::kTransformA;
144
-
145
- /// Complex transform on B operand
146
- static ComplexTransform const kTransformB0 = Operator0::kTransformB;
147
- static ComplexTransform const kTransformB1 = Operator1::kTransformB;
148
-
149
- /// Internal structure exposed for introspection.
150
- struct Detail {
151
-
152
- /// Number of cp.async instructions to load one stage of operand A
153
- static int const AsyncCopyIterationsPerStageA =
154
- IteratorA::ThreadMap::Iterations::kCount;
155
-
156
- /// Number of cp.async instructions to load one stage of operand B
157
- static int const AsyncCopyIterationsPerStageB =
158
- IteratorB0::ThreadMap::Iterations::kCount;
159
-
160
- /// Number of stages
161
- static int const kStages = Stages;
162
-
163
- /// Number of cp.async instructions to load on group of operand A
164
- static int const kAccessesPerGroupA =
165
- (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
166
-
167
- /// Number of cp.async instructions to load on group of operand B
168
- static int const kAccessesPerGroupB =
169
- (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
170
- };
171
-
172
- private:
173
-
174
- using WarpLoadedFragmentA = typename Operator0::FragmentA;
175
- using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
176
- using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
177
- using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA;
178
- using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
179
- using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
180
-
181
- private:
182
-
183
- //
184
- // Data members
185
- //
186
-
187
- /// Iterator to write threadblock-scoped tile of A operand to shared memory
188
- SmemIteratorA smem_iterator_A_;
189
-
190
- /// Iterator to write threadblock-scoped tile of B operand to shared memory
191
- SmemIteratorB0 smem_iterator_B0_;
192
- SmemIteratorB1 smem_iterator_B1_;
193
-
194
- public:
195
-
196
- /// Construct from tensor references
197
- CUTLASS_DEVICE
198
- DualMmaMultistage(
199
- ///< Shared storage needed for internal use by threadblock-scoped GEMM
200
- typename Base::SharedStorage &shared_storage,
201
- ///< ID within the threadblock
202
- int thread_idx,
203
- ///< ID of warp
204
- int warp_idx,
205
- ///< ID of each thread within a warp
206
- int lane_idx
207
- ):
208
- Base(shared_storage, thread_idx, warp_idx, lane_idx),
209
- smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
210
- smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx),
211
- smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx)
212
- {
213
- // Compute warp location within threadblock tile by mapping the warp_id to
214
- // three coordinates:
215
- // _m: the warp's position within the threadblock along the M dimension
216
- // _n: the warp's position within the threadblock along the N dimension
217
- // _k: the warp's position within the threadblock along the K dimension
218
-
219
- int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
220
- int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
221
-
222
- int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
223
- int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
224
-
225
- // Add per-warp offsets in units of warp-level tiles
226
- this->warp_tile_iterator_A_.add_tile_offset(
227
- {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
228
- this->warp_tile_iterator_B0_.add_tile_offset(
229
- {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
230
- this->warp_tile_iterator_B1_.add_tile_offset(
231
- {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
232
- }
233
-
234
- CUTLASS_DEVICE
235
- void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1,
236
- int group_start_A = 0, int group_start_B = 0) {
237
- iterator_A.set_iteration_index(group_start_A *
238
- IteratorA::kAccessesPerVector);
239
- this->smem_iterator_A_.set_iteration_index(group_start_A);
240
-
241
- // Async Copy for operand A
242
- CUTLASS_PRAGMA_UNROLL
243
- for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
244
- if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
245
- typename IteratorA::AccessType *dst_ptr =
246
- reinterpret_cast<typename IteratorA::AccessType *>(
247
- this->smem_iterator_A_.get());
248
-
249
- int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
250
- IteratorA::ThreadMap::kElementsPerAccess /
251
- IteratorA::kAccessesPerVector / 8;
252
-
253
- CUTLASS_PRAGMA_UNROLL
254
- for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
255
- auto gmem_ptr = iterator_A.get();
256
-
257
- if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
258
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
259
- dst_ptr + v, gmem_ptr, iterator_A.valid());
260
- } else {
261
- cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
262
- dst_ptr + v, gmem_ptr, iterator_A.valid());
263
- }
264
-
265
- ++iterator_A;
266
- }
267
-
268
- ++this->smem_iterator_A_;
269
- }
270
- }
271
-
272
- iterator_B0.set_iteration_index(group_start_B *
273
- IteratorB0::kAccessesPerVector);
274
- iterator_B1.set_iteration_index(group_start_B *
275
- IteratorB1::kAccessesPerVector);
276
- this->smem_iterator_B0_.set_iteration_index(group_start_B);
277
- this->smem_iterator_B1_.set_iteration_index(group_start_B);
278
-
279
- // Async Copy for operand B0
280
- CUTLASS_PRAGMA_UNROLL
281
- for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
282
- if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
283
- typename IteratorB0::AccessType *dst_ptr =
284
- reinterpret_cast<typename IteratorB0::AccessType *>(
285
- this->smem_iterator_B0_.get());
286
-
287
- int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
288
- IteratorB0::ThreadMap::kElementsPerAccess /
289
- IteratorB0::kAccessesPerVector / 8;
290
-
291
- CUTLASS_PRAGMA_UNROLL
292
- for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
293
- auto gmem_ptr = iterator_B0.get();
294
-
295
- if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
296
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
297
- dst_ptr + v, gmem_ptr, iterator_B0.valid());
298
- } else {
299
- cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
300
- dst_ptr + v, gmem_ptr, iterator_B0.valid());
301
- }
302
-
303
- ++iterator_B0;
304
- }
305
- ++this->smem_iterator_B0_;
306
- }
307
- }
308
- // Async Copy for operand B1
309
- CUTLASS_PRAGMA_UNROLL
310
- for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
311
- if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
312
- typename IteratorB1::AccessType *dst_ptr =
313
- reinterpret_cast<typename IteratorB1::AccessType *>(
314
- this->smem_iterator_B1_.get());
315
-
316
- int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
317
- IteratorB1::ThreadMap::kElementsPerAccess /
318
- IteratorB1::kAccessesPerVector / 8;
319
-
320
- CUTLASS_PRAGMA_UNROLL
321
- for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
322
- auto gmem_ptr = iterator_B1.get();
323
-
324
- if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
325
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
326
- dst_ptr + v, gmem_ptr, iterator_B1.valid());
327
- } else {
328
- cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
329
- dst_ptr + v, gmem_ptr, iterator_B1.valid());
330
- }
331
-
332
- ++iterator_B1;
333
- }
334
- ++this->smem_iterator_B1_;
335
- }
336
- }
337
- }
338
-
339
- /// Perform a threadblock-scoped matrix multiply-accumulate
340
- CUTLASS_DEVICE
341
- void operator()(
342
- ///< problem size of GEMM
343
- int gemm_k_iterations,
344
- ///< destination accumulator tile
345
- FragmentC &accum0,
346
- FragmentC &accum1,
347
- ///< iterator over A operand in global memory
348
- IteratorA iterator_A,
349
- ///< iterator over B operand in global memory
350
- IteratorB0 iterator_B0,
351
- IteratorB1 iterator_B1,
352
- ///< initial value of accumulator
353
- FragmentC const &src_accum0,
354
- FragmentC const &src_accum1
355
- ) {
356
-
357
- //
358
- // Prologue
359
- //
360
-
361
- // Issue several complete stages
362
- CUTLASS_PRAGMA_UNROLL
363
- for (int stage = 0; stage < Base::kStages - 1;
364
- ++stage, --gemm_k_iterations) {
365
-
366
- iterator_A.clear_mask(gemm_k_iterations == 0);
367
- iterator_B0.clear_mask(gemm_k_iterations == 0);
368
- iterator_B1.clear_mask(gemm_k_iterations == 0);
369
-
370
- iterator_A.set_iteration_index(0);
371
- this->smem_iterator_A_.set_iteration_index(0);
372
-
373
- // Async Copy for operand A
374
- CUTLASS_PRAGMA_UNROLL
375
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
376
- typename IteratorA::AccessType *dst_ptr =
377
- reinterpret_cast<typename IteratorA::AccessType *>(
378
- this->smem_iterator_A_.get());
379
-
380
- CUTLASS_PRAGMA_UNROLL
381
- for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
382
- int const kSrcBytes =
383
- sizeof_bits<typename IteratorA::Element>::value *
384
- IteratorA::ThreadMap::kElementsPerAccess /
385
- IteratorA::kAccessesPerVector / 8;
386
-
387
- int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
388
-
389
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
390
- dst_ptr + v, iterator_A.get(), iterator_A.valid());
391
-
392
- ++iterator_A;
393
- }
394
-
395
- ++this->smem_iterator_A_;
396
- }
397
-
398
- iterator_B0.set_iteration_index(0);
399
- iterator_B1.set_iteration_index(0);
400
- this->smem_iterator_B0_.set_iteration_index(0);
401
- this->smem_iterator_B1_.set_iteration_index(0);
402
-
403
- // Async Copy for operand B0
404
- CUTLASS_PRAGMA_UNROLL
405
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
406
- typename IteratorB0::AccessType *dst_ptr =
407
- reinterpret_cast<typename IteratorB0::AccessType *>(
408
- this->smem_iterator_B0_.get());
409
-
410
- CUTLASS_PRAGMA_UNROLL
411
- for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
412
- int const kSrcBytes =
413
- sizeof_bits<typename IteratorB0::Element>::value *
414
- IteratorB0::ThreadMap::kElementsPerAccess /
415
- IteratorB0::kAccessesPerVector / 8;
416
-
417
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
418
- dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
419
-
420
- ++iterator_B0;
421
- }
422
-
423
- ++this->smem_iterator_B0_;
424
- }
425
- // Async Copy for operand B1
426
- CUTLASS_PRAGMA_UNROLL
427
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
428
- typename IteratorB1::AccessType *dst_ptr =
429
- reinterpret_cast<typename IteratorB1::AccessType *>(
430
- this->smem_iterator_B1_.get());
431
-
432
- CUTLASS_PRAGMA_UNROLL
433
- for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
434
- int const kSrcBytes =
435
- sizeof_bits<typename IteratorB1::Element>::value *
436
- IteratorB1::ThreadMap::kElementsPerAccess /
437
- IteratorB1::kAccessesPerVector / 8;
438
-
439
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
440
- dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
441
-
442
- ++iterator_B1;
443
- }
444
-
445
- ++this->smem_iterator_B1_;
446
- }
447
-
448
- // Move to the next stage
449
- iterator_A.add_tile_offset({0, 1});
450
- iterator_B0.add_tile_offset({1, 0});
451
- iterator_B1.add_tile_offset({1, 0});
452
-
453
- this->smem_iterator_A_.add_tile_offset({0, 1});
454
- this->smem_iterator_B0_.add_tile_offset({1, 0});
455
- this->smem_iterator_B1_.add_tile_offset({1, 0});
456
-
457
- // Defines the boundary of a stage of cp.async.
458
- cutlass::arch::cp_async_fence();
459
- }
460
-
461
- // Perform accumulation in the 'd' output operand
462
- accum0 = src_accum0;
463
- accum1 = src_accum1;
464
-
465
- //
466
- // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
467
- // so that all accumulator elements outside the GEMM footprint are zero.
468
- //
469
-
470
- if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
471
-
472
- /// Iterator to write threadblock-scoped tile of A operand to shared memory
473
- SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
474
-
475
- typename IteratorA::AccessType zero_A;
476
- zero_A.clear();
477
-
478
- last_smem_iterator_A.set_iteration_index(0);
479
-
480
- // Async Copy for operand A
481
- CUTLASS_PRAGMA_UNROLL
482
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
483
-
484
- typename IteratorA::AccessType *dst_ptr =
485
- reinterpret_cast<typename IteratorA::AccessType *>(
486
- last_smem_iterator_A.get());
487
-
488
- *dst_ptr = zero_A;
489
-
490
- ++last_smem_iterator_A;
491
- }
492
-
493
- typename IteratorB0::AccessType zero_B;
494
- zero_B.clear();
495
-
496
- /// Iterator to write threadblock-scoped tile of B0 operand to shared memory
497
- SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_);
498
- last_smem_iterator_B0.set_iteration_index(0);
499
-
500
- // Async Copy for operand B0
501
- CUTLASS_PRAGMA_UNROLL
502
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
503
- typename IteratorB0::AccessType *dst_ptr =
504
- reinterpret_cast<typename IteratorB0::AccessType *>(
505
- last_smem_iterator_B0.get());
506
-
507
- *dst_ptr = zero_B;
508
-
509
- ++last_smem_iterator_B0;
510
- }
511
-
512
- /// Iterator to write threadblock-scoped tile of B1 operand to shared memory
513
- SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_);
514
- last_smem_iterator_B1.set_iteration_index(0);
515
-
516
- // Async Copy for operand B1
517
- CUTLASS_PRAGMA_UNROLL
518
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
519
-
520
- typename IteratorB1::AccessType *dst_ptr =
521
- reinterpret_cast<typename IteratorB1::AccessType *>(
522
- last_smem_iterator_B1.get());
523
-
524
- *dst_ptr = zero_B;
525
-
526
- ++last_smem_iterator_B1;
527
- }
528
- }
529
-
530
- // Waits until stages up to the previous (kStages-2)th stage have committed.
531
- cutlass::arch::cp_async_wait<Base::kStages - 2>();
532
- __syncthreads();
533
-
534
- // Pair of fragments used to overlap shared memory loads and math
535
- // instructions
536
- WarpLoadedFragmentA warp_loaded_frag_A[2];
537
- WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
538
- WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
539
- WarpTransformedFragmentA warp_transformed_frag_A[2];
540
- WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
541
- WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
542
-
543
- Operator0 warp_mma0;
544
- Operator1 warp_mma1;
545
-
546
- this->warp_tile_iterator_A_.set_kgroup_index(0);
547
- this->warp_tile_iterator_B0_.set_kgroup_index(0);
548
- this->warp_tile_iterator_B1_.set_kgroup_index(0);
549
-
550
- this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
551
- this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
552
- this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
553
-
554
- ++this->warp_tile_iterator_A_;
555
- ++this->warp_tile_iterator_B0_;
556
- ++this->warp_tile_iterator_B1_;
557
-
558
- iterator_A.clear_mask(gemm_k_iterations == 0);
559
- iterator_B0.clear_mask(gemm_k_iterations == 0);
560
- iterator_B1.clear_mask(gemm_k_iterations == 0);
561
-
562
- int smem_write_stage_idx = Base::kStages - 1;
563
- int smem_read_stage_idx = 0;
564
-
565
- warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0],
566
- warp_loaded_frag_A[0], warp_loaded_frag_B0[0]);
567
- warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0],
568
- warp_loaded_frag_A[0], warp_loaded_frag_B1[0]);
569
-
570
- // tf32x3 kernels use staging accumulation. warp_mma uses a temporary
571
- // accumulator and this temporary accumulator is added to the final
572
- // accumulator once in every mainloop iteration.
573
- plus<FragmentC> plus_accum;
574
-
575
- FragmentC tmp_accum0, tmp_accum1;
576
-
577
- if (platform::is_same<typename Operator0::MathOperator,
578
- arch::OpMultiplyAddFastF32>::value
579
- || platform::is_same<typename Operator0::MathOperator,
580
- arch::OpMultiplyAddComplexFastF32>::value) {
581
-
582
- tmp_accum0.clear();
583
- tmp_accum1.clear();
584
- }
585
-
586
- //
587
- // Mainloop
588
- //
589
-
590
- CUTLASS_GEMM_LOOP
591
- for (; gemm_k_iterations > (-Base::kStages + 1);) {
592
- //
593
- // Loop over GEMM K dimension
594
- //
595
-
596
- // Computes a warp-level GEMM on data held in shared memory
597
- // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
598
- CUTLASS_PRAGMA_UNROLL
599
- for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
600
- ++warp_mma_k) {
601
-
602
- // Load warp-level tiles from shared memory, wrapping to k offset if
603
- // this is the last group as the case may be.
604
-
605
- this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
606
- this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
607
- this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
608
-
609
- this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
610
- this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
611
- this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
612
-
613
- ++this->warp_tile_iterator_A_;
614
- ++this->warp_tile_iterator_B0_;
615
- ++this->warp_tile_iterator_B1_;
616
-
617
- if (warp_mma_k > 0) {
618
- warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2],
619
- warp_transformed_frag_B0[warp_mma_k % 2],
620
- warp_loaded_frag_A[warp_mma_k % 2],
621
- warp_loaded_frag_B0[warp_mma_k % 2]);
622
- warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2],
623
- warp_transformed_frag_B1[warp_mma_k % 2],
624
- warp_loaded_frag_A[warp_mma_k % 2],
625
- warp_loaded_frag_B1[warp_mma_k % 2]);
626
- }
627
-
628
- if (platform::is_same<typename Operator0::MathOperator,
629
- arch::OpMultiplyAddFastF32>::value
630
- || platform::is_same<typename Operator0::MathOperator,
631
- arch::OpMultiplyAddComplexFastF32>::value) {
632
-
633
- warp_mma0(
634
- tmp_accum0,
635
- warp_transformed_frag_A[warp_mma_k % 2],
636
- warp_transformed_frag_B0[warp_mma_k % 2],
637
- tmp_accum0
638
- );
639
- warp_mma1(
640
- tmp_accum1,
641
- warp_transformed_frag_A[warp_mma_k % 2],
642
- warp_transformed_frag_B1[warp_mma_k % 2],
643
- tmp_accum1
644
- );
645
-
646
- if (warp_mma_k == 0) {
647
- accum0 = plus_accum(accum0, tmp_accum0);
648
- accum1 = plus_accum(accum1, tmp_accum1);
649
- tmp_accum0.clear();
650
- tmp_accum1.clear();
651
- }
652
- } else {
653
- warp_mma0(
654
- accum0,
655
- warp_transformed_frag_A[warp_mma_k % 2],
656
- warp_transformed_frag_B0[warp_mma_k % 2],
657
- accum0
658
- );
659
- warp_mma1(
660
- accum1,
661
- warp_transformed_frag_A[warp_mma_k % 2],
662
- warp_transformed_frag_B1[warp_mma_k % 2],
663
- accum1
664
- );
665
- }
666
-
667
- // Issue global->shared copies for the this stage
668
- if (warp_mma_k < Base::kWarpGemmIterations - 1) {
669
- int group_start_iteration_A, group_start_iteration_B;
670
-
671
- group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
672
- group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
673
-
674
- copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
675
- group_start_iteration_B);
676
- }
677
-
678
- if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
679
- int group_start_iteration_A, group_start_iteration_B;
680
- group_start_iteration_A =
681
- (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
682
- group_start_iteration_B =
683
- (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
684
-
685
- copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
686
- group_start_iteration_B);
687
-
688
- // Inserts a memory fence between stages of cp.async instructions.
689
- cutlass::arch::cp_async_fence();
690
-
691
- // Waits until stages up to the previous (kStages-2)th stage have committed.
692
- arch::cp_async_wait<Base::kStages - 2>();
693
- __syncthreads();
694
-
695
- // Move to the next stage
696
- iterator_A.add_tile_offset({0, 1});
697
- iterator_B0.add_tile_offset({1, 0});
698
- iterator_B1.add_tile_offset({1, 0});
699
-
700
- this->smem_iterator_A_.add_tile_offset({0, 1});
701
- this->smem_iterator_B0_.add_tile_offset({1, 0});
702
- this->smem_iterator_B1_.add_tile_offset({1, 0});
703
-
704
- // Add negative offsets to return iterators to the 'start' of the
705
- // circular buffer in shared memory
706
- if (smem_write_stage_idx == (Base::kStages - 1)) {
707
- this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
708
- this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
709
- this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
710
- smem_write_stage_idx = 0;
711
- } else {
712
- ++smem_write_stage_idx;
713
- }
714
-
715
- if (smem_read_stage_idx == (Base::kStages - 1)) {
716
- this->warp_tile_iterator_A_.add_tile_offset(
717
- {0, -Base::kStages * Policy0::kPartitionsK *
718
- Base::kWarpGemmIterations});
719
- this->warp_tile_iterator_B0_.add_tile_offset(
720
- {-Base::kStages * Policy0::kPartitionsK *
721
- Base::kWarpGemmIterations,
722
- 0});
723
- this->warp_tile_iterator_B1_.add_tile_offset(
724
- {-Base::kStages * Policy1::kPartitionsK *
725
- Base::kWarpGemmIterations,
726
- 0});
727
- smem_read_stage_idx = 0;
728
- } else {
729
- ++smem_read_stage_idx;
730
- }
731
-
732
- --gemm_k_iterations;
733
- iterator_A.clear_mask(gemm_k_iterations == 0);
734
- iterator_B0.clear_mask(gemm_k_iterations == 0);
735
- iterator_B1.clear_mask(gemm_k_iterations == 0);
736
- }
737
-
738
- // Do any conversions feeding the first stage at the end of the loop so
739
- // we can start right away on mma instructions
740
- if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
741
- warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
742
- warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
743
- warp_loaded_frag_A[(warp_mma_k + 1) % 2],
744
- warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
745
- warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
746
- warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
747
- warp_loaded_frag_A[(warp_mma_k + 1) % 2],
748
- warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
749
- }
750
- }
751
-
752
- }
753
-
754
- if (platform::is_same<typename Operator0::MathOperator,
755
- arch::OpMultiplyAddFastF32>::value
756
- || platform::is_same<typename Operator0::MathOperator,
757
- arch::OpMultiplyAddComplexFastF32>::value) {
758
- accum0 = plus_accum(accum0, tmp_accum0);
759
- accum1 = plus_accum(accum1, tmp_accum1);
760
- }
761
-
762
- // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
763
- cutlass::arch::cp_async_fence();
764
- cutlass::arch::cp_async_wait<0>();
765
- __syncthreads();
766
- }
767
- };
768
-
769
- /////////////////////////////////////////////////////////////////////////////////////////////////
770
-
771
- } // namespace threadblock
772
- } // namespace gemm
773
- } // namespace cutlass
774
-
775
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh DELETED
@@ -1,139 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cute/tensor.hpp"
34
-
35
- #include "cutlass/arch/arch.h"
36
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
37
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
38
- #include "cutlass/gemm/collective/collective_builder.hpp"
39
-
40
- #include "cutlass/epilogue/collective/collective_epilogue.hpp"
41
- #include "cutlass/epilogue/thread/linear_combination.h"
42
-
43
- namespace example {
44
-
45
- //
46
- // GETT entry point
47
- //
48
- template <
49
- class ProblemShapeMNKL,
50
- class ElementA,
51
- class StrideA,
52
- class ElementB,
53
- class StrideB,
54
- class ElementAccumulator,
55
- class ElementC,
56
- class StrideC,
57
- class ElementD,
58
- class StrideD,
59
- class ElementEpilogue>
60
- cutlass::Status
61
- gett_kernel(
62
- ProblemShapeMNKL problem_shape_mnkl,
63
- ElementA const* ptr_A, StrideA stride_a_mkl,
64
- ElementB const* ptr_B, StrideB stride_b_nkl,
65
- ElementAccumulator _,
66
- ElementC const* ptr_C, StrideC stride_c_mnl,
67
- ElementD * ptr_D, StrideD stride_d_mnl,
68
- ElementEpilogue alpha, ElementEpilogue beta,
69
- cudaStream_t stream = 0) {
70
- using namespace cute;
71
-
72
- // TileShape -- GETT configuration
73
- // Specify the number of elements to take from each mode
74
- // BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...)
75
-
76
- // Take 128 from m0, 128 from n0, 64 from k0
77
- using TileShape = Shape<Shape<_128>, Shape<_128>, Shape<_64>>;
78
-
79
- /* Other examples:
80
- * Take 32 elements from m0 and 4 elements from m1
81
- * Take 64 elements from n0 and 2 elements from n1
82
- * Take 8 elements from k0 and 8 elements from k1
83
- **/
84
- // using TileShape = Shape<Shape<_32,_4>, Shape<_64,_2>, Shape<_8,_8>>;
85
-
86
- using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination<
87
- ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default,
88
- cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
89
-
90
- // No changes are required to the default epilogue
91
- using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
92
- cutlass::epilogue::collective::DefaultEpilogue<
93
- ElementC,
94
- StrideC,
95
- StrideD,
96
- EpilogueThreadOp,
97
- cutlass::gemm::EpilogueDefault>>;
98
-
99
- // CollectiveMma for GETTs can be built using the CollectiveBuilders
100
- using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
101
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
102
- ElementA, StrideA, 128 / cutlass::sizeof_bits<ElementA>::value,
103
- ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
104
- ElementAccumulator,
105
- TileShape, Shape<_1,_2,_1>,
106
- cutlass::gemm::collective::StageCountAutoCarveout<
107
- static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
108
- cutlass::gemm::collective::KernelScheduleAuto
109
- >::CollectiveOp;
110
-
111
- // The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM
112
- using GettKernel = cutlass::gemm::kernel::GemmUniversal<
113
- ProblemShapeMNKL,
114
- CollectiveMainloop,
115
- CollectiveEpilogue>;
116
-
117
- using GettOperator = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;
118
-
119
- typename GettOperator::Arguments args {
120
- cutlass::gemm::GemmUniversalMode::kBatched,
121
- problem_shape_mnkl,
122
- { ptr_A, stride_a_mkl, ptr_B, stride_b_nkl },
123
- { {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl }
124
- };
125
-
126
- #if CUTLASS_DEBUG_TRACE_LEVEL > 0
127
- print("Problem shape:");
128
- print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n");
129
- print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n");
130
- print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n");
131
- print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n");
132
- print("TileSape:"); print(TileShape{}); print("\n");
133
- #endif
134
-
135
- GettOperator op;
136
- return op(args, stream);
137
- }
138
-
139
- } // namespace example
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp DELETED
@@ -1,421 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cutlass/cutlass.h"
34
- #include "cutlass/kernel_hardware_info.hpp"
35
- #include "cutlass/gemm/gemm.h"
36
- #include "cutlass/gemm/dispatch_policy.hpp"
37
-
38
- #include "cute/tensor.hpp"
39
-
40
- #include "gather_tensor.hpp"
41
-
42
- namespace cutlass {
43
- ///Forward declaration
44
- struct CudaHostAdapter;
45
- }
46
-
47
- namespace cutlass::gemm::kernel {
48
-
49
- ///////////////////////////////////////////////////////////////////////////////
50
-
51
- template <
52
- class ProblemShape_,
53
- class CollectiveMainloop_,
54
- class CollectiveEpilogue_,
55
- class TileScheduler_,
56
- class GatherA_,
57
- class GatherB_
58
- >
59
- class GemmGather
60
- {
61
- public:
62
- //
63
- // Type Aliases
64
- //
65
- using ProblemShape = ProblemShape_;
66
- static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
67
- "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
68
-
69
- // Mainloop derived types
70
- using CollectiveMainloop = CollectiveMainloop_;
71
- using TileShape = typename CollectiveMainloop::TileShape;
72
- using TiledMma = typename CollectiveMainloop::TiledMma;
73
- using ArchTag = typename CollectiveMainloop::ArchTag;
74
- using ElementA = typename CollectiveMainloop::ElementA;
75
- using StrideA = typename CollectiveMainloop::StrideA;
76
- using ElementB = typename CollectiveMainloop::ElementB;
77
- using StrideB = typename CollectiveMainloop::StrideB;
78
- using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
79
- using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
80
- using ClusterShape = typename DispatchPolicy::ClusterShape;
81
- using MainloopArguments = typename CollectiveMainloop::Arguments;
82
- using MainloopParams = typename CollectiveMainloop::Params;
83
- static_assert(ArchTag::kMinComputeCapability >= 90);
84
-
85
- // Epilogue derived types
86
- using CollectiveEpilogue = CollectiveEpilogue_;
87
- using ElementC = typename CollectiveEpilogue::ElementC;
88
- using StrideC = typename CollectiveEpilogue::StrideC;
89
- using ElementD = typename CollectiveEpilogue::ElementD;
90
- using StrideD = typename CollectiveEpilogue::StrideD;
91
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
92
- using EpilogueParams = typename CollectiveEpilogue::Params;
93
-
94
- static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
95
- "Non-persistent warp-specialized kernel does not support specializing the tile scheduler.");
96
- using TileSchedulerTag = TileScheduler_;
97
- using TileScheduler = typename detail::TileSchedulerSelector<
98
- TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
99
- using TileSchedulerArguments = typename TileScheduler::Arguments;
100
-
101
- using GatherA = GatherA_;
102
- using GatherB = GatherB_;
103
-
104
- // Kernel level shared memory storage
105
- struct SharedStorage {
106
- union TensorStorage {
107
- using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
108
- using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
109
-
110
- MainloopTensorStorage mainloop;
111
- EpilogueTensorStorage epilogue;
112
- } tensors;
113
-
114
- struct PipelineStorage : cute::aligned_struct<16, _2> {
115
- using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
116
- using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
117
-
118
- alignas(16) MainloopPipelineStorage mainloop;
119
- alignas(16) EpiLoadPipelineStorage epi_load;
120
- } pipelines;
121
- };
122
-
123
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
124
-
125
- using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA;
126
- using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB;
127
- static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same.");
128
-
129
- static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup;
130
- static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup;
131
- static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups;
132
- static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance.");
133
-
134
- static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup;
135
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
136
-
137
- // Device side arguments
138
- struct Arguments {
139
- GemmUniversalMode mode{};
140
- ProblemShape problem_shape{};
141
- MainloopArguments mainloop{};
142
- EpilogueArguments epilogue{};
143
- KernelHardwareInfo hw_info{};
144
- TileSchedulerArguments scheduler{};
145
- GatherA gather_A{};
146
- GatherB gather_B{};
147
- };
148
-
149
- // Kernel entry point API
150
- struct Params {
151
- GemmUniversalMode mode{};
152
- ProblemShape problem_shape{};
153
- MainloopParams mainloop{};
154
- EpilogueParams epilogue{};
155
- GatherA gather_A{};
156
- GatherB gather_B{};
157
- };
158
-
159
- //
160
- // Methods
161
- //
162
-
163
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
164
- static
165
- Params
166
- to_underlying_arguments(Arguments const& args, void* workspace) {
167
- (void) workspace;
168
- auto problem_shape = args.problem_shape;
169
- if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
170
- // swap M/N
171
- get<0>(problem_shape) = get<1>(args.problem_shape);
172
- get<1>(problem_shape) = get<0>(args.problem_shape);
173
- }
174
- return {
175
- args.mode,
176
- problem_shape,
177
- CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
178
- CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
179
- args.gather_A,
180
- args.gather_B
181
- };
182
- }
183
-
184
- static bool
185
- can_implement(Arguments const& args) {
186
- bool implementable = (args.mode == GemmUniversalMode::kGemm) or
187
- (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
188
- if (!implementable) {
189
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
190
- return implementable;
191
- }
192
- implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
193
- implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
194
- return implementable;
195
- }
196
-
197
- static
198
- size_t
199
- get_workspace_size(Arguments const& args) {
200
- return 0;
201
- }
202
-
203
- static
204
- cutlass::Status
205
- initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
206
- CudaHostAdapter* cuda_adapter = nullptr) {
207
- return Status::kSuccess;
208
- }
209
-
210
- // Computes the kernel launch grid shape based on runtime parameters
211
- static dim3
212
- get_grid_shape(Params const& params) {
213
- auto cluster_shape = Shape<_1,_1,_1>{};
214
- auto tile_shape = TileShape{};
215
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
216
- return TileScheduler::get_tiled_cta_shape_mnl(
217
- problem_shape_MNKL, tile_shape, cluster_shape);
218
- }
219
-
220
- static dim3
221
- get_block_shape() {
222
- return dim3(MaxThreadsPerBlock, 1, 1);
223
- }
224
-
225
- CUTLASS_DEVICE
226
- void
227
- operator()(Params const& params, char* smem_buf) {
228
- using namespace cute;
229
- using X = Underscore;
230
-
231
- // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
232
- #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
233
- if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
234
- printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
235
- return;
236
- }
237
- #endif
238
-
239
- enum class WarpGroupRole {
240
- Producer = 0,
241
- Consumer = 1,
242
- };
243
-
244
- // Kernel level shared memory storage
245
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
246
-
247
- int thread_idx = int(threadIdx.x);
248
- int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
249
- int warp_group_idx = canonical_warp_group_idx();
250
- CUTLASS_ASSERT(warp_group_idx < NumWarpGroups);
251
- WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer;
252
-
253
- // Mainloop Load pipeline
254
- using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
255
- typename MainloopPipeline::Params mainloop_pipeline_params;
256
- if (warp_group_role == WarpGroupRole::Producer) {
257
- mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
258
- }
259
- if (warp_group_role == WarpGroupRole::Consumer) {
260
- mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
261
- }
262
- mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup;
263
- mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup;
264
- MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params);
265
-
266
- // Epilogue Load pipeline
267
- using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
268
- typename EpiLoadPipeline::Params epi_load_pipeline_params;
269
- if (warp_group_role == WarpGroupRole::Producer) {
270
- epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
271
- }
272
- if (warp_group_role == WarpGroupRole::Consumer) {
273
- epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
274
- }
275
- epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup;
276
- epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup;
277
- EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
278
-
279
- // Epilogue Store pipeline
280
- using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
281
- typename EpiStorePipeline::Params epi_store_pipeline_params;
282
- epi_store_pipeline_params.always_wait = true;
283
- EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
284
-
285
- // Initialize starting pipeline states for the collectives
286
- typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
287
- typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
288
-
289
- // For the DMA Load (producer) we start with an opposite phase
290
- // i.e., we skip all waits since we know that the buffer is indeed empty
291
- PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
292
- PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
293
- PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
294
-
295
- // Preconditions
296
- static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
297
- static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
298
- static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
299
- static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
300
-
301
- // Separate out problem shape for convenience
302
- // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
303
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
304
- auto M = get<0>(problem_shape_MNKL);
305
- auto N = get<1>(problem_shape_MNKL);
306
- auto K = get<2>(problem_shape_MNKL);
307
- auto L = get<3>(problem_shape_MNKL);
308
-
309
- // Represent the full tensors
310
- Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l)
311
- Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l)
312
-
313
- // Get the appropriate blocks for this thread block -- potential for thread block locality
314
- auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
315
- TiledMma tiled_mma;
316
-
317
- // Make tiled views, defer the slice
318
- Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
319
- Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
320
-
321
- // Compute m_coord, n_coord, and l_coord with their post-tiled shapes
322
- auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
323
- auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
324
- auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
325
- auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
326
-
327
- // Slice with m_coord and n_coord
328
- Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
329
- Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
330
-
331
- // Get pipeline iterators and increments from tensor shapes
332
- auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
333
- auto k_tile_count = size<2>(gA);
334
- auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
335
- auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
336
-
337
- // Wait for all threads in the thread block
338
- __syncthreads();
339
-
340
- // In a warp specialized kernel, collectives expose data movement and compute operations separately
341
- CollectiveMainloop collective_mainloop;
342
- CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue};
343
-
344
- if (warp_group_role == WarpGroupRole::Producer) {
345
- // Compute tile residues for predication
346
- auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord
347
- auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord
348
- auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
349
- auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
350
-
351
- collective_mainloop.load(
352
- mainloop_pipeline,
353
- mainloop_pipe_producer_state,
354
- gA,
355
- gB,
356
- k_tile_iter, k_tile_count,
357
- residue_mnk,
358
- thread_idx,
359
- shared_storage.tensors.mainloop
360
- );
361
- // Update starting mainloop pipeline state for the pipeline drain
362
- mainloop_pipe_producer_state.advance(k_tile_count);
363
- // Make sure mainloop consumer has been waited upon before issuing epilogue load
364
- collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
365
-
366
- if (collective_epilogue.is_producer_load_needed()) {
367
- epi_load_pipe_producer_state =
368
- collective_epilogue.load(
369
- epi_load_pipeline,
370
- epi_load_pipe_producer_state,
371
- problem_shape_MNKL,
372
- blk_shape,
373
- blk_coord,
374
- tiled_mma,
375
- thread_idx,
376
- shared_storage.tensors.epilogue
377
- );
378
- collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
379
- }
380
- }
381
- else if (warp_group_role == WarpGroupRole::Consumer) {
382
- Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
383
-
384
- collective_mainloop.mma(
385
- mainloop_pipeline,
386
- mainloop_pipe_consumer_state,
387
- accumulators,
388
- k_tile_count,
389
- warp_group_thread_idx,
390
- shared_storage.tensors.mainloop,
391
- params.mainloop
392
- );
393
-
394
- // Make sure the math instructions are done and free buffers before entering the epilogue
395
- collective_mainloop.mma_tail(
396
- mainloop_pipeline,
397
- mainloop_pipe_consumer_state,
398
- k_tile_count
399
- );
400
-
401
- // Epilogue and write to gD
402
- collective_epilogue.store(
403
- epi_load_pipeline,
404
- epi_load_pipe_consumer_state,
405
- epi_store_pipeline,
406
- epi_store_pipe_producer_state,
407
- problem_shape_MNKL,
408
- blk_shape,
409
- blk_coord,
410
- accumulators,
411
- tiled_mma,
412
- warp_group_thread_idx,
413
- shared_storage.tensors.epilogue
414
- );
415
- }
416
- }
417
- };
418
-
419
- ///////////////////////////////////////////////////////////////////////////////
420
-
421
- } // namespace cutlass::gemm::kernel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh DELETED
@@ -1,136 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cute/numeric/math.hpp"
34
-
35
- namespace example
36
- {
37
-
38
- // Naive grid-stride loop implementation of gather
39
- template<typename Element, typename Func>
40
- __global__ void
41
- gather_kernel(Element const * __restrict__ input,
42
- Element * __restrict__ output,
43
- Func func,
44
- int num_elems_input,
45
- int num_elems_output,
46
- cutlass::FastDivmod stride_divmod)
47
- {
48
- Element const * input_b = input + blockIdx.z * num_elems_input;
49
- Element * output_b = output + blockIdx.z * num_elems_output;
50
- int tidx = threadIdx.x + blockIdx.x * blockDim.x;
51
- for (int k = tidx; k < num_elems_output; k += blockDim.x * gridDim.x) {
52
- int i,j;
53
- stride_divmod(j, i, k);
54
- output_b[k] = input_b[i + func(j) * stride_divmod.divisor];
55
- }
56
- }
57
-
58
- // Gather elements along strided dimension of the tensor according to given indices
59
- template<typename Element, typename Func>
60
- void
61
- gather(Element const * input,
62
- Element * output,
63
- Func func,
64
- int batch_size,
65
- int num_elems_input,
66
- int num_elems_output,
67
- int stride,
68
- cutlass::KernelHardwareInfo const& hw_info)
69
- {
70
- // Upcast to uint128_t data type
71
- int factor = 128 / cutlass::sizeof_bits<Element>::value;
72
- assert(stride % factor == 0);
73
- int stride_upcast = stride/factor;
74
- int num_elems_input_upcast = num_elems_input / factor;
75
- int num_elems_output_upcast = num_elems_output / factor;
76
-
77
- cutlass::FastDivmod stride_divmod(stride_upcast);
78
- dim3 blocks(hw_info.sm_count, 1, batch_size);
79
- gather_kernel<<<blocks, 1024>>>(reinterpret_cast<cute::uint128_t const *>(input),
80
- reinterpret_cast<cute::uint128_t *>(output),
81
- func,
82
- num_elems_input_upcast,
83
- num_elems_output_upcast,
84
- stride_divmod);
85
- }
86
-
87
- // Naive grid-stride loop implementation of scatter
88
- template<typename Element, typename Func>
89
- __global__ void
90
- scatter_kernel(Element const * __restrict__ input,
91
- Element * __restrict__ output,
92
- Func func,
93
- int num_elems_input,
94
- int num_elems_output,
95
- cutlass::FastDivmod stride_divmod)
96
- {
97
- Element const * input_b = input + blockIdx.z * num_elems_input;
98
- Element * output_b = output + blockIdx.z * num_elems_output;
99
- int tidx = threadIdx.x + blockIdx.x * blockDim.x;
100
- for (int k = tidx; k < num_elems_input; k += blockDim.x * gridDim.x) {
101
- int i,j;
102
- stride_divmod(j, i, k);
103
- output_b[i + func(j) * stride_divmod.divisor] = input_b[k];
104
- }
105
- }
106
-
107
- // Gather elements along strided dimension of the tensor according to given indices
108
- template<typename Element, typename Func>
109
- void
110
- scatter(Element const * input,
111
- Element * output,
112
- Func func,
113
- int batch_size,
114
- int num_elems_input,
115
- int num_elems_output,
116
- int stride,
117
- cutlass::KernelHardwareInfo const& hw_info)
118
- {
119
- // Upcast to uint128_t data type
120
- int factor = 128 / cutlass::sizeof_bits<Element>::value;
121
- assert(stride % factor == 0);
122
- int stride_upcast = stride/factor;
123
- int num_elems_input_upcast = num_elems_input / factor;
124
- int num_elems_output_upcast = num_elems_output / factor;
125
-
126
- cutlass::FastDivmod stride_divmod(stride_upcast);
127
- dim3 blocks(hw_info.sm_count, 1, batch_size);
128
- scatter_kernel<<<blocks, 1024>>>(reinterpret_cast<cute::uint128_t const *>(input),
129
- reinterpret_cast<cute::uint128_t *>(output),
130
- func,
131
- num_elems_input_upcast,
132
- num_elems_output_upcast,
133
- stride_divmod);
134
- }
135
-
136
- } // namespace example
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp DELETED
@@ -1,222 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Functor performing elementwise operations used by epilogues.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/cutlass.h"
38
- #include "cutlass/gemm/dispatch_policy.hpp"
39
- #include "cutlass/epilogue/collective/detail.hpp"
40
-
41
- #include "cute/tensor.hpp"
42
- #include "cute/numeric/numeric_types.hpp"
43
-
44
- #include "gather_tensor.hpp"
45
-
46
- namespace cutlass::epilogue::collective {
47
-
48
- /// Applies an element wise operation to all elements within the fragment
49
- /// and scatter-writes them out to destination storage.
50
- /// GatherC and ScatterD are types of user-defined functions that apply the
51
- /// transoformation of the strided coordinate (e.g. through an index array).
52
- template <
53
- class StrideC_,
54
- class StrideD_,
55
- class ThreadEpilogueOp_,
56
- class EpilogueSchedule_,
57
- class GatherC_,
58
- class ScatterD_
59
- >
60
- class EpilogueGatherScatter {
61
- public:
62
- //
63
- // Type Aliases
64
- //
65
- using EpilogueSchedule = EpilogueSchedule_;
66
-
67
- // derived types of output thread level operator
68
- using ThreadEpilogueOp = ThreadEpilogueOp_;
69
- using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
70
- using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
71
- using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
72
- using ElementScalar = ElementCompute;
73
- using ElementC = typename ThreadEpilogueOp::ElementC;
74
- using StrideC = StrideC_;
75
- using ElementD = typename ThreadEpilogueOp::ElementD;
76
- using StrideD = StrideD_;
77
-
78
- // Every epilogue needs these two GmemTiledCopy{C,D} aliases.
79
- // If you don't know what they should be, just use void.
80
- using GmemTiledCopyC = void;
81
- using GmemTiledCopyD = void;
82
-
83
- using GatherC = GatherC_;
84
- using ScatterD = ScatterD_;
85
-
86
- static const int kOutputAlignment = ThreadEpilogueOp::kCount;
87
- using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
88
-
89
- static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
90
- static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
91
-
92
- struct SharedStorage { };
93
-
94
- // Host side epilogue arguments
95
- struct Arguments {
96
- typename ThreadEpilogueOp::Params thread_params{};
97
- ElementC const* ptr_C = nullptr;
98
- StrideC dC{};
99
- ElementD* ptr_D = nullptr;
100
- StrideD dD{};
101
- GatherC gather_C{};
102
- ScatterD scatter_D{};
103
- };
104
-
105
- // Device side epilogue params
106
- using Params = Arguments;
107
-
108
- //
109
- // Methods
110
- //
111
-
112
- template <class ProblemShape>
113
- static constexpr Params
114
- to_underlying_arguments(
115
- [[maybe_unused]] ProblemShape const& _,
116
- Arguments const& args,
117
- [[maybe_unused]] void* workspace) {
118
- return args;
119
- }
120
-
121
- template<class ProblemShape>
122
- static bool
123
- can_implement(
124
- [[maybe_unused]] ProblemShape const& problem_shape,
125
- [[maybe_unused]] Arguments const& args) {
126
- return true;
127
- }
128
-
129
- CUTLASS_HOST_DEVICE
130
- EpilogueGatherScatter(Params const& params_) : params(params_) { }
131
-
132
- template<
133
- class ProblemShapeMNKL,
134
- class BlockShapeMNK,
135
- class BlockCoordMNKL,
136
- class FrgEngine, class FrgLayout,
137
- class TiledMma,
138
- class ResidueMNK
139
- >
140
- CUTLASS_DEVICE void
141
- operator()(
142
- ProblemShapeMNKL problem_shape_mnkl,
143
- BlockShapeMNK blk_shape_MNK,
144
- BlockCoordMNKL blk_coord_mnkl,
145
- cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
146
- TiledMma tiled_mma,
147
- ResidueMNK residue_mnk,
148
- int thread_idx,
149
- char* smem_buf)
150
- {
151
- using namespace cute;
152
- using X = Underscore;
153
-
154
- static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
155
- static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
156
- static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
157
- static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
158
-
159
- (void) smem_buf;
160
- ThreadEpilogueOp epilogue_op{params.thread_params};
161
-
162
- // Separate out problem shape for convenience
163
- auto M = get<0>(problem_shape_mnkl);
164
- auto N = get<1>(problem_shape_mnkl);
165
- auto L = get<3>(problem_shape_mnkl);
166
-
167
- auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
168
- auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
169
-
170
- // Represent the full output tensor
171
- Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l)
172
- Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l)
173
-
174
- Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
175
- Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
176
-
177
- // Slice to get the tile this CTA is responsible for
178
- auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
179
- Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
180
- Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
181
-
182
- // Partition source and destination tiles to match the accumulator partitioning
183
- auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
184
- Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
185
- Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
186
-
187
- static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
188
- CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
189
- "Source and destination must have the same number of elements.");
190
- CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
191
- "Accumulator count must have the same destination element count.");
192
-
193
- // Make an identity coordinate tensor for predicating our output MN tile
194
- auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
195
- Tensor tCcD = thr_mma.partition_C(cD);
196
-
197
- // source is needed
198
- if (epilogue_op.is_source_needed()) {
199
- CUTLASS_PRAGMA_UNROLL
200
- for (int i = 0; i < size(accumulators); ++i) {
201
- if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
202
- tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
203
- }
204
- }
205
- }
206
- // source is not needed, avoid load
207
- else {
208
- CUTLASS_PRAGMA_UNROLL
209
- for (int i = 0; i < size(accumulators); ++i) {
210
- if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
211
- tCgD(i) = epilogue_op(accumulators(i));
212
- }
213
- }
214
- }
215
- }
216
-
217
- private:
218
- Params params;
219
- };
220
-
221
- } // namespace cutlass::epilogue::collective
222
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh DELETED
@@ -1,92 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Simple permutation kernel implementation.
34
- */
35
-
36
- #include "cutlass/layout/pitch_linear.h"
37
- #include "cutlass/layout/matrix.h"
38
- #include "cutlass/tensor_view.h"
39
- #include "cutlass/fast_math.h"
40
- #include "cute/numeric/numeric_types.hpp"
41
-
42
- namespace example
43
- {
44
-
45
- /**
46
- * Assumes column-major input (M mode is contiguous, N mode is strided).
47
- * For row major, the inputs must be switched accordingly.
48
- */
49
- template<bool Batched, typename Element, typename Permute>
50
- __global__ void
51
- permute_kernel(Element const* __restrict__ input,
52
- Element* __restrict__ output,
53
- Permute permute,
54
- int64_t num_elems,
55
- cutlass::FastDivmod stride_divmod)
56
- {
57
- // CUTLASS 2.x batched permute functions assume 0 batch stride for target tensor
58
- Element const * input_b = input + blockIdx.z * num_elems;
59
- Element * output_b = output + (Batched ? 0 : blockIdx.z * num_elems);
60
- for (int64_t k = threadIdx.x + blockIdx.x * blockDim.x; k < num_elems; k += blockDim.x * gridDim.x)
61
- {
62
- int i, j;
63
- stride_divmod(j, i, k);
64
- output_b[permute(cutlass::PitchLinearCoord(i, j))] = input_b[i + j * stride_divmod.divisor];
65
- }
66
- }
67
-
68
- template<bool Batched, typename Permute, typename Element>
69
- void permute(Element const* input,
70
- Element * output,
71
- int64_t num_elems,
72
- int stride,
73
- int batch_count,
74
- cutlass::KernelHardwareInfo const& hw_info)
75
- {
76
- // Upcast to uint128_t data type
77
- int factor = 128 / cutlass::sizeof_bits<Element>::value;
78
- assert(stride % factor == 0);
79
- int stride_upcast = stride/factor;
80
- int64_t num_elems_upcast = num_elems / factor;
81
- Permute permute_upcast(cutlass::PitchLinearCoord(stride_upcast, int(num_elems_upcast/stride_upcast)), stride_upcast);
82
-
83
- cutlass::FastDivmod stride_divmod(stride);
84
- dim3 blocks(hw_info.sm_count, 1, batch_count);
85
- permute_kernel<Batched><<<blocks, 1024>>>(reinterpret_cast<cute::uint128_t const *>(input),
86
- reinterpret_cast<cute::uint128_t *>(output),
87
- permute_upcast,
88
- num_elems_upcast,
89
- stride_upcast);
90
- }
91
-
92
- } // namespace example
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp DELETED
@@ -1,274 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Additional permutation information for the example.
34
- */
35
-
36
- #include "cutlass/layout/permute.h"
37
- #include "cutlass/gemm/gemm.h"
38
-
39
- namespace example
40
- {
41
-
42
- using namespace cute;
43
-
44
- // This struct is specialized below for different CUTLASS 2.x permutation ops
45
- // to describe the operation in terms of target CuTe shape and stride order.
46
- template<class Permute>
47
- struct PermuteTraits {};
48
-
49
- // Use X as a placeholder for shape division result
50
- using X = Underscore;
51
-
52
- // Reshape a rank-2 shape into a multidimensional shape.
53
- // Input:
54
- // shape = (A, B, ...)
55
- // target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...)
56
- // Output:
57
- // ((A1, ..., A/prod(A1..Am), ..., Am), (B1, ..., B/prod(B1..Bn), ..., Bn), ...)
58
- template<class Shape, class TargetShape>
59
- constexpr auto
60
- reshape(Shape const& shape, TargetShape const& target_shape)
61
- {
62
- if constexpr (is_tuple<Shape>::value) {
63
- return cute::transform(shape, target_shape, [](auto && s, auto && t){ return reshape(s, t); });
64
- }
65
- else {
66
- auto idx = find_if(target_shape, [](auto x){ return is_underscore<decltype(x)>{}; });
67
- constexpr int I = decltype(idx)::value;
68
- static_assert(I < tuple_size_v<TargetShape>, "Each mode of TargetShape must contain a placeholder X");
69
- auto divisors = remove<I>(target_shape);
70
- assert(shape % product(divisors) == 0);
71
- return replace<I>(target_shape, shape / product(divisors));
72
- }
73
- }
74
-
75
- // Given a tensor layout, compute a permutation layout consisting of:
76
- // - sub-modes corresponding to the implied multidimensional shape of the source tensor
77
- // - strides accounting for the permutation operation being performed
78
- template<class Permute, bool Transpose, class Shape, class Stride>
79
- constexpr auto
80
- make_permute_layout(Layout<Shape,Stride> const& layout) {
81
- static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
82
- if constexpr (Transpose) {
83
- // Deal with tensor B by transposing appropriately before and after computing the permute layout.
84
- // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
85
- return select<1,0,2>(make_permute_layout<Permute, false>(select<1,0,2>(layout)));
86
- }
87
- else {
88
- if constexpr (cutlass::layout::is_trivial_permute<Permute>) {
89
- // Special case for NoPermute. Use a depth-2 layout for consistency with other permutations.
90
- using ShapeProfile = tuple<tuple<X>, tuple<X>, tuple<X>>;
91
- return unflatten(layout, ShapeProfile{});
92
- }
93
- else {
94
- // Here's where the permutation layout is actually built
95
- using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
96
- using StrideOrder = typename PermuteTraits<Permute>::StrideOrder;
97
- return make_ordered_layout(reshape(layout.shape(), ShapeProfile{}), StrideOrder{});
98
- }
99
- }
100
- }
101
-
102
- namespace detail
103
- {
104
-
105
- template<int I>
106
- struct is_constant_pred {
107
- template <class T>
108
- constexpr auto operator()(T) {
109
- return is_constant<I, T>{};
110
- }
111
- };
112
-
113
- template<class Permutation, int... I>
114
- constexpr auto
115
- inverse_impl(Permutation const & perm, seq<I...>) {
116
- return cute::make_tuple(Int<find_if(Permutation{}, is_constant_pred<I>{})>{}...);
117
- }
118
-
119
- } // namespace detail
120
-
121
- // Compute an inverse of a permutation represented as a tuple of cute::Int<>
122
- template<class Permutation>
123
- constexpr auto
124
- inverse(Permutation const & perm) {
125
- auto flat_perm = flatten(perm);
126
- return unflatten(detail::inverse_impl(flat_perm, tuple_seq<decltype(flat_perm)>{}), perm);
127
- }
128
-
129
- template<class T>
130
- using inverse_t = decltype(inverse(T{}));
131
-
132
- // Given a rank-2 layout of tensor that is assumed to have been permuted,
133
- // compute the original rank-2 layout of the tensor prior to the permutation.
134
- // This is needed to form the correct input to the standalone permutation kernel.
135
- template<class Permute, bool Transpose, class Shape, class Stride>
136
- constexpr auto
137
- make_original_layout(Layout<Shape,Stride> const& layout) {
138
- static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
139
- if constexpr (Transpose) {
140
- // Deal with tensor B by transposing appropriately before and after computing the permute layout.
141
- // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
142
- return select<1,0,2>(make_original_layout<Permute, false>(select<1,0,2>(layout)));
143
- }
144
- else {
145
- using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
146
- auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{}));
147
- using IndexOrder = typename PermuteTraits<Permute>::IndexOrder;
148
- auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get<i>(re_shape); });
149
- using OrigOrder = conditional_t<cutlass::gemm::detail::is_major<0,Stride>(), seq<0,1,2>, seq<1,0,2>>;
150
- // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n");
151
- // print("Original shape: "); print(orig_shape); print("\n");
152
- return make_ordered_layout(product_each(orig_shape), OrigOrder{});
153
- }
154
- }
155
-
156
- /////////////// Tensor4DPermute0213 ////////////////////
157
-
158
- template<int D1, int D2>
159
- struct PermuteTraits<cutlass::layout::Tensor4DPermute0213ColumnMajor<D1, D2>>
160
- {
161
- static constexpr bool kBatched = false;
162
- using ShapeProfile = Shape<Shape<X,Int<D1>>, Shape<Int<D2>,X>, Shape<X>>;
163
- using IndexOrder = Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
164
- using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
165
- };
166
-
167
- template<int D1, int D2>
168
- struct PermuteTraits<cutlass::layout::Tensor4DPermute0213ColumnMajorInverse<D1, D2>>
169
- {
170
- static constexpr bool kBatched = false;
171
- using ShapeProfile = Shape<Shape<X,Int<D2>>, Shape<Int<D1>,X>, Shape<X>>;
172
- using IndexOrder = Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
173
- using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
174
- };
175
-
176
- template<int D1, int D2>
177
- struct PermuteTraits<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>>
178
- {
179
- static constexpr bool kBatched = false;
180
- using ShapeProfile = Shape<Shape<Int<D1>,X>, Shape<X,Int<D2>>, Shape<X>>;
181
- using IndexOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
182
- using StrideOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
183
- };
184
-
185
- template<int D1, int D2>
186
- struct PermuteTraits<cutlass::layout::Tensor4DPermute0213RowMajorInverse<D1, D2>>
187
- {
188
- static constexpr bool kBatched = false;
189
- using ShapeProfile = Shape<Shape<Int<D2>,X>, Shape<X,Int<D1>>, Shape<X>>;
190
- using IndexOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
191
- using StrideOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
192
- };
193
-
194
- /////////////// Tensor4DPermuteBMM0321 ////////////////////
195
-
196
- template<int D>
197
- struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D>>
198
- {
199
- static constexpr bool kBatched = true;
200
- using ShapeProfile = Shape<Shape<X>, Shape<X>, Shape<Int<D>,X>>;
201
- using IndexOrder = Step<Step<_0,_2>, Step<_1>, Step<_3>>;
202
- using StrideOrder = Step<Step<_0>, Step<_2>, Step<_1,_3>>;
203
- };
204
-
205
- template<int D>
206
- struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
207
- {
208
- static constexpr bool kBatched = true;
209
- using ShapeProfile = Shape<Shape<X,Int<D>>, Shape<X>, Shape<X>>;
210
- using IndexOrder = Step<Step<_0>, Step<_2>, Step<_1,_3>>;
211
- using StrideOrder = Step<Step<_0,_2>, Step<_1>, Step<_3>>;
212
- };
213
-
214
- /////////////// Tensor4DPermuteBMM0213 ////////////////////
215
-
216
- template<int D>
217
- struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D>>
218
- {
219
- static constexpr bool kBatched = true;
220
- using ShapeProfile = Shape<Shape<X>, Shape<X>, Shape<Int<D>,X>>;
221
- using IndexOrder = Step<Step<_0>, Step<_1,_2>, Step<_3>>;
222
- using StrideOrder = Step<Step<_2>, Step<_0>, Step<_1,_3>>;
223
- };
224
-
225
- template<int D>
226
- struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
227
- {
228
- static constexpr bool kBatched = true;
229
- using ShapeProfile = Shape<Shape<X>, Shape<X,Int<D>>, Shape<X>>;
230
- using IndexOrder = Step<Step<_0>, Step<_1>, Step<_2,_3>>;
231
- using StrideOrder = Step<Step<_1>, Step<_0,_2>, Step<_3>>;
232
- };
233
-
234
- /////////////// Tensor5DPermute02413 ////////////////////
235
-
236
- template<int D1, int D2, int D3>
237
- struct PermuteTraits<cutlass::layout::Tensor5DPermute02413ColumnMajor<D1, D2, D3>>
238
- {
239
- static constexpr bool kBatched = false;
240
- using ShapeProfile = Shape<Shape<X,Int<D1>>, Shape<Int<D2>,Int<D3>,X>, Shape<X>>;
241
- using IndexOrder = Step<Step<_0,_2>, Step<_4,_1,_3>, Step<_5>>;
242
- using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_3>, Step<_1,_4,_2>, Step<_5>>;
243
- };
244
-
245
- template<int D1, int D2, int D3>
246
- struct PermuteTraits<cutlass::layout::Tensor5DPermute02413ColumnMajorInverse<D1, D2, D3>>
247
- {
248
- static constexpr bool kBatched = false;
249
- using ShapeProfile = Shape<Shape<X,Int<D2>>, Shape<X,Int<D1>,Int<D3>>, Shape<X>>;
250
- using IndexOrder = Step<Step<_0,_3>, Step<_1,_4,_2>, Step<_5>>;
251
- using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_2>, Step<_4,_1,_3>, Step<_5>>;
252
- };
253
-
254
- /////////////// Tensor5DPermute20314 ////////////////////
255
-
256
- template<int D1, int D2, int D3>
257
- struct PermuteTraits<cutlass::layout::Tensor5DPermute20314RowMajor<D1, D2, D3>>
258
- {
259
- static constexpr bool kBatched = false;
260
- using ShapeProfile = Shape<Shape<Int<D1>,X>, Shape<X,Int<D3>,Int<D2>>, Shape<X>>;
261
- using IndexOrder = Step<Step<_2,_0>, Step<_3,_1,_4>, Step<_5>>;
262
- using StrideOrder = Step<Step<_1,_3>, Step<_0,_2,_4>, Step<_5>>;
263
- };
264
-
265
- template<int D1, int D2, int D3>
266
- struct PermuteTraits<cutlass::layout::Tensor5DPermute20314RowMajorInverse<D1, D2, D3>>
267
- {
268
- static constexpr bool kBatched = false;
269
- using ShapeProfile = Shape<Shape<X,Int<D2>>, Shape<X,Int<D1>,Int<D3>>, Shape<X>>;
270
- using IndexOrder = Step<Step<_3,_0>, Step<_2,_4,_1>, Step<_5>>;
271
- using StrideOrder = Step<Step<_4,_2>, Step<_0,_3,_1>, Step<_5>>;
272
- };
273
-
274
- } // namespace example
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp DELETED
@@ -1,129 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- // Command line options parsing
33
- template<typename RasterOrderOptions>
34
- struct Options {
35
-
36
- bool help = false;
37
-
38
- float alpha = 1.f, beta = 0.f;
39
- float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
40
- bool device_scale = false;
41
- bool save_aux = true;
42
- bool save_amax = true;
43
- int iterations = 1000;
44
- int m = 1024, n = 512, k = 1024, l = 1;
45
- RasterOrderOptions raster;
46
- int swizzle;
47
-
48
- // Parses the command line
49
- void parse(int argc, char const **args) {
50
- cutlass::CommandLine cmd(argc, args);
51
-
52
- if (cmd.check_cmd_line_flag("help")) {
53
- help = true;
54
- return;
55
- }
56
-
57
- cmd.get_cmd_line_argument("m", m);
58
- cmd.get_cmd_line_argument("n", n);
59
- cmd.get_cmd_line_argument("k", k);
60
- cmd.get_cmd_line_argument("l", l);
61
- cmd.get_cmd_line_argument("alpha", alpha, 1.f);
62
- cmd.get_cmd_line_argument("beta", beta, 0.f);
63
- cmd.get_cmd_line_argument("scale_a", scale_a, 1.f);
64
- cmd.get_cmd_line_argument("scale_b", scale_b, 1.f);
65
- cmd.get_cmd_line_argument("scale_c", scale_c, 1.f);
66
- cmd.get_cmd_line_argument("scale_d", scale_d, 1.f);
67
- cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f);
68
- cmd.get_cmd_line_argument("device_scale", device_scale, false);
69
- cmd.get_cmd_line_argument("save_aux", save_aux, true);
70
- cmd.get_cmd_line_argument("save_amax", save_amax, true);
71
- cmd.get_cmd_line_argument("iterations", iterations);
72
-
73
- char raster_char;
74
- cmd.get_cmd_line_argument("raster", raster_char);
75
-
76
- if (raster_char == 'N' || raster_char == 'n') {
77
- raster = RasterOrderOptions::AlongN;
78
- }
79
- else if (raster_char == 'M' || raster_char == 'm') {
80
- raster = RasterOrderOptions::AlongM;
81
- }
82
- else if (raster_char == 'H' || raster_char == 'h') {
83
- raster = RasterOrderOptions::Heuristic;
84
- }
85
-
86
- cmd.get_cmd_line_argument("swizzle", swizzle, 1);
87
- }
88
-
89
- /// Prints the usage statement.
90
- std::ostream & print_usage(std::ostream &out) const {
91
-
92
- out << "54_fp8_hopper_warp_specialized_gemm\n\n"
93
- << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
94
- << "Options:\n\n"
95
- << " --help If specified, displays this usage statement\n\n"
96
- << " --m=<int> Sets the M extent of the GEMM\n"
97
- << " --n=<int> Sets the N extent of the GEMM\n"
98
- << " --k=<int> Sets the K extent of the GEMM\n"
99
- << " --l=<int> Sets the l extent (batch) of the GEMM\n"
100
- << " --alpha=<f32> Epilogue scalar alpha\n"
101
- << " --beta=<f32> Epilogue scalar beta\n"
102
- << " --scale_a=<f32> Scaling factor for A\n"
103
- << " --scale_b=<f32> Scaling factor for B\n"
104
- << " --scale_c=<f32> Scaling factor for C\n"
105
- << " --scale_d=<f32> Scaling factor for D (ignored for non-fp8 D)\n"
106
- << " --scale_aux=<f32> Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n"
107
- << " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
108
- << " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
109
- << " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
110
- << " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
111
- << " --swizzle=<int> CTA Rasterization swizzle\n\n"
112
- << " --iterations=<int> Number of profiling iterations to perform.\n\n";
113
-
114
- out
115
- << "\n\nExamples:\n\n"
116
- << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
117
-
118
- return out;
119
- }
120
-
121
- /// Compute performance in GFLOP/s
122
- double gflops(double runtime_s) const
123
- {
124
- // Two flops per multiply-add
125
- uint64_t flop = uint64_t(2) * m * n * k;
126
- double gflop = double(flop) / double(1.0e9);
127
- return gflop / runtime_s;
128
- }
129
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp DELETED
@@ -1,246 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/gemm/dispatch_policy.hpp"
36
- #include "cutlass/epilogue/collective/default_epilogue.hpp"
37
- #include "cutlass/epilogue/collective/collective_builder.hpp"
38
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
39
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
40
- #include "cutlass/util/command_line.h"
41
- #include "cutlass/util/reference/device/tensor_fill.h"
42
- #include "cutlass/util/reference/device/tensor_compare.h"
43
-
44
- #include "cute/tensor.hpp"
45
-
46
- #include <cuda.h>
47
- #include <numeric>
48
- #include "helper.h"
49
-
50
- enum MixedDtypeGemmMode {
51
- ConvertOnly,
52
- ScaleOnly,
53
- ScaleWithZeroPoint
54
- };
55
-
56
- /// Command line options parsing
57
- struct MixedDtypeOptions {
58
-
59
- bool help = false;
60
-
61
- float alpha = 1.0f;
62
- float beta = 0.0f;
63
- int iterations = 100;
64
- int warmup = 10;
65
- int mode = 1;
66
- int m = 5120, n = 4096, k = 4096;
67
- int g = 128;
68
- int l = 1;
69
-
70
- // Parses the command line
71
- void parse(int argc, char const **args) {
72
- cutlass::CommandLine cmd(argc, args);
73
-
74
- if (cmd.check_cmd_line_flag("help")) {
75
- help = true;
76
- return;
77
- }
78
-
79
- cmd.get_cmd_line_argument("m", m);
80
- cmd.get_cmd_line_argument("n", n);
81
- cmd.get_cmd_line_argument("k", k);
82
- cmd.get_cmd_line_argument("l", l);
83
- cmd.get_cmd_line_argument("g", g);
84
- cmd.get_cmd_line_argument("mode", mode);
85
- cmd.get_cmd_line_argument("alpha", alpha, 1.f);
86
- cmd.get_cmd_line_argument("beta", beta, 0.f);
87
- cmd.get_cmd_line_argument("iterations", iterations);
88
- cmd.get_cmd_line_argument("warmup", warmup);
89
- }
90
-
91
- /// Prints the usage statement.
92
- std::ostream & print_usage(std::ostream &out) const {
93
-
94
- out << "55_hopper_mixed_dtype_gemm\n\n"
95
- << " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
96
- << "Options:\n\n"
97
- << " --help If specified, displays this usage statement\n\n"
98
- << " --m=<int> Sets the M extent of the GEMM\n"
99
- << " --n=<int> Sets the N extent of the GEMM\n"
100
- << " --k=<int> Sets the K extent of the GEMM\n"
101
- << " --l=<int> The number of independent gemm problems with mnk shape\n"
102
- << " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
103
- << " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
104
- << " --alpha=<f32> Epilogue scalar alpha\n"
105
- << " --beta=<f32> Epilogue scalar beta\n\n"
106
- << " --iterations=<int> Number of profiling iterations to perform.\n\n"
107
- << " --warmup=<int> Number of warmup iterations to perform.\n\n";
108
-
109
- out
110
- << "\n\nExamples:\n\n"
111
- << "$ " << "55_hopper_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
112
-
113
- return out;
114
- }
115
-
116
- /// Compute performance in GFLOP/s
117
- double gflops(double runtime_s) const
118
- {
119
- // Two flops per multiply-add
120
- uint64_t flop = uint64_t(2) * m * n * k * l;
121
- double gflop = double(flop) / double(1.0e9);
122
- return gflop / runtime_s;
123
- }
124
- };
125
-
126
- /// Result structure
127
- struct MixedDtypeResult
128
- {
129
- double avg_runtime_ms = 0.0;
130
- double gflops = 0.0;
131
- cutlass::Status status = cutlass::Status::kSuccess;
132
- cudaError_t error = cudaSuccess;
133
- bool passed = false;
134
-
135
- };
136
-
137
- /// Profiling Loop
138
- template <class Gemm>
139
- void mixed_dtype_profiling(
140
- Gemm& gemm,
141
- MixedDtypeOptions const& options,
142
- MixedDtypeResult& result) {
143
-
144
- if (options.iterations <= 0) return;
145
-
146
- cudaEvent_t start, stop;
147
- cudaEventCreate(&start);
148
- cudaEventCreate(&stop);
149
-
150
- std::vector<float> runtimes;
151
- runtimes.reserve(options.iterations);
152
-
153
- for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
154
- cudaEventRecord(start);
155
- CUTLASS_CHECK(gemm.run());
156
- cudaEventRecord(stop);
157
- cudaEventSynchronize(stop);
158
-
159
- if (iter >= options.warmup) {
160
- float milliseconds = 0;
161
- cudaEventElapsedTime(&milliseconds, start, stop);
162
- runtimes.push_back(milliseconds);
163
- }
164
- }
165
-
166
- cudaEventDestroy(start);
167
- cudaEventDestroy(stop);
168
-
169
- // Compute average setup and runtime and GFLOPs.
170
- result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
171
- result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
172
-
173
- std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
174
- std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
175
- std::cout << " GFLOPS: " << result.gflops << std::endl;
176
-
177
- }
178
-
179
- /// Helpers to initialize a block of device data
180
- template <class Element>
181
- bool initialize_tensor(
182
- cutlass::DeviceAllocation<Element>& block,
183
- uint64_t seed = 2023) {
184
-
185
- double scope_max, scope_min;
186
- int bits_input = cutlass::sizeof_bits<Element>::value;
187
- int bits_output = cutlass::sizeof_bits<Element>::value;
188
-
189
- if (bits_input == 1) {
190
- scope_max = 2;
191
- scope_min = 0;
192
- }
193
- else if (bits_input <= 8) {
194
- scope_max = 2;
195
- scope_min = -2;
196
- }
197
- else if (bits_output == 16) {
198
- scope_max = 5;
199
- scope_min = -5;
200
- }
201
- else {
202
- scope_max = 8;
203
- scope_min = -8;
204
- }
205
- cutlass::reference::device::BlockFillRandomUniform(
206
- block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
207
-
208
- return true;
209
- }
210
-
211
- template <class Element>
212
- bool initialize_scale(
213
- cutlass::DeviceAllocation<Element>& block,
214
- MixedDtypeOptions const& options,
215
- uint64_t seed = 2023) {
216
-
217
- // If no scales, initialize with 1 so we can use the same kernel to dequantize the data
218
- float scope_max = 1.0f, scope_min = 1.0f;
219
- if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
220
- float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
221
- scope_max = 2.f;
222
- scope_min = 0.1f;
223
- }
224
- cutlass::reference::device::BlockFillRandomUniform(
225
- block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
226
-
227
- return true;
228
- }
229
-
230
- template <class Element>
231
- bool initialize_zero(
232
- cutlass::DeviceAllocation<Element>& block,
233
- MixedDtypeOptions const& options,
234
- uint64_t seed = 2023) {
235
-
236
- // If no bias, initialize with 0 so we can use the same kernel to dequantize the data
237
- float scope_max = 0.0f, scope_min = 0.0f;
238
- if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
239
- scope_max = 2.0f;
240
- scope_min = -2.0f;
241
- }
242
- cutlass::reference::device::BlockFillRandomUniform(
243
- block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
244
-
245
- return true;
246
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h DELETED
@@ -1,320 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cute/tensor.hpp"
34
- #include "cute/atom/mma_atom.hpp"
35
- #include "cute/atom/copy_atom.hpp"
36
- #include <random>
37
-
38
- #include "cutlass/util/print_error.hpp"
39
-
40
- #include "cutlass/gemm/dispatch_policy.hpp"
41
- #include "cutlass/gemm/collective/collective_mma.hpp"
42
-
43
- using namespace cute;
44
-
45
- struct AmpereUnpredicatedFprop {
46
- //
47
- // Static config for conv problem shape
48
- //
49
- using D = _6;
50
- using H = _4;
51
- using W = _4;
52
-
53
- using T = _3;
54
- using R = _3;
55
- using S = _3;
56
-
57
- using Z = _4;
58
- using P = _2;
59
- using Q = _2;
60
-
61
- using C = _64;
62
- using K = _128;
63
-
64
- // Tiler config
65
- using Tiler_K = decltype(cute::min(K{}, _128{}));
66
- using Tiler_C = decltype(cute::min(C{}, _32{}));
67
- using Tiler_N = _4;
68
- using TileM = Tiler_K;
69
- using TileN = Shape<Tiler_N, Z, P, Q>;
70
- using TileK = Shape<Tiler_C,_1,_1,_1>;
71
- using PIPE = _3;
72
- using TilerFlt = Shape<TileM, TileK>;
73
- using TilerAct = Shape<TileN, TileK>;
74
- using TilerOut = Shape<TileM, TileN>;
75
-
76
- using TileSizeM = Int<size(TileM{})>;
77
- using TileSizeN = Int<size(TileN{})>;
78
- using TileSizeK = Int<size(TileK{})>;
79
- static constexpr int Stages = PIPE::value;
80
-
81
- using ElementFlt = tfloat32_t;
82
- using ElementAct = tfloat32_t;
83
- using ElementOut = float;
84
-
85
- using TiledMma = TiledMMA<
86
- MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
87
- Layout<Shape<_2,_2,_1>>,
88
- Tile<_32,_32,Underscore>>;
89
-
90
- static constexpr int MaxThreadsPerBlock = size(TiledMma{});
91
- static constexpr int MinBlocksPerMultiprocessor = 1;
92
-
93
- union SharedStorage {
94
- struct {
95
- ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})];
96
- ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})];
97
- } mainloop;
98
-
99
- struct {
100
- ElementOut sCMatrix[size(TileM{}) * size(TileN{})];
101
- } epilogue;
102
- };
103
-
104
- //
105
- // Stencil tensor
106
- //
107
-
108
- using GmemLayoutFlt = decltype(make_ordered_layout(
109
- Shape< K, Shape< C, T, R, S>>{},
110
- tuple<_4, tuple<_0,_3,_2,_1>>{}));
111
-
112
- // We have 64 elements * 32b each in the major mode that we can vectorize
113
- // Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4
114
- // Rest along the minor mode
115
- using GmemTiledCopyFlt = decltype(make_tiled_copy(
116
- Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementFlt>{},
117
- Layout<Shape <_16, _8>,
118
- Stride< _8, _1>>{},
119
- Layout<Shape < _1, _4>>{}));
120
-
121
- // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
122
- // using SmemLayoutFlt = decltype(
123
- // composition(Swizzle<3,2,3>{},
124
- // make_ordered_layout(
125
- // Shape<TileSizeM,TileSizeK,PIPE>{},
126
- // tuple< _1, _0, _2>{})));
127
-
128
- using SmemLayoutAtomFlt = decltype(
129
- composition(Swizzle<1,2,3>{},
130
- Layout<Shape <_8,Shape <_4, _2>>,
131
- Stride<_4,Stride<_1,_32>>>{}));
132
-
133
- using SmemCopyAtomFlt = Copy_Atom<SM75_U32x4_LDSM_N, ElementFlt>;
134
-
135
- //
136
- // Activation tensor
137
- //
138
-
139
- // Activation tensor is major in the contraction mode, so vectorize that mode first
140
- // Then lay out the rest of the threads along the other mode
141
- using GmemTiledCopyAct = decltype(make_tiled_copy(
142
- Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementAct>{},
143
- Layout<Shape <_16, _8>,
144
- Stride< _8, _1>>{},
145
- Layout<Shape < _1, _4>>{}));
146
-
147
- // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
148
- // using SmemLayoutAct = decltype(
149
- // composition(Swizzle<3,2,3>{},
150
- // make_ordered_layout(
151
- // Shape<TileSizeN,TileSizeK,PIPE>{},
152
- // tuple< _1, _0, _2>{})));
153
-
154
- using SmemLayoutAtomAct = decltype(
155
- composition(Swizzle<1,2,3>{},
156
- Layout<Shape <_8,Shape <_4, _2>>,
157
- Stride<_4,Stride<_1,_32>>>{}));
158
-
159
- using SmemCopyAtomAct = Copy_Atom<SM75_U32x4_LDSM_N, ElementAct>;
160
-
161
- //
162
- // Output tensor
163
- //
164
-
165
- using GmemTiledCopyOut = decltype(make_tiled_copy(
166
- Copy_Atom<UniversalCopy<uint128_t>, ElementAct>{},
167
- Layout<Shape <_8, _16>,
168
- Stride<_1, _8>>{},
169
- Layout<Shape <_4, _1>>{}));
170
-
171
- using SmemCopyAtomOut = Copy_Atom<UniversalCopy<uint32_t>, ElementOut>;
172
-
173
- // This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability
174
- using SmemLayoutOut = Layout<Shape<TileSizeM, TileSizeN>>;
175
-
176
- //
177
- // Conv functor
178
- //
179
- template <class EngineFlt, class TensorActivation, class TensorOutput>
180
- void __device__
181
- operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K, (C,T,R,S))
182
- TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S))
183
- TensorOutput mOut, // ( K, (N,Z,P,Q))
184
- char* smem_buf) const {
185
- using namespace cute;
186
- using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
187
- cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
188
- Shape<TileM,TileN,TileK>,
189
- ElementFlt,
190
- Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
191
- ElementAct,
192
- Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
193
- TiledMma,
194
- GmemTiledCopyFlt,
195
- SmemLayoutAtomFlt,
196
- SmemCopyAtomFlt,
197
- cute::identity,
198
- GmemTiledCopyAct,
199
- SmemLayoutAtomAct,
200
- SmemCopyAtomAct,
201
- cute::identity>;
202
-
203
- TiledMma tiled_mma;
204
- Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
205
- clear(accum);
206
-
207
- // Set up tensors
208
- // NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
209
- Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k')
210
- Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1)
211
- Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n')
212
-
213
- // Compute m_coord and n_coord with their post-tiled shapes
214
- auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
215
- auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
216
- Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k')
217
- Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1)
218
- Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N)
219
-
220
- auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
221
- int k_tile_count = size<2>(gA);
222
-
223
- CollectiveMainloop collective_mma;
224
- collective_mma(
225
- accum,
226
- gA,
227
- gB,
228
- accum,
229
- k_tile_iter, k_tile_count,
230
- Underscore{}, // no residue since we do not support predication
231
- threadIdx.x,
232
- smem_buf);
233
-
234
- //
235
- // Epilogue
236
- //
237
- SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
238
- Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{});
239
-
240
- auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma);
241
- auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x);
242
- auto tCrC = smem_thr_copy_C.retile_S(accum);
243
- auto tCsC = smem_thr_copy_C.partition_D(sC);
244
- copy(smem_tiled_copy_C, tCrC, tCsC);
245
-
246
- __syncthreads();
247
-
248
- GmemTiledCopyOut gmem_tiled_copy_C;
249
- auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x);
250
- auto tDsC = gmem_thr_copy_C.partition_S(sC);
251
- auto tDgC = gmem_thr_copy_C.partition_D(gC);
252
- copy(gmem_tiled_copy_C, tDsC, tDgC);
253
-
254
- #if 0
255
- if (thread0()) {
256
- print("mAct = "); print(mAct); print('\n');
257
- print("mFlt = "); print(mFlt); print('\n');
258
- print("mOut = "); print(mOut); print('\n');
259
- print("gA = "); print(gA); print('\n');
260
- print("gB = "); print(gB); print('\n');
261
- print("gC = "); print(gC); print('\n');
262
- print("sA = "); print(sA.layout()); print('\n');
263
- print("sB = "); print(sB.layout()); print('\n');
264
- print("sC = "); print(sC.layout()); print('\n');
265
- print("tAgA = "); print(tAgA.layout()); print('\n');
266
- print("tBgB = "); print(tBgB.layout()); print('\n');
267
- print("tAsA = "); print(tAsA.layout()); print('\n');
268
- print("tBsB = "); print(tBsB.layout()); print('\n');
269
- print("tCsA = "); print(tCsA.layout()); print('\n');
270
- print("tCsB = "); print(tCsB.layout()); print('\n');
271
- print("tCrC = "); print(tCrC.layout()); print('\n');
272
- print("tCsC = "); print(tCsC.layout()); print('\n');
273
- print("tDsC = "); print(tDsC.layout()); print('\n');
274
- print("tDgC = "); print(tDgC.layout()); print('\n');
275
- print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n');
276
- print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n');
277
- print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n');
278
- print("k_tile_count = "); print(size<2>(gA)); print('\n');
279
- print("k_tile_iter = "); print(*k_tile_iter); print('\n');
280
- print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n');
281
- }
282
- #endif
283
- }
284
- };
285
-
286
- template <class TensorFlt, class TensorAct, class TensorOut>
287
- inline int
288
- fprop_reference(
289
- TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S))
290
- TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S))
291
- TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q))
292
- TensorOut mOutputRef) {
293
- int32_t N = size<1,0>(mOutputRef);
294
- int32_t Z = size<1,1>(mOutputRef);
295
- int32_t P = size<1,2>(mOutputRef);
296
- int32_t Q = size<1,3>(mOutputRef);
297
- int32_t T = size<1,3>(mStencil);
298
- int32_t R = size<1,2>(mStencil);
299
- int32_t S = size<1,1>(mStencil);
300
- int32_t C = size<1,0>(mStencil);
301
-
302
- size_t K = static_cast<size_t>(size<0>(mOutputRef));
303
- size_t NZPQ = static_cast<size_t>(size<1>(mOutputRef));
304
- size_t CTRS = static_cast<size_t>(size<1>(mStencil));
305
-
306
- #if defined(_OPENMP)
307
- #pragma omp parallel for
308
- #endif
309
- for (size_t logical_m = 0; logical_m < K; ++logical_m) {
310
- for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) {
311
- auto accumulator = float(0);
312
- for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) {
313
- accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
314
- }
315
- mOutputRef(logical_m, logical_n) = accumulator;
316
- }
317
- }
318
-
319
- return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01);
320
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp DELETED
@@ -1,242 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/gemm/collective/collective_builder.hpp"
35
-
36
- #include "dispatch_policy_extra.hpp"
37
- #include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
38
- #include "../pipeline/prefetch_pipeline_sm90.hpp"
39
-
40
- namespace cutlass::gemm::collective {
41
-
42
- namespace detail {
43
-
44
- // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
45
- template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int stages>
46
- constexpr int
47
- compute_stage_count_or_override_prefetch(StageCount<stages> stage_count) {
48
- return stages;
49
- }
50
-
51
- // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
52
- template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes>
53
- constexpr int
54
- compute_stage_count_or_override_prefetch(StageCountAutoCarveout<carveout_bytes> stage_count) {
55
- constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
56
- constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>);
57
- constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
58
- constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
59
- constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size
60
- constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}));
61
- constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast<int>(mainloop_pipeline_bytes);
62
-
63
- return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes;
64
- }
65
-
66
- } // namespace detail
67
-
68
- // GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
69
- template <
70
- class ElementA,
71
- class GmemLayoutATag,
72
- int AlignmentA,
73
- class ElementB,
74
- class GmemLayoutBTag,
75
- int AlignmentB,
76
- class ElementAccumulator,
77
- class TileShape_MNK,
78
- class ClusterShape_MNK,
79
- class StageCountType,
80
- class KernelScheduleType
81
- >
82
- struct CollectiveBuilder<
83
- arch::Sm90,
84
- arch::OpClassTensorOp,
85
- ElementA,
86
- GmemLayoutATag,
87
- AlignmentA,
88
- ElementB,
89
- GmemLayoutBTag,
90
- AlignmentB,
91
- ElementAccumulator,
92
- TileShape_MNK,
93
- ClusterShape_MNK,
94
- StageCountType,
95
- KernelScheduleType,
96
- cute::enable_if_t<
97
- cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>>
98
- > {
99
- static_assert(is_static<TileShape_MNK>::value);
100
- static_assert(is_static<ClusterShape_MNK>::value);
101
- static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
102
- "Not meet TMA alignment requirement yet\n");
103
- static_assert(detail::is_input_fp8<ElementA, ElementB>(),
104
- "Only FP8 datatypes are compatible with these kernel schedules\n");
105
- // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
106
- static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
107
- "Not supported for fp8 non-TN warp specialized kernels yet\n");
108
- #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
109
- static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
110
- #endif
111
-
112
- static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
113
- static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
114
-
115
- using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
116
-
117
- using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
118
- ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
119
-
120
- using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
121
- using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
122
-
123
- using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
124
- GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
125
- using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
126
- GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
127
-
128
- static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
129
- ElementA, ElementB, TileShape_MNK>(StageCountType{});
130
- using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
131
-
132
- using SmemCopyAtomA = void;
133
- using SmemCopyAtomB = void;
134
-
135
- using CollectiveOp = CollectiveMma<
136
- DispatchPolicy,
137
- TileShape_MNK,
138
- ElementA,
139
- TagToStrideA_t<GmemLayoutATag>,
140
- ElementB,
141
- TagToStrideB_t<GmemLayoutBTag>,
142
- TiledMma,
143
- GmemTiledCopyA,
144
- SmemLayoutAtomA,
145
- SmemCopyAtomA,
146
- cute::identity,
147
- GmemTiledCopyB,
148
- SmemLayoutAtomB,
149
- SmemCopyAtomB,
150
- cute::identity
151
- >;
152
- };
153
-
154
- // GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps
155
- template <
156
- class ElementA,
157
- class GmemLayoutATag,
158
- int AlignmentA,
159
- class ElementB,
160
- class GmemLayoutBTag,
161
- int AlignmentB,
162
- class ElementAccumulator,
163
- class TileShape_MNK,
164
- class ClusterShape_MNK,
165
- class StageCountType,
166
- class KernelScheduleType
167
- >
168
- struct CollectiveBuilder<
169
- arch::Sm90,
170
- arch::OpClassTensorOp,
171
- ElementA,
172
- GmemLayoutATag,
173
- AlignmentA,
174
- ElementB,
175
- GmemLayoutBTag,
176
- AlignmentB,
177
- ElementAccumulator,
178
- TileShape_MNK,
179
- ClusterShape_MNK,
180
- StageCountType,
181
- KernelScheduleType,
182
- cute::enable_if_t<
183
- cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>>
184
- > {
185
- static_assert(is_static<TileShape_MNK>::value);
186
- static_assert(is_static<ClusterShape_MNK>::value);
187
- static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
188
- "Not meet TMA alignment requirement yet\n");
189
- static_assert(detail::is_input_fp8<ElementA, ElementB>(),
190
- "Only FP8 datatypes are compatible with these kernel schedules\n");
191
- // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
192
- static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
193
- "Not supported for fp8 non-TN warp specialized kernels yet\n");
194
- #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
195
- static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
196
- #endif
197
-
198
- static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
199
- static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
200
-
201
- using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
202
-
203
- using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
204
- ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
205
-
206
- using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
207
- using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
208
-
209
- using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
210
- GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
211
- using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
212
- GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
213
-
214
- static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
215
- ElementA, ElementB, TileShape_MNK>(StageCountType{});
216
- using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
217
-
218
- using SmemCopyAtomA = void;
219
- using SmemCopyAtomB = void;
220
-
221
- using CollectiveOp = CollectiveMma<
222
- DispatchPolicy,
223
- TileShape_MNK,
224
- ElementA,
225
- TagToStrideA_t<GmemLayoutATag>,
226
- ElementB,
227
- TagToStrideB_t<GmemLayoutBTag>,
228
- TiledMma,
229
- GmemTiledCopyA,
230
- SmemLayoutAtomA,
231
- SmemCopyAtomA,
232
- cute::identity,
233
- GmemTiledCopyB,
234
- SmemLayoutAtomB,
235
- SmemCopyAtomB,
236
- cute::identity
237
- >;
238
- };
239
-
240
- } // namespace cutlass::gemm::collective
241
-
242
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp DELETED
@@ -1,61 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- namespace cutlass::gemm {
35
-
36
- // Standard non-persistent kernel with a single producer warp, and one prefetch warp.
37
- // `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
38
- // while the producer warp is waiting on griddepcontrol.
39
- // GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and
40
- // according to prefetch ratio.
41
- struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { };
42
-
43
- // Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp.
44
- // `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
45
- // while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not
46
- // wait on griddepcontrol and loads immediately.
47
- struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { };
48
-
49
- template<
50
- int Stages_,
51
- class ClusterShape_ = Shape<_1,_1,_1>,
52
- class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch
53
- >
54
- struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch {
55
- constexpr static int Stages = Stages_;
56
- using ClusterShape = ClusterShape_;
57
- using ArchTag = arch::Sm90;
58
- using Schedule = KernelSchedule;
59
- };
60
-
61
- } // namespace cutlass::gemm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp DELETED
@@ -1,871 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/gemm/dispatch_policy.hpp"
36
- #include "cutlass/numeric_types.h"
37
- #include "cutlass/pipeline/pipeline.hpp"
38
- #include "cutlass/trace.h"
39
-
40
- #include "cute/arch/cluster_sm90.hpp"
41
- #include "cute/arch/copy_sm90.hpp"
42
- #include "cute/algorithm/functional.hpp"
43
- #include "cute/atom/mma_atom.hpp"
44
- #include "cute/algorithm/gemm.hpp"
45
- #include "cute/numeric/arithmetic_tuple.hpp"
46
- #include "cutlass/arch/grid_dependency_control.h"
47
-
48
- #include "dispatch_policy_extra.hpp"
49
-
50
- #include "../pipeline/prefetch_pipeline_sm90.hpp"
51
-
52
- /////////////////////////////////////////////////////////////////////////////////////////////////
53
-
54
- namespace cutlass::gemm::collective {
55
- using namespace cute;
56
-
57
- /////////////////////////////////////////////////////////////////////////////////////////////////
58
-
59
- namespace detail {
60
-
61
- constexpr int PrefetchStages = 4;
62
- constexpr int PrefetchInitialStages = 1;
63
- // This determines how much shmem we set aside for prefetch.
64
- // We don't reuse anything loaded by prefetcher, so we can keep
65
- // loading into the same place -- there will be a conflict when
66
- // writing, but it doesn't affect performance as much as the doors
67
- // that this opens.
68
- constexpr int PrefetchStagesActual = 1;
69
-
70
- } // namespace detail
71
-
72
- // WarpSpecialized Mainloop
73
- template <
74
- int Stages,
75
- class ClusterShape,
76
- class KernelSchedule,
77
- class TileShape_,
78
- class ElementA_,
79
- class StrideA_,
80
- class ElementB_,
81
- class StrideB_,
82
- class TiledMma_,
83
- class GmemTiledCopyA_,
84
- class SmemLayoutAtomA_,
85
- class SmemCopyAtomA_,
86
- class TransformA_,
87
- class GmemTiledCopyB_,
88
- class SmemLayoutAtomB_,
89
- class SmemCopyAtomB_,
90
- class TransformB_>
91
- struct CollectiveMma<
92
- MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>,
93
- TileShape_,
94
- ElementA_,
95
- StrideA_,
96
- ElementB_,
97
- StrideB_,
98
- TiledMma_,
99
- GmemTiledCopyA_,
100
- SmemLayoutAtomA_,
101
- SmemCopyAtomA_,
102
- TransformA_,
103
- GmemTiledCopyB_,
104
- SmemLayoutAtomB_,
105
- SmemCopyAtomB_,
106
- TransformB_>
107
- {
108
- //
109
- // Type Aliases
110
- //
111
- using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>;
112
- using TileShape = TileShape_;
113
- using ElementA = ElementA_;
114
- using StrideA = StrideA_;
115
- using ElementB = ElementB_;
116
- using StrideB = StrideB_;
117
- using TiledMma = TiledMma_;
118
- using ElementAccumulator = typename TiledMma::ValTypeC;
119
- using GmemTiledCopyA = GmemTiledCopyA_;
120
- using GmemTiledCopyB = GmemTiledCopyB_;
121
- using SmemLayoutAtomA = SmemLayoutAtomA_;
122
- using SmemLayoutAtomB = SmemLayoutAtomB_;
123
- using SmemCopyAtomA = SmemCopyAtomA_;
124
- using SmemCopyAtomB = SmemCopyAtomB_;
125
- using TransformA = TransformA_;
126
- using TransformB = TransformB_;
127
- using ArchTag = typename DispatchPolicy::ArchTag;
128
-
129
- static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
130
- using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
131
-
132
- using PrefetcherPipeline = cutlass::PrefetchPipeline<detail::PrefetchStages>;
133
-
134
- using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
135
- using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
136
- using PipelineParams = typename MainloopPipeline::Params;
137
-
138
- static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
139
- static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
140
- static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
141
-
142
- static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
143
- static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
144
- static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
145
-
146
- // Tile along modes in a way that maximizes the TMA box size.
147
- using SmemLayoutA = decltype(tile_to_shape(
148
- SmemLayoutAtomA{},
149
- make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
150
- cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
151
- using SmemLayoutB = decltype(tile_to_shape(
152
- SmemLayoutAtomB{},
153
- make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
154
- cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
155
-
156
- static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages);
157
- static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages);
158
-
159
- using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
160
- cute::Int<size<0>(SmemLayoutA{})>{},
161
- cute::Int<size<1>(SmemLayoutA{})>{},
162
- cute::Int<detail::PrefetchStagesActual>{})));
163
-
164
- static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
165
-
166
- static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
167
- static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
168
- cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
169
- "MMA atom must source both A and B operand from smem_desc for this mainloop.");
170
- static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
171
- "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
172
- static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
173
- "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
174
-
175
- // TMA converts f32 input to tf32 when copying from GMEM to SMEM
176
- // For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
177
- static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
178
- static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
179
- using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
180
- using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
181
-
182
- // Defined outside the class where it's used, to work around MSVC issues
183
- using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<detail::PrefetchStages>;
184
-
185
- struct SharedStorage {
186
- struct TensorStorage : cute::aligned_struct<128, _0> {
187
- cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
188
- cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
189
- cute::array_aligned<typename TiledMma::ValTypeA, prefetch_smem_size> smem_prefetch;
190
- } tensors;
191
-
192
- using PipelineStorage = typename MainloopPipeline::SharedStorage;
193
- PipelineStorage pipeline;
194
- PrefetcherPipelineStorage prefetcher_pipeline;
195
- };
196
- using TensorStorage = typename SharedStorage::TensorStorage;
197
- using PipelineStorage = typename SharedStorage::PipelineStorage;
198
-
199
- // Host side kernel arguments
200
- struct Arguments {
201
- ElementA const* ptr_A;
202
- StrideA dA;
203
- ElementB const* ptr_B;
204
- StrideB dB;
205
- uint32_t mma_promotion_interval = 4;
206
- float overlap_ratio = 0.5;
207
- float prefetch_ratio = -1.0;
208
- };
209
-
210
- // Device side kernel params
211
- struct Params {
212
- // Assumption: StrideA is congruent with Problem_MK
213
- using TMA_A = decltype(make_tma_copy_A_sm90(
214
- GmemTiledCopyA{},
215
- make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
216
- SmemLayoutA{}(_,_,cute::Int<0>{}),
217
- TileShape{},
218
- ClusterShape{}));
219
- // Assumption: StrideB is congruent with Problem_NK
220
- using TMA_B = decltype(make_tma_copy_B_sm90(
221
- GmemTiledCopyB{},
222
- make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
223
- SmemLayoutB{}(_,_,cute::Int<0>{}),
224
- TileShape{},
225
- ClusterShape{}));
226
-
227
- TMA_A tma_load_a;
228
- TMA_B tma_load_b;
229
- uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
230
- uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
231
- uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
232
- float overlap_ratio = 0.5;
233
- float prefetch_ratio = -1.0;
234
- };
235
-
236
- //
237
- // Methods
238
- //
239
-
240
- template <class ProblemShape>
241
- static constexpr Params
242
- to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
243
- (void) workspace;
244
-
245
- // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
246
- auto problem_shape_MNKL = append<4>(problem_shape, 1);
247
- auto [M,N,K,L] = problem_shape_MNKL;
248
-
249
- auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
250
- auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
251
-
252
- Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
253
- Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
254
-
255
- typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
256
- GmemTiledCopyA{},
257
- tensor_a,
258
- SmemLayoutA{}(_,_,cute::Int<0>{}),
259
- TileShape{},
260
- ClusterShape{});
261
- typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
262
- GmemTiledCopyB{},
263
- tensor_b,
264
- SmemLayoutB{}(_,_,cute::Int<0>{}),
265
- TileShape{},
266
- ClusterShape{});
267
- uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
268
- uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
269
- uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
270
-
271
- return {
272
- tma_load_a,
273
- tma_load_b,
274
- transaction_bytes,
275
- transaction_bytes_mk,
276
- transaction_bytes_nk,
277
- args.overlap_ratio,
278
- args.prefetch_ratio
279
- };
280
- }
281
-
282
- template<class ProblemShape>
283
- static bool
284
- can_implement(
285
- ProblemShape const& problem_shape,
286
- [[maybe_unused]] Arguments const& args) {
287
- constexpr int tma_alignment_bits = 128;
288
- auto problem_shape_MNKL = append<4>(problem_shape, 1);
289
- auto [M,N,K,L] = problem_shape_MNKL;
290
-
291
- constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
292
- bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
293
- constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
294
- implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
295
-
296
- if (!implementable) {
297
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
298
- return false;
299
- }
300
-
301
- if (args.overlap_ratio > 1.0) {
302
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n");
303
- return false;
304
- }
305
-
306
- if (args.prefetch_ratio > 1.0) {
307
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n");
308
- return false;
309
- }
310
-
311
- return true;
312
- }
313
-
314
- static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
315
- static constexpr int K_PIPE_MMAS = 1;
316
- static constexpr uint32_t TmaTransactionBytesMK =
317
- cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
318
- static constexpr uint32_t TmaTransactionBytesNK =
319
- cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
320
-
321
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
322
- CUTLASS_DEVICE
323
- static void prefetch_tma_descriptors(Params const& mainloop_params) {
324
- cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
325
- cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
326
- }
327
-
328
- /// Set up the data needed by this collective for load and mma.
329
- /// Returns a tuple of tensors. The collective and the kernel layer have the contract
330
- /// Returned tuple must contain at least two elements, with the first two elements being:
331
- /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
332
- /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
333
- /// The rest of the tensors can be specified as needed by this collective.
334
- template <class ProblemShape_MNKL>
335
- CUTLASS_DEVICE auto
336
- load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
337
- using X = Underscore;
338
- // Separate out problem shape for convenience
339
- auto [M,N,K,L] = problem_shape_MNKL;
340
-
341
- // TMA requires special handling of strides to deal with coord codomain mapping
342
- // Represent the full tensors -- get these from TMA
343
- Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
344
- Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
345
-
346
- // Make tiled views, defer the slice
347
- Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
348
- Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
349
-
350
- return cute::make_tuple(gA_mkl, gB_nkl);
351
- }
352
-
353
- template <
354
- class TensorA, class TensorB,
355
- class KTileIterator, class BlockCoord
356
- >
357
- CUTLASS_DEVICE void
358
- load(
359
- Params const& mainloop_params,
360
- MainloopPipeline pipeline,
361
- PrefetcherPipeline prefetcher_pipeline,
362
- PipelineState smem_pipe_write,
363
- TensorA const& gA_mkl,
364
- TensorB const& gB_nkl,
365
- BlockCoord const& blk_coord,
366
- KTileIterator k_tile_iter, int k_tile_count,
367
- int thread_idx,
368
- uint32_t block_rank_in_cluster,
369
- TensorStorage& shared_tensors) {
370
- int lane_predicate = cute::elect_one_sync();
371
-
372
- if (lane_predicate) {
373
- bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
374
- float overlap_ratio = mainloop_params.overlap_ratio;
375
- int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
376
-
377
- Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
378
- Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
379
-
380
- //
381
- // Prepare the TMA loads for A
382
- //
383
-
384
- constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
385
- uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
386
-
387
- auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
388
- auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
389
-
390
- // Partition the inputs based on the current block coordinates.
391
- auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
392
- Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
393
- Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
394
-
395
- // Applies the mapping from cta_tma_a
396
- Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
397
- Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
398
-
399
- // Applies the mapping from cta_tma_b
400
- Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
401
- Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
402
-
403
- uint16_t mcast_mask_a = 0;
404
- uint16_t mcast_mask_b = 0;
405
-
406
- // Issue TmaLoads
407
- // Maps the tile -> block, value
408
- if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
409
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
410
- for (int n = 0; n < size<1>(block_layout); ++n) {
411
- mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
412
- }
413
- }
414
-
415
- if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
416
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
417
- for (int m = 0; m < size<0>(block_layout); ++m) {
418
- mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
419
- }
420
- }
421
-
422
- // We have to wait on dependent grids because of B.
423
- cutlass::arch::wait_on_dependent_grids();
424
-
425
- // Signal prefetcher to stop
426
- prefetcher_pipeline.producer_arrive();
427
-
428
- bool launch_dep_grids = false;
429
- // Mainloop
430
- CUTLASS_PRAGMA_NO_UNROLL
431
- for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
432
- // LOCK smem_pipe_write for _writing_
433
- pipeline.producer_acquire(smem_pipe_write);
434
-
435
- //
436
- // Copy gmem to smem for *k_tile_iter
437
- //
438
-
439
- using BarrierType = typename MainloopPipeline::ProducerBarrierType;
440
- BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
441
-
442
- int write_stage = smem_pipe_write.index();
443
- copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
444
- copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
445
- ++k_tile_iter;
446
-
447
- if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
448
- launch_dep_grids = true;
449
- cutlass::arch::launch_dependent_grids();
450
- }
451
-
452
- // Advance smem_pipe_write
453
- ++smem_pipe_write;
454
- }
455
- if (!disable_gdc && !launch_dep_grids) {
456
- cutlass::arch::launch_dependent_grids();
457
- }
458
- }
459
- }
460
-
461
- template <
462
- class TensorA,
463
- class KTileIterator, class BlockCoord
464
- >
465
- CUTLASS_DEVICE void
466
- load_MK(
467
- Params const& mainloop_params,
468
- MainloopPipeline pipeline,
469
- PrefetcherPipeline prefetcher_pipeline,
470
- PipelineState smem_pipe_write,
471
- TensorA const& gA_mkl,
472
- BlockCoord const& blk_coord,
473
- KTileIterator k_tile_iter, int k_tile_count,
474
- int thread_idx,
475
- uint32_t block_rank_in_cluster,
476
- TensorStorage& shared_tensors) {
477
- int lane_predicate = cute::elect_one_sync();
478
-
479
- if (lane_predicate) {
480
- bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
481
- float overlap_ratio = mainloop_params.overlap_ratio;
482
- int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
483
-
484
- Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
485
-
486
- //
487
- // Prepare the TMA loads for A
488
- //
489
-
490
- constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
491
- uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
492
-
493
- auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
494
-
495
- // Partition the inputs based on the current block coordinates.
496
- auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
497
- Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
498
-
499
- // Applies the mapping from cta_tma_a
500
- Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
501
- Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
502
-
503
- uint16_t mcast_mask_a = 0;
504
-
505
- // Issue TmaLoads
506
- // Maps the tile -> block, value
507
- if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
508
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
509
- for (int n = 0; n < size<1>(block_layout); ++n) {
510
- mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
511
- }
512
- }
513
-
514
- // Don't wait on dependent grids when loading `A`, because
515
- // we assume `A` (weights) are static.
516
-
517
- bool launch_dep_grids = false;
518
- // Mainloop
519
- CUTLASS_PRAGMA_NO_UNROLL
520
- for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
521
- // LOCK smem_pipe_write for _writing_
522
- pipeline.producer_acquire(smem_pipe_write);
523
-
524
- //
525
- // Copy gmem to smem for *k_tile_iter
526
- //
527
-
528
- using BarrierType = typename MainloopPipeline::ProducerBarrierType;
529
- BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
530
-
531
- int write_stage = smem_pipe_write.index();
532
- copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
533
- ++k_tile_iter;
534
-
535
- if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
536
- launch_dep_grids = true;
537
- cutlass::arch::launch_dependent_grids();
538
- }
539
-
540
- // Advance smem_pipe_write
541
- ++smem_pipe_write;
542
- }
543
- if (!disable_gdc && !launch_dep_grids) {
544
- cutlass::arch::launch_dependent_grids();
545
- }
546
- }
547
- }
548
-
549
- template <
550
- class TensorB,
551
- class KTileIterator, class BlockCoord
552
- >
553
- CUTLASS_DEVICE void
554
- load_NK(
555
- Params const& mainloop_params,
556
- MainloopPipeline pipeline,
557
- PrefetcherPipeline prefetcher_pipeline,
558
- PipelineState smem_pipe_write,
559
- TensorB const& gB_nkl,
560
- BlockCoord const& blk_coord,
561
- KTileIterator k_tile_iter, int k_tile_count,
562
- int thread_idx,
563
- uint32_t block_rank_in_cluster,
564
- TensorStorage& shared_tensors) {
565
- int lane_predicate = cute::elect_one_sync();
566
-
567
- if (lane_predicate) {
568
- Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
569
-
570
- //
571
- // Prepare the TMA loads for B
572
- //
573
-
574
- constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
575
- uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
576
-
577
- auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
578
-
579
- // Partition the inputs based on the current block coordinates.
580
- auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
581
- Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
582
-
583
- // Applies the mapping from cta_tma_b
584
- Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
585
- Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
586
-
587
- uint16_t mcast_mask_b = 0;
588
-
589
- // Issue TmaLoads
590
- // Maps the tile -> block, value
591
- if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
592
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
593
- for (int m = 0; m < size<0>(block_layout); ++m) {
594
- mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
595
- }
596
- }
597
-
598
- // Ensure that the prefetched kernel does not touch
599
- // unflushed global memory prior to this instruction
600
- cutlass::arch::wait_on_dependent_grids();
601
-
602
- // Signal prefetcher to stop
603
- prefetcher_pipeline.producer_arrive();
604
-
605
- // Mainloop
606
- CUTLASS_PRAGMA_NO_UNROLL
607
- for (; k_tile_count > 0; --k_tile_count) {
608
- // LOCK smem_pipe_write for _writing_
609
- pipeline.producer_acquire(smem_pipe_write);
610
-
611
- //
612
- // Copy gmem to smem for *k_tile_iter
613
- //
614
-
615
- using BarrierType = typename MainloopPipeline::ProducerBarrierType;
616
- BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
617
-
618
- int write_stage = smem_pipe_write.index();
619
- copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
620
- ++k_tile_iter;
621
-
622
- // Advance smem_pipe_write
623
- ++smem_pipe_write;
624
- }
625
- }
626
- }
627
-
628
- /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
629
- CUTLASS_DEVICE void
630
- load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
631
- int lane_predicate = cute::elect_one_sync();
632
-
633
- // Issue the epilogue waits
634
- if (lane_predicate) {
635
- /* This helps avoid early exit of blocks in Cluster
636
- * Waits for all stages to either be released (all
637
- * Consumer UNLOCKs), or if the stage was never used
638
- * then would just be acquired since the phase was
639
- * still inverted from make_producer_start_state
640
- */
641
- pipeline.producer_tail(smem_pipe_write);
642
- }
643
- }
644
-
645
-
646
- template <
647
- class TensorA,
648
- class KTileIterator, class BlockCoord
649
- >
650
- CUTLASS_DEVICE void
651
- prefetch_MK(
652
- Params const& mainloop_params,
653
- PrefetcherPipeline prefetcher_pipeline,
654
- PipelineState smem_pipe_write,
655
- TensorA const& gA_mkl,
656
- BlockCoord const& blk_coord,
657
- KTileIterator k_tile_iter, int k_tile_count,
658
- int thread_idx,
659
- uint32_t block_rank_in_cluster,
660
- TensorStorage& shared_tensors) {
661
- int lane_predicate = cute::elect_one_sync();
662
-
663
- if (lane_predicate) {
664
- bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
665
- float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
666
- int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
667
- prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages);
668
-
669
- Tensor sA = make_tensor(
670
- make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
671
-
672
- //
673
- // Prepare the TMA loads for A
674
- //
675
-
676
- constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
677
- uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
678
-
679
- auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
680
-
681
- // Partition the inputs based on the current block coordinates.
682
- auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
683
- Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
684
-
685
- // Applies the mapping from cta_tma_a
686
- Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
687
- Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
688
-
689
- uint16_t mcast_mask_a = 0;
690
-
691
- // Issue TmaLoads
692
- // Maps the tile -> block, value
693
- if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
694
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
695
- for (int n = 0; n < size<1>(block_layout); ++n) {
696
- mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
697
- }
698
- }
699
-
700
- uint32_t prefetcher_stage = 0;
701
- uint32_t prefetcher_phase = 0;
702
- CUTLASS_PRAGMA_NO_UNROLL
703
- for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) {
704
-
705
- if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) {
706
- break;
707
- }
708
-
709
- prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages);
710
- using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
711
- BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);
712
-
713
- int write_stage = 0;
714
- copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
715
- ++k_tile_iter;
716
- ++k_tile_iter;
717
-
718
- prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase);
719
- }
720
- prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase);
721
- }
722
- }
723
-
724
- /// Perform a collective-scoped matrix multiply-accumulate
725
- /// Consumer Perspective
726
- template <
727
- class FrgTensorC
728
- >
729
- CUTLASS_DEVICE void
730
- mma(MainloopPipeline pipeline,
731
- PipelineState smem_pipe_read,
732
- FrgTensorC& accum,
733
- int k_tile_count,
734
- int thread_idx,
735
- TensorStorage& shared_tensors,
736
- Params const& mainloop_params) {
737
- static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
738
- static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
739
- static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
740
- static_assert(cute::is_void_v<SmemCopyAtomA>,
741
- "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
742
- static_assert(cute::is_void_v<SmemCopyAtomB>,
743
- "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
744
-
745
- Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
746
- Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
747
-
748
- //
749
- // Define C accumulators and A/B partitioning
750
- //
751
-
752
- TiledMma tiled_mma;
753
- auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
754
-
755
- Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
756
- Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
757
-
758
- // Allocate "fragments/descriptors"
759
- Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
760
- Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
761
-
762
- CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
763
- CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
764
- CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
765
- CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
766
- CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
767
- CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
768
-
769
- //
770
- // PIPELINED MAIN LOOP
771
- //
772
- static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
773
- "ERROR : Incorrect number of MMAs in flight");
774
-
775
- // We release buffers to producer warps(dma load) with some mmas in flight
776
- PipelineState smem_pipe_release = smem_pipe_read;
777
-
778
- // Prologue GMMAs
779
- int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
780
-
781
- tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
782
-
783
- warpgroup_fence_operand(accum);
784
- CUTLASS_PRAGMA_UNROLL
785
- for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
786
- {
787
- // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
788
- auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
789
- pipeline.consumer_wait(smem_pipe_read, barrier_token);
790
-
791
- int read_stage = smem_pipe_read.index();
792
- warpgroup_arrive();
793
- // Unroll the K mode manually to set scale D to 1
794
- CUTLASS_PRAGMA_UNROLL
795
- for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
796
- // (V,M,K) x (V,N,K) => (V,M,N)
797
- cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
798
- tiled_mma.accumulate_ = GMMA::ScaleOut::One;
799
- }
800
-
801
- warpgroup_commit_batch();
802
-
803
- ++smem_pipe_read;
804
- }
805
-
806
- warpgroup_fence_operand(accum);
807
- // Mainloop GMMAs
808
- k_tile_count -= prologue_mma_count;
809
-
810
- CUTLASS_PRAGMA_NO_UNROLL
811
- for ( ; k_tile_count > 0; --k_tile_count)
812
- {
813
- // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
814
- auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
815
- pipeline.consumer_wait(smem_pipe_read, barrier_token);
816
-
817
- //
818
- // Compute on k_tile
819
- //
820
-
821
- int read_stage = smem_pipe_read.index();
822
- warpgroup_fence_operand(accum);
823
- warpgroup_arrive();
824
- // Unroll the K mode manually to set scale D to 1
825
- CUTLASS_PRAGMA_UNROLL
826
- for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
827
- // (V,M,K) x (V,N,K) => (V,M,N)
828
- cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
829
- tiled_mma.accumulate_ = GMMA::ScaleOut::One;
830
- }
831
- warpgroup_commit_batch();
832
-
833
- /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
834
- warpgroup_wait<K_PIPE_MMAS>();
835
- warpgroup_fence_operand(accum);
836
-
837
- // UNLOCK smem_pipe_release, done _computing_ on it
838
- pipeline.consumer_release(smem_pipe_release);
839
-
840
- // Advance smem_pipe_read and smem_pipe_release
841
- ++smem_pipe_read;
842
- ++smem_pipe_release;
843
- }
844
-
845
- warpgroup_fence_operand(accum);
846
- }
847
-
848
- /// Perform a Consumer Epilogue to release all buffers
849
- CUTLASS_DEVICE void
850
- mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
851
- // Prologue GMMAs
852
- int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
853
- k_tile_count -= prologue_mma_count;
854
-
855
- smem_pipe_release.advance(k_tile_count);
856
-
857
- // Wait on all GMMAs to complete
858
- warpgroup_wait<0>();
859
-
860
- for (int count = 0; count < prologue_mma_count; ++count) {
861
- pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
862
- ++smem_pipe_release;
863
- }
864
- }
865
- };
866
-
867
- /////////////////////////////////////////////////////////////////////////////////////////////////
868
-
869
- } // namespace cutlass::gemm::collective
870
-
871
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp DELETED
@@ -1,117 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- // Command line options parsing
33
- struct Options {
34
-
35
- bool help = false;
36
-
37
- float alpha = 1.f, beta = 0.f;
38
- float overlap_ratio = 0.5f, prefetch_ratio = 0.5f;
39
- int iterations = 1000;
40
- int n = 64, m = 1280, k = 8192, l = 1;
41
-
42
- // Parses the command line
43
- void parse(int argc, char const **args) {
44
- cutlass::CommandLine cmd(argc, args);
45
-
46
- if (cmd.check_cmd_line_flag("help")) {
47
- help = true;
48
- return;
49
- }
50
-
51
- cmd.get_cmd_line_argument("m", m);
52
- cmd.get_cmd_line_argument("n", n);
53
- cmd.get_cmd_line_argument("k", k);
54
- cmd.get_cmd_line_argument("l", l);
55
- cmd.get_cmd_line_argument("alpha", alpha, 1.f);
56
- cmd.get_cmd_line_argument("beta", beta, 0.f);
57
- cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f);
58
- cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f);
59
- cmd.get_cmd_line_argument("iterations", iterations);
60
- }
61
-
62
- /// Prints the usage statement.
63
- std::ostream & print_usage(std::ostream &out) const {
64
-
65
- out << "63_hopper_gemm_with_weight_prefetch\n\n"
66
- << " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n"
67
- << " For more details please refer to the source file.\n\n"
68
- << "Options:\n\n"
69
- << " --help If specified, displays this usage statement\n\n"
70
- << " --m=<int> Sets the M extent of the GEMM\n"
71
- << " --n=<int> Sets the N extent of the GEMM\n"
72
- << " --k=<int> Sets the K extent of the GEMM\n"
73
- << " --l=<int> Sets the l extent (batch) of the GEMM\n"
74
- << " --alpha=<f32> Epilogue scalar alpha\n"
75
- << " --beta=<f32> Epilogue scalar beta\n"
76
- << " --p=<f32> Prefetch ratio\n"
77
- << " --o=<f32> Overlap ratio\n"
78
- << " --iterations=<int> Number of profiling iterations to perform.\n\n";
79
-
80
- out
81
- << "\n\nExamples:\n\n"
82
- << "$ " << "63_hopper_gemm_with_weight_prefetch" <<
83
- " --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n";
84
-
85
- return out;
86
- }
87
-
88
- /// Compute performance in GFLOP/s
89
- double gflops(double runtime_s) const
90
- {
91
- // Two flops per multiply-add
92
- uint64_t flop = uint64_t(2) * m * n * k * l;
93
- double gflop = double(flop) / double(1.0e9);
94
- return gflop / runtime_s;
95
- }
96
-
97
- /// Compute effective bandwidth in GB/sec
98
- double effective_bandwidth(
99
- double runtime_s,
100
- size_t bytes_a,
101
- size_t bytes_b,
102
- size_t bytes_c,
103
- size_t bytes_d
104
- ) const
105
- {
106
- static double const kBytesPerGiB = double(1ull << 30);
107
-
108
- double bytes_in =
109
- (double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A
110
- (double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B
111
- (beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C
112
- double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D
113
-
114
- double gb_total = (bytes_in + bytes_out) / kBytesPerGiB;
115
- return gb_total / runtime_s;
116
- }
117
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp DELETED
@@ -1,561 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- #include "cutlass/cutlass.h"
35
- #include "cutlass/fast_math.h"
36
- #include "cutlass/kernel_hardware_info.hpp"
37
- #include "cute/arch/cluster_sm90.hpp"
38
- #include "cutlass/arch/reg_reconfig.h"
39
- #include "cutlass/arch/mma_sm90.h"
40
- #include "cutlass/epilogue/collective/detail.hpp"
41
- #include "cutlass/gemm/gemm.h"
42
- #include "cutlass/gemm/dispatch_policy.hpp"
43
- #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
44
- #include "cutlass/pipeline/pipeline.hpp"
45
- #include "cutlass/trace.h"
46
-
47
- #include "cute/tensor.hpp"
48
-
49
- #include "../collective/dispatch_policy_extra.hpp"
50
-
51
- ///////////////////////////////////////////////////////////////////////////////
52
-
53
- namespace cutlass::gemm::kernel {
54
-
55
- ///////////////////////////////////////////////////////////////////////////////
56
-
57
- // GEMM + Prefetch for the A tensor + (optional) split DMA warps
58
- template <
59
- class ProblemShape_,
60
- class CollectiveMainloop_,
61
- class CollectiveEpilogue_,
62
- class TileScheduler_
63
- >
64
- class GemmUniversal<
65
- ProblemShape_,
66
- CollectiveMainloop_,
67
- CollectiveEpilogue_,
68
- TileScheduler_,
69
- cute::enable_if_t<
70
- cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA> ||
71
- cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>
72
- >
73
- >
74
- {
75
- public:
76
- //
77
- // Type Aliases
78
- //
79
- using ProblemShape = ProblemShape_;
80
- static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
81
- "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
82
- static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
83
-
84
- static constexpr bool SplitWarps = cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>;
85
-
86
- // Mainloop derived types
87
- using CollectiveMainloop = CollectiveMainloop_;
88
- using TileShape = typename CollectiveMainloop::TileShape;
89
- using TiledMma = typename CollectiveMainloop::TiledMma;
90
- using ArchTag = typename CollectiveMainloop::ArchTag;
91
- using ElementA = typename CollectiveMainloop::ElementA;
92
- using StrideA = typename CollectiveMainloop::StrideA;
93
- using ElementB = typename CollectiveMainloop::ElementB;
94
- using StrideB = typename CollectiveMainloop::StrideB;
95
- using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
96
- using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
97
- using ClusterShape = typename DispatchPolicy::ClusterShape;
98
- using MainloopArguments = typename CollectiveMainloop::Arguments;
99
- using MainloopParams = typename CollectiveMainloop::Params;
100
- static_assert(ArchTag::kMinComputeCapability >= 90);
101
-
102
- // Epilogue derived types
103
- using CollectiveEpilogue = CollectiveEpilogue_;
104
- using ElementC = typename CollectiveEpilogue::ElementC;
105
- using StrideC = typename CollectiveEpilogue::StrideC;
106
- using ElementD = typename CollectiveEpilogue::ElementD;
107
- using StrideD = typename CollectiveEpilogue::StrideD;
108
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
109
- using EpilogueParams = typename CollectiveEpilogue::Params;
110
-
111
- static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
112
- "TMA warp-specialized kernel does not support specializing the tile scheduler.");
113
- using TileSchedulerTag = TileScheduler_;
114
- using TileScheduler = typename detail::TileSchedulerSelector<
115
- TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
116
- using TileSchedulerArguments = typename TileScheduler::Arguments;
117
-
118
- // Kernel level shared memory storage
119
- struct SharedStorage {
120
- // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union
121
- union TensorStorage {
122
- using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
123
- using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
124
-
125
- MainloopTensorStorage mainloop;
126
- EpilogueTensorStorage epilogue;
127
- } tensors;
128
-
129
- struct PipelineStorage : cute::aligned_struct<16, _1> {
130
- using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
131
- using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage;
132
- using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
133
-
134
- alignas(16) MainloopPipelineStorage mainloop;
135
- alignas(16) EpiLoadPipelineStorage epi_load;
136
- alignas(16) PrefetcherPipelineStorage prefetcher;
137
- } pipelines;
138
- };
139
-
140
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
141
-
142
- static constexpr uint32_t NumLoadWarpGroups = 1;
143
- static constexpr uint32_t NumMmaWarpGroups = 1;
144
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
145
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
146
-
147
- // Device side arguments
148
- struct Arguments {
149
- GemmUniversalMode mode{};
150
- ProblemShape problem_shape{};
151
- MainloopArguments mainloop{};
152
- EpilogueArguments epilogue{};
153
- KernelHardwareInfo hw_info{};
154
- TileSchedulerArguments scheduler{};
155
- };
156
-
157
- // Kernel entry point API
158
- struct Params {
159
- GemmUniversalMode mode{};
160
- ProblemShape problem_shape{};
161
- MainloopParams mainloop{};
162
- EpilogueParams epilogue{};
163
- };
164
-
165
- //
166
- // Methods
167
- //
168
-
169
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
170
- static
171
- Params
172
- to_underlying_arguments(Arguments const& args, void* workspace) {
173
- (void) workspace;
174
- auto problem_shape = args.problem_shape;
175
- if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
176
- // swap M/N
177
- get<0>(problem_shape) = get<1>(args.problem_shape);
178
- get<1>(problem_shape) = get<0>(args.problem_shape);
179
- }
180
- return {
181
- args.mode,
182
- problem_shape,
183
- CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
184
- CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
185
- };
186
- }
187
-
188
- static bool
189
- can_implement(Arguments const& args) {
190
- bool implementable = (args.mode == GemmUniversalMode::kGemm) or
191
- (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
192
- if (!implementable) {
193
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
194
- return implementable;
195
- }
196
- implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
197
- implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
198
- implementable &= TileScheduler::can_implement(args.scheduler);
199
-
200
- return implementable;
201
- }
202
-
203
- static
204
- size_t
205
- get_workspace_size(Arguments const& args) {
206
- return 0;
207
- }
208
-
209
- static
210
- cutlass::Status
211
- initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
212
- CudaHostAdapter* cuda_adapter = nullptr) {
213
- return Status::kSuccess;
214
- }
215
-
216
- // Computes the kernel launch grid shape based on runtime parameters
217
- static dim3
218
- get_grid_shape(Params const& params) {
219
- auto cluster_shape = ClusterShape{};
220
- auto tile_shape = TileShape{};
221
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
222
- return TileScheduler::get_tiled_cta_shape_mnl(
223
- problem_shape_MNKL, tile_shape, cluster_shape);
224
- }
225
-
226
- static dim3
227
- get_block_shape() {
228
- return dim3(MaxThreadsPerBlock, 1, 1);
229
- }
230
-
231
- CUTLASS_DEVICE
232
- void
233
- operator()(Params const& params, char* smem_buf) {
234
- using namespace cute;
235
- using X = Underscore;
236
-
237
- #if defined(__CUDA_ARCH_FEAT_SM90_ALL)
238
- # define ENABLE_SM90_KERNEL_LEVEL 1
239
- #endif
240
-
241
- // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
242
- #if ! defined(ENABLE_SM90_KERNEL_LEVEL)
243
- printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
244
- #else
245
-
246
- enum class WarpGroupRole {
247
- Producer = 0,
248
- Consumer = 1,
249
- };
250
- // Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK.
251
- // Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused.
252
- // Both modes use Warp1 to prefetch.
253
- enum class ProducerWarpRole {
254
- Warp0 = 0,
255
- PrefetchMK = 1,
256
- Warp2 = 2,
257
- UnusedWarp = 3
258
- };
259
-
260
- // Kernel level shared memory storage
261
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
262
-
263
- int thread_idx = int(threadIdx.x);
264
- int lane_idx = canonical_lane_idx();
265
- int warp_idx = canonical_warp_idx_sync();
266
- int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
267
- int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
268
- auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
269
- auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
270
- int lane_predicate = cute::elect_one_sync();
271
- uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
272
-
273
-
274
- // Issue Tma Descriptor Prefetch from a single thread
275
- if ((warp_idx == 0) && lane_predicate) {
276
- CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
277
- CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
278
- }
279
-
280
- // Mainloop Load pipeline
281
- using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
282
- typename MainloopPipeline::Params mainloop_pipeline_params;
283
- mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
284
- if (warp_group_role == WarpGroupRole::Producer && (
285
- producer_warp_role == ProducerWarpRole::Warp0 ||
286
- producer_warp_role == ProducerWarpRole::Warp2)) {
287
- mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
288
- mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
289
- }
290
- if (warp_group_role == WarpGroupRole::Consumer) {
291
- mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
292
- }
293
- mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
294
- MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
295
- bool should_prefetch = params.mainloop.prefetch_ratio > 0;
296
- using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline;
297
- typename PrefetcherPipeline::Params prefetcher_pipeline_params;
298
- prefetcher_pipeline_params.num_prefetchers = 1;
299
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
300
- prefetcher_pipeline_params.should_prefetch = should_prefetch;
301
- prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk;
302
- }
303
- PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params);
304
-
305
- // Epilogue Load pipeline
306
- using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
307
- typename EpiLoadPipeline::Params epi_load_pipeline_params;
308
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) {
309
- epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
310
- }
311
- if (warp_group_role == WarpGroupRole::Consumer) {
312
- epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
313
- }
314
- epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
315
- epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
316
- epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
317
- if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
318
- epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
319
- }
320
- EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
321
-
322
- // Epilogue Store pipeline
323
- using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
324
- typename EpiStorePipeline::Params epi_store_pipeline_params;
325
- epi_store_pipeline_params.always_wait = true;
326
- EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
327
-
328
- // Initialize starting pipeline states for the collectives
329
- // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
330
- typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
331
- typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
332
-
333
- // For the DMA Load (producer) we start with an opposite phase
334
- // i.e., we skip all waits since we know that the buffer is indeed empty
335
- PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
336
- PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
337
- PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
338
-
339
- auto cluster_wait_fn = [&] () {
340
- // We need this to guarantee that the Pipeline init is visible
341
- // To all producers and consumer thread blocks in the Cluster
342
- if constexpr (size(ClusterShape{}) > 1) {
343
- // Non-prefetcher warps arrive and wait,
344
- // Prefetcher warp can go ahead without waiting.
345
- cute::cluster_arrive_relaxed();
346
- if (warp_group_role != WarpGroupRole::Producer ||
347
- producer_warp_role != ProducerWarpRole::PrefetchMK) {
348
- cute::cluster_wait();
349
- }
350
- return [] () {};
351
- }
352
- else {
353
- // __syncthreads() but only for non prefetcher warps
354
- if (should_prefetch) {
355
-
356
- // Use a named barrier to let the prefetcher warp start loading into the L2
357
- // without waiting to sync with all other warps.
358
- // All other warps need to sync because the mainloop pipeline init
359
- // should be visible to all of them.
360
- // Prefetcher has its own barriers, and the only warps it would need to sync
361
- // with would be the DMA warps.
362
- using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
363
- auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
364
- blockDim.x * blockDim.y * blockDim.z,
365
- /*id*/ 0);
366
- // Prefetcher warp doesn't arrive on this barrier.
367
- auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
368
- blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
369
- /*id*/ 1);
370
-
371
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
372
- __syncwarp();
373
- prefetcher_arrive_barrier.arrive();
374
- }
375
- else if (warp_group_role == WarpGroupRole::Producer) {
376
- prefetcher_arrive_barrier.arrive_and_wait();
377
- cluster_arrive_barrier.arrive_and_wait();
378
- }
379
- else {
380
- prefetcher_arrive_barrier.arrive();
381
- cluster_arrive_barrier.arrive_and_wait();
382
- }
383
- } else {
384
- __syncthreads();
385
- }
386
- return [] () {};
387
- }
388
- } ();
389
-
390
- // Preconditions
391
- static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
392
- static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
393
- static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
394
- static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
395
-
396
- // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
397
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
398
-
399
- // Get the appropriate blocks for this thread block -- potential for thread block locality
400
- auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
401
- TiledMma tiled_mma;
402
-
403
- // In a warp specialized kernel, collectives expose data movement and compute operations separately
404
- CollectiveMainloop collective_mainloop;
405
- CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
406
-
407
- // Prepare and partition the input tensors. Expects a tuple of tensors where:
408
- // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
409
- // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
410
- auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
411
- static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
412
-
413
- // Extract out partitioned A and B.
414
- Tensor gA_mkl = get<0>(load_inputs);
415
- Tensor gB_nkl = get<1>(load_inputs);
416
-
417
- // Compute m_coord, n_coord, and l_coord with their post-tiled shapes
418
- auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
419
- auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
420
- auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
421
- auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
422
-
423
- // Get pipeline iterators and increments from tensor shapes
424
- auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
425
- auto k_tile_count = size<3>(gA_mkl);
426
-
427
- // Wait for all thread blocks in the Cluster
428
- cluster_wait_fn();
429
-
430
- if (warp_group_role == WarpGroupRole::Producer) {
431
- if (producer_warp_role == ProducerWarpRole::Warp0) {
432
- if constexpr(SplitWarps) {
433
- collective_mainloop.load_NK(
434
- params.mainloop,
435
- mainloop_pipeline,
436
- prefetcher_pipeline,
437
- mainloop_pipe_producer_state,
438
- gB_nkl,
439
- blk_coord,
440
- k_tile_iter, k_tile_count,
441
- lane_idx,
442
- block_rank_in_cluster,
443
- shared_storage.tensors.mainloop
444
- );
445
- }
446
- else {
447
- collective_mainloop.load(
448
- params.mainloop,
449
- mainloop_pipeline,
450
- prefetcher_pipeline,
451
- mainloop_pipe_producer_state,
452
- gA_mkl, gB_nkl,
453
- blk_coord,
454
- k_tile_iter, k_tile_count,
455
- lane_idx,
456
- block_rank_in_cluster,
457
- shared_storage.tensors.mainloop
458
- );
459
- }
460
- // Update starting mainloop pipeline state for the pipeline drain
461
- mainloop_pipe_producer_state.advance(k_tile_count);
462
- // Make sure mainloop consumer has been waited upon before issuing epilogue load
463
- collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
464
-
465
- if (collective_epilogue.is_producer_load_needed()) {
466
- // Ensure warp is converged before issuing epilogue loads
467
- __syncwarp();
468
- epi_load_pipe_producer_state = collective_epilogue.load(
469
- epi_load_pipeline,
470
- epi_load_pipe_producer_state,
471
- problem_shape_MNKL,
472
- blk_shape,
473
- blk_coord,
474
- tiled_mma,
475
- lane_idx,
476
- shared_storage.tensors.epilogue
477
- );
478
- collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
479
- }
480
- }
481
- else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) {
482
- collective_mainloop.load_MK(
483
- params.mainloop,
484
- mainloop_pipeline,
485
- prefetcher_pipeline,
486
- mainloop_pipe_producer_state,
487
- gA_mkl,
488
- blk_coord,
489
- k_tile_iter, k_tile_count,
490
- lane_idx,
491
- block_rank_in_cluster,
492
- shared_storage.tensors.mainloop
493
- );
494
- // Update starting mainloop pipeline state for the pipeline drain
495
- mainloop_pipe_producer_state.advance(k_tile_count);
496
- // Make sure mainloop consumer has been waited upon before issuing epilogue load
497
- collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
498
- } else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) {
499
- collective_mainloop.prefetch_MK(
500
- params.mainloop,
501
- prefetcher_pipeline,
502
- mainloop_pipe_producer_state,
503
- gA_mkl,
504
- blk_coord,
505
- k_tile_iter, k_tile_count,
506
- lane_idx,
507
- block_rank_in_cluster,
508
- shared_storage.tensors.mainloop
509
- );
510
- }
511
- }
512
- else if (warp_group_role == WarpGroupRole::Consumer) {
513
- Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
514
-
515
- collective_mainloop.mma(
516
- mainloop_pipeline,
517
- mainloop_pipe_consumer_state,
518
- accumulators,
519
- k_tile_count,
520
- warp_group_thread_idx,
521
- shared_storage.tensors.mainloop,
522
- params.mainloop
523
- );
524
-
525
- // Make sure the math instructions are done and free buffers before entering the epilogue
526
- collective_mainloop.mma_tail(
527
- mainloop_pipeline,
528
- mainloop_pipe_consumer_state,
529
- k_tile_count
530
- );
531
-
532
- // Epilogue and write to gD
533
- auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
534
- collective_epilogue.store(
535
- epi_load_pipeline,
536
- epi_load_pipe_consumer_state,
537
- epi_store_pipeline,
538
- epi_store_pipe_producer_state,
539
- problem_shape_MNKL,
540
- blk_shape,
541
- blk_coord,
542
- accumulators,
543
- tiled_mma,
544
- warp_group_thread_idx,
545
- shared_storage.tensors.epilogue
546
- );
547
-
548
- collective_epilogue.store_tail(
549
- epi_load_pipeline,
550
- epi_load_pipe_consumer_state_next,
551
- epi_store_pipeline,
552
- epi_store_pipe_producer_state_next
553
- );
554
- }
555
- #endif
556
- }
557
- };
558
-
559
- ///////////////////////////////////////////////////////////////////////////////
560
-
561
- } // namespace cutlass::gemm::kernel