Kernels:
Trusted publisher
Uploaded using `kernel-builder` (batch 11/32).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h +0 -2119
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h +0 -55
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h +0 -283
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h +0 -0
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h +0 -1322
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py +0 -144
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h +0 -90
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h +0 -154
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h +0 -113
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h +0 -213
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h +0 -311
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h +0 -189
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h +0 -427
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py +0 -129
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py +0 -131
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py +0 -120
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py +0 -469
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py +0 -249
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py +0 -476
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py +0 -232
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py +0 -1013
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py +0 -456
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py +0 -92
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py +0 -135
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py +0 -67
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h +0 -292
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h +0 -94
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h +0 -499
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h +0 -52
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h +0 -938
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h +0 -545
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h +0 -95
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h +0 -150
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h +0 -424
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h +0 -232
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h +0 -775
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh +0 -139
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +0 -421
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh +0 -136
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp +0 -222
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh +0 -92
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp +0 -274
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp +0 -129
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp +0 -246
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h +0 -320
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp +0 -242
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp +0 -61
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp +0 -871
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp +0 -117
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp +0 -561
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h
DELETED
|
@@ -1,2119 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Templates implementing loading of tiles from pitch-linear rank=2
|
| 33 |
-
tensors.
|
| 34 |
-
|
| 35 |
-
This iterator uses masks to guard out-of-bounds accesses. The first tile
|
| 36 |
-
this iterator visits maybe partial, then the remaining tiles are complete.
|
| 37 |
-
So, we only need to compute the predicates twice, once before the first tile
|
| 38 |
-
and once for the remaining full tiles which can share the same predicates.
|
| 39 |
-
|
| 40 |
-
A precomputed "Params" object minimizes the amount of state that must be
|
| 41 |
-
stored in registers, and integer addition is used to advance the pointer
|
| 42 |
-
through memory.
|
| 43 |
-
*/
|
| 44 |
-
|
| 45 |
-
#pragma once
|
| 46 |
-
|
| 47 |
-
#include "cutlass/arch/memory.h"
|
| 48 |
-
#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
|
| 49 |
-
|
| 50 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 51 |
-
|
| 52 |
-
namespace cutlass {
|
| 53 |
-
namespace transform {
|
| 54 |
-
namespace threadblock {
|
| 55 |
-
|
| 56 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 57 |
-
|
| 58 |
-
/// PredicatedTileIteratorResidualLast
|
| 59 |
-
///
|
| 60 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 61 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 62 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 63 |
-
/// MaskedTileIteratorConcept
|
| 64 |
-
///
|
| 65 |
-
/// Regular tile iterator using a precomputed control structure to minimize
|
| 66 |
-
/// register liveness and integer arithmetic.
|
| 67 |
-
///
|
| 68 |
-
/// Layout is assumed to be invariant at the time the precomputed "Params"
|
| 69 |
-
/// object is constructed.
|
| 70 |
-
///
|
| 71 |
-
/// Base pointer and tensor extents may be specified at the time the iterator is
|
| 72 |
-
/// constructed. Subsequently, they are assumed to be immutable.
|
| 73 |
-
///
|
| 74 |
-
/// Adding a logical coordinate offset may be performed at the time the iterator
|
| 75 |
-
/// is constructed. Subsequent additions to logical coordinate offset may be
|
| 76 |
-
/// performed but are relatively expensive.
|
| 77 |
-
///
|
| 78 |
-
/// Visitation order is intended to first visit a "residual" tile that may be
|
| 79 |
-
/// partially full in both the advance dimension and the steady-state dimension.
|
| 80 |
-
/// This is assumed to be the last tile in the iteration sequence. Advancing an
|
| 81 |
-
/// iterator that has just been constructed moves to the first tile that is full
|
| 82 |
-
/// in the advance dimension and recomputes predicates. Subsequent accesses may
|
| 83 |
-
/// be performed without updating internal predicates and are efficient in terms
|
| 84 |
-
/// of live register state and pointer arithmetic instructions.
|
| 85 |
-
///
|
| 86 |
-
/// To be efficient, this assumes the iterator will be dereferenced and advanced
|
| 87 |
-
/// at least once outside any looping structure to minimize integer arithmetic.
|
| 88 |
-
///
|
| 89 |
-
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to
|
| 90 |
-
/// dereferencing the iterator.
|
| 91 |
-
///
|
| 92 |
-
///
|
| 93 |
-
/// Example:
|
| 94 |
-
///
|
| 95 |
-
/// An efficient pipeline structure may be constructed as follows:
|
| 96 |
-
///
|
| 97 |
-
// template <typename Iterator>
|
| 98 |
-
// __global__ void kernel(
|
| 99 |
-
// typename Iterator::Params params,
|
| 100 |
-
// typename Iterator::Element *ptr,
|
| 101 |
-
// TensorCoord extent) {
|
| 102 |
-
//
|
| 103 |
-
// typename Iterator::Fragment fragment;
|
| 104 |
-
//
|
| 105 |
-
// TensorCoord threadblock_offset(0, 0);
|
| 106 |
-
//
|
| 107 |
-
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
|
| 108 |
-
//
|
| 109 |
-
//
|
| 110 |
-
// fragment = *iter; // load "residue" tile first
|
| 111 |
-
// ++iter; // advance to first "steady state" tile and update
|
| 112 |
-
// internal masks
|
| 113 |
-
//
|
| 114 |
-
//
|
| 115 |
-
// #pragma unroll
|
| 116 |
-
// for (int i = Remaining - 1; i >= 0; --i) {
|
| 117 |
-
//
|
| 118 |
-
// f(fragment);
|
| 119 |
-
//
|
| 120 |
-
// if (!i) {
|
| 121 |
-
// iter.clear_mask(); // light-weight operation to clear masks -
|
| 122 |
-
// subsequent loads become NO-OPs.
|
| 123 |
-
// }
|
| 124 |
-
//
|
| 125 |
-
// fragment = *iter; // load tile during "steady state" phase
|
| 126 |
-
// ++iter; // advance to next tile - lightweight due to
|
| 127 |
-
// steady-state masks
|
| 128 |
-
// }
|
| 129 |
-
// }
|
| 130 |
-
//
|
| 131 |
-
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
|
| 132 |
-
//
|
| 133 |
-
// using Iterator =
|
| 134 |
-
// transform::threadblock::PredicatedTileIteratorResidualLast;
|
| 135 |
-
//
|
| 136 |
-
// typename Iterator::Params params(view.layout());
|
| 137 |
-
//
|
| 138 |
-
// kernel<Iterator>(params, view.data());
|
| 139 |
-
// }
|
| 140 |
-
///
|
| 141 |
-
///
|
| 142 |
-
template <
|
| 143 |
-
typename Shape,
|
| 144 |
-
typename Element,
|
| 145 |
-
typename Layout,
|
| 146 |
-
int AdvanceRank,
|
| 147 |
-
typename ThreadMap,
|
| 148 |
-
int AccessSize = ThreadMap::kElementsPerAccess,
|
| 149 |
-
bool Gather = false>
|
| 150 |
-
class PredicatedTileIteratorResidualLast;
|
| 151 |
-
|
| 152 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 153 |
-
|
| 154 |
-
/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
|
| 155 |
-
///
|
| 156 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 157 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 158 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 159 |
-
/// MaskedTileIteratorConcept
|
| 160 |
-
///
|
| 161 |
-
template <
|
| 162 |
-
typename Shape_,
|
| 163 |
-
typename Element_,
|
| 164 |
-
int AdvanceRank,
|
| 165 |
-
typename ThreadMap_,
|
| 166 |
-
int AccessSize,
|
| 167 |
-
bool Gather>
|
| 168 |
-
class PredicatedTileIteratorResidualLast<
|
| 169 |
-
Shape_,
|
| 170 |
-
Element_,
|
| 171 |
-
layout::PitchLinear,
|
| 172 |
-
AdvanceRank,
|
| 173 |
-
ThreadMap_,
|
| 174 |
-
AccessSize,
|
| 175 |
-
Gather> {
|
| 176 |
-
public:
|
| 177 |
-
static_assert(
|
| 178 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 179 |
-
"Specialization for pitch-linear iterator may advance along the "
|
| 180 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 181 |
-
|
| 182 |
-
using Shape = Shape_;
|
| 183 |
-
using Element = Element_;
|
| 184 |
-
using Layout = layout::PitchLinear;
|
| 185 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 186 |
-
using ThreadMap = ThreadMap_;
|
| 187 |
-
|
| 188 |
-
using Index = typename Layout::Index;
|
| 189 |
-
using LongIndex = typename Layout::LongIndex;
|
| 190 |
-
|
| 191 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 192 |
-
using TensorView = TensorView<Element, Layout>;
|
| 193 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 194 |
-
|
| 195 |
-
using Pointer = Element*;
|
| 196 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 197 |
-
|
| 198 |
-
/// Type used for internal memory accesses
|
| 199 |
-
using AccessType = AlignedArray<
|
| 200 |
-
Element,
|
| 201 |
-
AccessSize,
|
| 202 |
-
(AccessSize * sizeof_bits<Element>::value / 8)>;
|
| 203 |
-
|
| 204 |
-
/// Underlying iterator to compute the addresses
|
| 205 |
-
using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
|
| 206 |
-
Shape,
|
| 207 |
-
Element,
|
| 208 |
-
Layout,
|
| 209 |
-
kAdvanceRank,
|
| 210 |
-
ThreadMap,
|
| 211 |
-
AccessType,
|
| 212 |
-
Gather>;
|
| 213 |
-
|
| 214 |
-
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
| 215 |
-
|
| 216 |
-
/// Fragment object to be loaded or stored
|
| 217 |
-
using Fragment = cutlass::Array<
|
| 218 |
-
Element,
|
| 219 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 220 |
-
|
| 221 |
-
/// Predicate vector stores mask to guard accesses
|
| 222 |
-
using Mask = typename TileAccessIterator::Mask;
|
| 223 |
-
|
| 224 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 225 |
-
class Params {
|
| 226 |
-
public:
|
| 227 |
-
using Base = typename TileAccessIterator::Params::Base;
|
| 228 |
-
|
| 229 |
-
friend PredicatedTileIteratorResidualLast;
|
| 230 |
-
|
| 231 |
-
private:
|
| 232 |
-
/// Parameters object
|
| 233 |
-
typename TileAccessIterator::Params params_;
|
| 234 |
-
|
| 235 |
-
public:
|
| 236 |
-
/// Construct the Params object given a pitch-linear tensor's layout
|
| 237 |
-
CUTLASS_HOST_DEVICE
|
| 238 |
-
Params(Layout const& layout) : params_(layout) {}
|
| 239 |
-
|
| 240 |
-
CUTLASS_HOST_DEVICE
|
| 241 |
-
Params() {}
|
| 242 |
-
|
| 243 |
-
CUTLASS_HOST_DEVICE
|
| 244 |
-
Params(Base const& base) : params_(base) {}
|
| 245 |
-
};
|
| 246 |
-
|
| 247 |
-
private:
|
| 248 |
-
/// Internal pointer type permits fast address arithmetic
|
| 249 |
-
using BytePointer = char*;
|
| 250 |
-
|
| 251 |
-
private:
|
| 252 |
-
//
|
| 253 |
-
// Data members
|
| 254 |
-
//
|
| 255 |
-
|
| 256 |
-
/// Data member to the tile access iterator
|
| 257 |
-
TileAccessIterator address_iterator_;
|
| 258 |
-
|
| 259 |
-
public:
|
| 260 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 261 |
-
/// and thread ID
|
| 262 |
-
CUTLASS_HOST_DEVICE
|
| 263 |
-
PredicatedTileIteratorResidualLast(
|
| 264 |
-
/// Precomputed parameters object
|
| 265 |
-
Params const& params,
|
| 266 |
-
/// Pointer to start of tensor
|
| 267 |
-
Pointer pointer,
|
| 268 |
-
/// Extent of tensor
|
| 269 |
-
TensorCoord extent,
|
| 270 |
-
/// ID of each participating thread
|
| 271 |
-
int thread_id,
|
| 272 |
-
/// Initial offset of threadblock
|
| 273 |
-
TensorCoord const& threadblock_offset,
|
| 274 |
-
/// Gather indices
|
| 275 |
-
int const* indices = nullptr)
|
| 276 |
-
: address_iterator_(
|
| 277 |
-
params.params_,
|
| 278 |
-
pointer,
|
| 279 |
-
extent,
|
| 280 |
-
thread_id,
|
| 281 |
-
threadblock_offset,
|
| 282 |
-
indices) {}
|
| 283 |
-
|
| 284 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 285 |
-
/// offset
|
| 286 |
-
CUTLASS_HOST_DEVICE
|
| 287 |
-
PredicatedTileIteratorResidualLast(
|
| 288 |
-
Params const& params, ///< Precomputed parameters object
|
| 289 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 290 |
-
TensorCoord extent, ///< Extent of tensor
|
| 291 |
-
int thread_id ///< ID of each participating thread
|
| 292 |
-
)
|
| 293 |
-
: PredicatedTileIteratorResidualLast(
|
| 294 |
-
params,
|
| 295 |
-
pointer,
|
| 296 |
-
extent,
|
| 297 |
-
thread_id,
|
| 298 |
-
make_Coord(0, 0)) {}
|
| 299 |
-
|
| 300 |
-
/// Adds a pointer offset in units of Element
|
| 301 |
-
CUTLASS_HOST_DEVICE
|
| 302 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 303 |
-
address_iterator_.add_pointer_offset(pointer_offset);
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
/// Advances to the next tile in memory.
|
| 307 |
-
///
|
| 308 |
-
/// The first time this method is called, predicates are updated, and the
|
| 309 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 310 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 311 |
-
/// pointer.
|
| 312 |
-
CUTLASS_HOST_DEVICE
|
| 313 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 314 |
-
if (kAdvanceRank)
|
| 315 |
-
address_iterator_.add_tile_offset({0, 1});
|
| 316 |
-
else
|
| 317 |
-
address_iterator_.add_tile_offset({1, 0});
|
| 318 |
-
|
| 319 |
-
return *this;
|
| 320 |
-
}
|
| 321 |
-
|
| 322 |
-
/// Advances to the next tile in memory.
|
| 323 |
-
///
|
| 324 |
-
/// The first time this method is called, predicates are updated, and the
|
| 325 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 326 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 327 |
-
/// pointer.
|
| 328 |
-
CUTLASS_HOST_DEVICE
|
| 329 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 330 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 331 |
-
operator++();
|
| 332 |
-
return self;
|
| 333 |
-
}
|
| 334 |
-
|
| 335 |
-
/// Clears the predicate set efficiently
|
| 336 |
-
CUTLASS_HOST_DEVICE
|
| 337 |
-
void clear_mask(bool enable = true) {
|
| 338 |
-
address_iterator_.clear_mask(enable);
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
CUTLASS_HOST_DEVICE
|
| 342 |
-
void set_residual_tile(bool enable) {
|
| 343 |
-
address_iterator_.set_residual_tile(enable);
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
/// Clears the predicate set efficiently
|
| 347 |
-
CUTLASS_HOST_DEVICE
|
| 348 |
-
void enable_mask() {
|
| 349 |
-
address_iterator_.enable_mask();
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 353 |
-
CUTLASS_HOST_DEVICE
|
| 354 |
-
void set_mask(Mask const& mask) {
|
| 355 |
-
address_iterator_.set_mask(mask);
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
/// Gets the mask
|
| 359 |
-
CUTLASS_HOST_DEVICE
|
| 360 |
-
void get_mask(Mask& mask) {
|
| 361 |
-
address_iterator_.get_mask(mask);
|
| 362 |
-
}
|
| 363 |
-
|
| 364 |
-
CUTLASS_DEVICE
|
| 365 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 366 |
-
load_with_byte_offset(
|
| 367 |
-
frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 368 |
-
}
|
| 369 |
-
|
| 370 |
-
CUTLASS_DEVICE
|
| 371 |
-
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
|
| 372 |
-
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
|
| 373 |
-
|
| 374 |
-
CUTLASS_PRAGMA_UNROLL
|
| 375 |
-
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 376 |
-
CUTLASS_PRAGMA_UNROLL
|
| 377 |
-
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 378 |
-
CUTLASS_PRAGMA_UNROLL
|
| 379 |
-
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 380 |
-
int idx = v +
|
| 381 |
-
kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 382 |
-
|
| 383 |
-
address_iterator_.set_iteration_index(idx);
|
| 384 |
-
char const* byte_ptr =
|
| 385 |
-
reinterpret_cast<char const*>(address_iterator_.get()) +
|
| 386 |
-
byte_offset;
|
| 387 |
-
|
| 388 |
-
AccessType const* access_ptr =
|
| 389 |
-
reinterpret_cast<AccessType const*>(byte_ptr);
|
| 390 |
-
|
| 391 |
-
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
|
| 392 |
-
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
| 393 |
-
|
| 394 |
-
++address_iterator_;
|
| 395 |
-
}
|
| 396 |
-
}
|
| 397 |
-
}
|
| 398 |
-
}
|
| 399 |
-
|
| 400 |
-
/// Loads a fragment from memory
|
| 401 |
-
CUTLASS_DEVICE
|
| 402 |
-
void load(Fragment& frag) {
|
| 403 |
-
load_with_byte_offset(frag, 0);
|
| 404 |
-
}
|
| 405 |
-
|
| 406 |
-
/// Store a fragment to memory
|
| 407 |
-
CUTLASS_DEVICE
|
| 408 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 409 |
-
store_with_byte_offset(
|
| 410 |
-
frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 411 |
-
}
|
| 412 |
-
|
| 413 |
-
/// Store a fragment to memory
|
| 414 |
-
CUTLASS_DEVICE
|
| 415 |
-
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
|
| 416 |
-
address_iterator_.set_iteration_index(0);
|
| 417 |
-
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
|
| 418 |
-
|
| 419 |
-
CUTLASS_PRAGMA_UNROLL
|
| 420 |
-
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 421 |
-
CUTLASS_PRAGMA_UNROLL
|
| 422 |
-
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 423 |
-
CUTLASS_PRAGMA_UNROLL
|
| 424 |
-
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 425 |
-
int idx = v +
|
| 426 |
-
kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 427 |
-
|
| 428 |
-
char* byte_ptr =
|
| 429 |
-
reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
|
| 430 |
-
AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
|
| 431 |
-
|
| 432 |
-
if (address_iterator_.valid()) {
|
| 433 |
-
*access_ptr = frag_ptr[idx];
|
| 434 |
-
}
|
| 435 |
-
++address_iterator_;
|
| 436 |
-
}
|
| 437 |
-
}
|
| 438 |
-
}
|
| 439 |
-
}
|
| 440 |
-
|
| 441 |
-
/// Store a fragment to memory
|
| 442 |
-
CUTLASS_DEVICE
|
| 443 |
-
void store(Fragment const& frag) {
|
| 444 |
-
store_with_byte_offset(frag, 0);
|
| 445 |
-
}
|
| 446 |
-
};
|
| 447 |
-
|
| 448 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 449 |
-
|
| 450 |
-
/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
|
| 451 |
-
///
|
| 452 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 453 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 454 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 455 |
-
/// MaskedTileIteratorConcept
|
| 456 |
-
///
|
| 457 |
-
template <
|
| 458 |
-
typename Shape_,
|
| 459 |
-
typename Element_,
|
| 460 |
-
int AdvanceRank,
|
| 461 |
-
typename ThreadMap_,
|
| 462 |
-
int AccessSize,
|
| 463 |
-
bool Gather>
|
| 464 |
-
class PredicatedTileIteratorResidualLast<
|
| 465 |
-
Shape_,
|
| 466 |
-
Element_,
|
| 467 |
-
layout::ColumnMajor,
|
| 468 |
-
AdvanceRank,
|
| 469 |
-
ThreadMap_,
|
| 470 |
-
AccessSize,
|
| 471 |
-
Gather> {
|
| 472 |
-
public:
|
| 473 |
-
static_assert(
|
| 474 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 475 |
-
"Specialization for pitch-linear iterator may along advance along the "
|
| 476 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 477 |
-
|
| 478 |
-
using Shape = Shape_;
|
| 479 |
-
using Element = Element_;
|
| 480 |
-
using Layout = layout::ColumnMajor;
|
| 481 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 482 |
-
using ThreadMap = ThreadMap_;
|
| 483 |
-
|
| 484 |
-
using Index = typename Layout::Index;
|
| 485 |
-
using LongIndex = typename Layout::LongIndex;
|
| 486 |
-
|
| 487 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 488 |
-
using TensorView = TensorView<Element, Layout>;
|
| 489 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 490 |
-
|
| 491 |
-
using Pointer = Element*;
|
| 492 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 493 |
-
|
| 494 |
-
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
|
| 495 |
-
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 496 |
-
Element,
|
| 497 |
-
layout::PitchLinear,
|
| 498 |
-
(kAdvanceRank == 0 ? 0 : 1),
|
| 499 |
-
ThreadMap,
|
| 500 |
-
AccessSize,
|
| 501 |
-
Gather>;
|
| 502 |
-
|
| 503 |
-
using AccessType = typename UnderlyingIterator::AccessType;
|
| 504 |
-
|
| 505 |
-
/// Fragment object to be loaded or stored
|
| 506 |
-
using Fragment = cutlass::Array<
|
| 507 |
-
Element,
|
| 508 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 509 |
-
|
| 510 |
-
/// Predicate vector stores mask to guard accesses
|
| 511 |
-
using Mask = typename UnderlyingIterator::Mask;
|
| 512 |
-
|
| 513 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 514 |
-
class Params {
|
| 515 |
-
private:
|
| 516 |
-
friend PredicatedTileIteratorResidualLast;
|
| 517 |
-
|
| 518 |
-
/// Parameters object
|
| 519 |
-
typename UnderlyingIterator::Params params_;
|
| 520 |
-
|
| 521 |
-
public:
|
| 522 |
-
CUTLASS_HOST_DEVICE
|
| 523 |
-
Params() {}
|
| 524 |
-
|
| 525 |
-
/// Construct the Params object given a pitch-linear tensor's layout
|
| 526 |
-
CUTLASS_HOST_DEVICE
|
| 527 |
-
Params(Layout const& layout)
|
| 528 |
-
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 529 |
-
|
| 530 |
-
CUTLASS_HOST_DEVICE
|
| 531 |
-
Params(typename UnderlyingIterator::Params::Base const& base)
|
| 532 |
-
: params_(base) {}
|
| 533 |
-
};
|
| 534 |
-
|
| 535 |
-
private:
|
| 536 |
-
//
|
| 537 |
-
// Data members
|
| 538 |
-
//
|
| 539 |
-
|
| 540 |
-
/// Underlying pitch-linear tile iterator
|
| 541 |
-
UnderlyingIterator iterator_;
|
| 542 |
-
|
| 543 |
-
public:
|
| 544 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 545 |
-
/// and thread ID
|
| 546 |
-
CUTLASS_HOST_DEVICE
|
| 547 |
-
PredicatedTileIteratorResidualLast(
|
| 548 |
-
Params const& params, ///< Precomputed parameters object
|
| 549 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 550 |
-
TensorCoord extent, ///< Extent of tensor
|
| 551 |
-
int thread_id, ///< ID of each participating thread
|
| 552 |
-
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
|
| 553 |
-
int const* indices =
|
| 554 |
-
nullptr ///< gather/scatter indices, note no support for
|
| 555 |
-
///< gather/scatter at this specialization
|
| 556 |
-
)
|
| 557 |
-
: iterator_(
|
| 558 |
-
params.params_,
|
| 559 |
-
pointer,
|
| 560 |
-
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 561 |
-
thread_id,
|
| 562 |
-
layout::PitchLinearCoord(
|
| 563 |
-
threadblock_offset.row(),
|
| 564 |
-
threadblock_offset.column()),
|
| 565 |
-
indices) {}
|
| 566 |
-
|
| 567 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 568 |
-
/// offset
|
| 569 |
-
CUTLASS_HOST_DEVICE
|
| 570 |
-
PredicatedTileIteratorResidualLast(
|
| 571 |
-
Params const& params, ///< Precomputed parameters object
|
| 572 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 573 |
-
TensorCoord extent, ///< Extent of tensor
|
| 574 |
-
int thread_id ///< ID of each participating thread
|
| 575 |
-
)
|
| 576 |
-
: PredicatedTileIteratorResidualLast(
|
| 577 |
-
params,
|
| 578 |
-
pointer,
|
| 579 |
-
extent,
|
| 580 |
-
thread_id,
|
| 581 |
-
make_Coord(0, 0)) {}
|
| 582 |
-
|
| 583 |
-
/// Adds a pointer offset in units of Element
|
| 584 |
-
CUTLASS_HOST_DEVICE
|
| 585 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 586 |
-
iterator_.add_pointer_offset(pointer_offset);
|
| 587 |
-
}
|
| 588 |
-
|
| 589 |
-
/// Advances to the next tile in memory.
|
| 590 |
-
///
|
| 591 |
-
/// The first time this method is called, predicates are updated, and the
|
| 592 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 593 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 594 |
-
/// pointer.
|
| 595 |
-
CUTLASS_HOST_DEVICE
|
| 596 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 597 |
-
++iterator_;
|
| 598 |
-
return *this;
|
| 599 |
-
}
|
| 600 |
-
|
| 601 |
-
/// Advances to the next tile in memory.
|
| 602 |
-
///
|
| 603 |
-
/// The first time this method is called, predicates are updated, and the
|
| 604 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 605 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 606 |
-
/// pointer.
|
| 607 |
-
CUTLASS_HOST_DEVICE
|
| 608 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 609 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 610 |
-
operator++();
|
| 611 |
-
return self;
|
| 612 |
-
}
|
| 613 |
-
|
| 614 |
-
/// Clears the predicate set efficiently
|
| 615 |
-
CUTLASS_HOST_DEVICE
|
| 616 |
-
void clear_mask(bool enable = true) {
|
| 617 |
-
iterator_.clear_mask(enable);
|
| 618 |
-
}
|
| 619 |
-
|
| 620 |
-
CUTLASS_HOST_DEVICE
|
| 621 |
-
void set_residual_tile(bool enable) {
|
| 622 |
-
iterator_.set_residual_tile(enable);
|
| 623 |
-
}
|
| 624 |
-
|
| 625 |
-
/// Clears the predicate set efficiently
|
| 626 |
-
CUTLASS_HOST_DEVICE
|
| 627 |
-
void enable_mask() {
|
| 628 |
-
iterator_.enable_mask();
|
| 629 |
-
}
|
| 630 |
-
|
| 631 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 632 |
-
CUTLASS_HOST_DEVICE
|
| 633 |
-
void set_mask(Mask const& mask) {
|
| 634 |
-
iterator_.set_mask(mask);
|
| 635 |
-
}
|
| 636 |
-
|
| 637 |
-
/// Gets the mask
|
| 638 |
-
CUTLASS_HOST_DEVICE
|
| 639 |
-
void get_mask(Mask& mask) {
|
| 640 |
-
iterator_.get_mask(mask);
|
| 641 |
-
}
|
| 642 |
-
|
| 643 |
-
/// Loads a fragment from memory
|
| 644 |
-
CUTLASS_DEVICE
|
| 645 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 646 |
-
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 647 |
-
}
|
| 648 |
-
|
| 649 |
-
/// Loads a fragment from memory
|
| 650 |
-
CUTLASS_DEVICE
|
| 651 |
-
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
|
| 652 |
-
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 653 |
-
}
|
| 654 |
-
|
| 655 |
-
/// Loads a fragment from memory
|
| 656 |
-
CUTLASS_DEVICE
|
| 657 |
-
void load(Fragment& frag) {
|
| 658 |
-
load_with_pointer_offset(frag, 0);
|
| 659 |
-
}
|
| 660 |
-
|
| 661 |
-
/// Store a fragment to memory
|
| 662 |
-
CUTLASS_DEVICE
|
| 663 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 664 |
-
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 665 |
-
}
|
| 666 |
-
|
| 667 |
-
/// Store a fragment to memory
|
| 668 |
-
CUTLASS_DEVICE
|
| 669 |
-
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
|
| 670 |
-
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 671 |
-
}
|
| 672 |
-
|
| 673 |
-
/// Store a fragment to memory
|
| 674 |
-
CUTLASS_DEVICE
|
| 675 |
-
void store(Fragment const& frag) {
|
| 676 |
-
store_with_pointer_offset(frag, 0);
|
| 677 |
-
}
|
| 678 |
-
};
|
| 679 |
-
|
| 680 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 681 |
-
|
| 682 |
-
/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
|
| 683 |
-
///
|
| 684 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 685 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 686 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 687 |
-
/// MaskedTileIteratorConcept
|
| 688 |
-
///
|
| 689 |
-
template <
|
| 690 |
-
typename Shape_,
|
| 691 |
-
typename Element_,
|
| 692 |
-
int AdvanceRank,
|
| 693 |
-
typename ThreadMap_,
|
| 694 |
-
int AccessSize,
|
| 695 |
-
bool Gather>
|
| 696 |
-
class PredicatedTileIteratorResidualLast<
|
| 697 |
-
Shape_,
|
| 698 |
-
Element_,
|
| 699 |
-
layout::RowMajor,
|
| 700 |
-
AdvanceRank,
|
| 701 |
-
ThreadMap_,
|
| 702 |
-
AccessSize,
|
| 703 |
-
Gather> {
|
| 704 |
-
public:
|
| 705 |
-
static_assert(
|
| 706 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 707 |
-
"Specialization for pitch-linear iterator may along advance along the "
|
| 708 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 709 |
-
|
| 710 |
-
using Shape = Shape_;
|
| 711 |
-
using Element = Element_;
|
| 712 |
-
using Layout = layout::RowMajor;
|
| 713 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 714 |
-
using ThreadMap = ThreadMap_;
|
| 715 |
-
|
| 716 |
-
using Index = typename Layout::Index;
|
| 717 |
-
using LongIndex = typename Layout::LongIndex;
|
| 718 |
-
|
| 719 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 720 |
-
using TensorView = TensorView<Element, Layout>;
|
| 721 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 722 |
-
|
| 723 |
-
using Pointer = Element*;
|
| 724 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 725 |
-
|
| 726 |
-
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
|
| 727 |
-
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 728 |
-
Element,
|
| 729 |
-
layout::PitchLinear,
|
| 730 |
-
(kAdvanceRank == 0 ? 1 : 0),
|
| 731 |
-
ThreadMap,
|
| 732 |
-
AccessSize,
|
| 733 |
-
Gather>;
|
| 734 |
-
|
| 735 |
-
using AccessType = typename UnderlyingIterator::AccessType;
|
| 736 |
-
|
| 737 |
-
/// Fragment object to be loaded or stored
|
| 738 |
-
using Fragment = cutlass::Array<
|
| 739 |
-
Element,
|
| 740 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 741 |
-
|
| 742 |
-
/// Predicate vector stores mask to guard accesses
|
| 743 |
-
using Mask = typename UnderlyingIterator::Mask;
|
| 744 |
-
|
| 745 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 746 |
-
class Params {
|
| 747 |
-
private:
|
| 748 |
-
friend PredicatedTileIteratorResidualLast;
|
| 749 |
-
|
| 750 |
-
/// Parameters object
|
| 751 |
-
typename UnderlyingIterator::Params params_;
|
| 752 |
-
|
| 753 |
-
public:
|
| 754 |
-
CUTLASS_HOST_DEVICE
|
| 755 |
-
Params() {}
|
| 756 |
-
|
| 757 |
-
/// Construct the Params object given a pitch-linear tensor's layout
|
| 758 |
-
CUTLASS_HOST_DEVICE
|
| 759 |
-
Params(Layout const& layout)
|
| 760 |
-
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 761 |
-
|
| 762 |
-
CUTLASS_HOST_DEVICE
|
| 763 |
-
Params(typename UnderlyingIterator::Params::Base const& base)
|
| 764 |
-
: params_(base) {}
|
| 765 |
-
};
|
| 766 |
-
|
| 767 |
-
private:
|
| 768 |
-
//
|
| 769 |
-
// Data members
|
| 770 |
-
//
|
| 771 |
-
|
| 772 |
-
/// Underlying pitch-linear tile iterator
|
| 773 |
-
UnderlyingIterator iterator_;
|
| 774 |
-
|
| 775 |
-
public:
|
| 776 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 777 |
-
/// and thread ID
|
| 778 |
-
CUTLASS_HOST_DEVICE
|
| 779 |
-
PredicatedTileIteratorResidualLast(
|
| 780 |
-
Params const& params, ///< Precomputed parameters object
|
| 781 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 782 |
-
TensorCoord extent, ///< Extent of tensor
|
| 783 |
-
int thread_id, ///< ID of each participating thread
|
| 784 |
-
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
|
| 785 |
-
int const* indices = nullptr ///< Gather indices
|
| 786 |
-
)
|
| 787 |
-
: iterator_(
|
| 788 |
-
params.params_,
|
| 789 |
-
pointer,
|
| 790 |
-
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 791 |
-
thread_id,
|
| 792 |
-
layout::PitchLinearCoord(
|
| 793 |
-
threadblock_offset.column(),
|
| 794 |
-
threadblock_offset.row()),
|
| 795 |
-
indices) {}
|
| 796 |
-
|
| 797 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 798 |
-
/// offset
|
| 799 |
-
CUTLASS_HOST_DEVICE
|
| 800 |
-
PredicatedTileIteratorResidualLast(
|
| 801 |
-
Params const& params, ///< Precomputed parameters object
|
| 802 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 803 |
-
TensorCoord extent, ///< Extent of tensor
|
| 804 |
-
int thread_id ///< ID of each participating thread
|
| 805 |
-
)
|
| 806 |
-
: PredicatedTileIteratorResidualLast(
|
| 807 |
-
params,
|
| 808 |
-
pointer,
|
| 809 |
-
extent,
|
| 810 |
-
thread_id,
|
| 811 |
-
make_Coord(0, 0)) {}
|
| 812 |
-
|
| 813 |
-
/// Adds a pointer offset in units of Element
|
| 814 |
-
CUTLASS_HOST_DEVICE
|
| 815 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 816 |
-
iterator_.add_pointer_offset(pointer_offset);
|
| 817 |
-
}
|
| 818 |
-
|
| 819 |
-
/// Advances to the next tile in memory.
|
| 820 |
-
///
|
| 821 |
-
/// The first time this method is called, predicates are updated, and the
|
| 822 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 823 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 824 |
-
/// pointer.
|
| 825 |
-
CUTLASS_HOST_DEVICE
|
| 826 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 827 |
-
++iterator_;
|
| 828 |
-
return *this;
|
| 829 |
-
}
|
| 830 |
-
|
| 831 |
-
/// Advances to the next tile in memory.
|
| 832 |
-
///
|
| 833 |
-
/// The first time this method is called, predicates are updated, and the
|
| 834 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 835 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 836 |
-
/// pointer.
|
| 837 |
-
CUTLASS_HOST_DEVICE
|
| 838 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 839 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 840 |
-
operator++();
|
| 841 |
-
return self;
|
| 842 |
-
}
|
| 843 |
-
|
| 844 |
-
/// Clears the predicate set efficiently
|
| 845 |
-
CUTLASS_HOST_DEVICE
|
| 846 |
-
void clear_mask(bool enable = true) {
|
| 847 |
-
iterator_.clear_mask(enable);
|
| 848 |
-
}
|
| 849 |
-
|
| 850 |
-
CUTLASS_HOST_DEVICE
|
| 851 |
-
void set_residual_tile(bool enable) {
|
| 852 |
-
iterator_.set_residual_tile(enable);
|
| 853 |
-
}
|
| 854 |
-
|
| 855 |
-
/// Clears the predicate set efficiently
|
| 856 |
-
CUTLASS_HOST_DEVICE
|
| 857 |
-
void enable_mask() {
|
| 858 |
-
iterator_.enable_mask();
|
| 859 |
-
}
|
| 860 |
-
|
| 861 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 862 |
-
CUTLASS_HOST_DEVICE
|
| 863 |
-
void set_mask(Mask const& mask) {
|
| 864 |
-
iterator_.set_mask(mask);
|
| 865 |
-
}
|
| 866 |
-
|
| 867 |
-
/// Gets the mask
|
| 868 |
-
CUTLASS_HOST_DEVICE
|
| 869 |
-
void get_mask(Mask& mask) {
|
| 870 |
-
iterator_.get_mask(mask);
|
| 871 |
-
}
|
| 872 |
-
|
| 873 |
-
/// Loads a fragment from memory
|
| 874 |
-
CUTLASS_DEVICE
|
| 875 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 876 |
-
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 877 |
-
}
|
| 878 |
-
|
| 879 |
-
/// Loads a fragment from memory
|
| 880 |
-
CUTLASS_DEVICE
|
| 881 |
-
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
|
| 882 |
-
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 883 |
-
}
|
| 884 |
-
|
| 885 |
-
/// Loads a fragment from memory
|
| 886 |
-
CUTLASS_DEVICE
|
| 887 |
-
void load(Fragment& frag) {
|
| 888 |
-
load_with_pointer_offset(frag, 0);
|
| 889 |
-
}
|
| 890 |
-
|
| 891 |
-
/// Store a fragment to memory
|
| 892 |
-
CUTLASS_DEVICE
|
| 893 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 894 |
-
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 895 |
-
}
|
| 896 |
-
|
| 897 |
-
/// Store a fragment to memory
|
| 898 |
-
CUTLASS_DEVICE
|
| 899 |
-
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
|
| 900 |
-
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 901 |
-
}
|
| 902 |
-
|
| 903 |
-
/// Store a fragment to memory
|
| 904 |
-
CUTLASS_DEVICE
|
| 905 |
-
void store(Fragment const& frag) {
|
| 906 |
-
store_with_pointer_offset(frag, 0);
|
| 907 |
-
}
|
| 908 |
-
};
|
| 909 |
-
|
| 910 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 911 |
-
|
| 912 |
-
/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data.
|
| 913 |
-
///
|
| 914 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 915 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 916 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 917 |
-
/// MaskedTileIteratorConcept
|
| 918 |
-
///
|
| 919 |
-
template <
|
| 920 |
-
typename Shape_,
|
| 921 |
-
typename Element_,
|
| 922 |
-
int AdvanceRank,
|
| 923 |
-
typename ThreadMap_,
|
| 924 |
-
int AccessSize>
|
| 925 |
-
class PredicatedTileIteratorResidualLast<
|
| 926 |
-
Shape_,
|
| 927 |
-
Element_,
|
| 928 |
-
layout::AffineRankN<2>,
|
| 929 |
-
AdvanceRank,
|
| 930 |
-
ThreadMap_,
|
| 931 |
-
AccessSize,
|
| 932 |
-
false> {
|
| 933 |
-
public:
|
| 934 |
-
static_assert(
|
| 935 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 936 |
-
"Specialization for pitch-linear iterator may advance along the "
|
| 937 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 938 |
-
|
| 939 |
-
using Shape = Shape_;
|
| 940 |
-
using Element = Element_;
|
| 941 |
-
using Layout = layout::AffineRankN<2>;
|
| 942 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 943 |
-
using ThreadMap = ThreadMap_;
|
| 944 |
-
|
| 945 |
-
using Index = typename Layout::Index;
|
| 946 |
-
using LongIndex = typename Layout::LongIndex;
|
| 947 |
-
|
| 948 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 949 |
-
using TensorView = TensorView<Element, Layout>;
|
| 950 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 951 |
-
|
| 952 |
-
using Pointer = Element*;
|
| 953 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 954 |
-
|
| 955 |
-
/// Type used for internal memory accesses
|
| 956 |
-
using AccessType = AlignedArray<
|
| 957 |
-
Element,
|
| 958 |
-
AccessSize,
|
| 959 |
-
(AccessSize * sizeof_bits<Element>::value / 8)>;
|
| 960 |
-
|
| 961 |
-
/// Underlying iterator to compute the addresses
|
| 962 |
-
using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
|
| 963 |
-
Shape,
|
| 964 |
-
Element,
|
| 965 |
-
Layout,
|
| 966 |
-
kAdvanceRank,
|
| 967 |
-
ThreadMap,
|
| 968 |
-
AccessType>;
|
| 969 |
-
|
| 970 |
-
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
| 971 |
-
|
| 972 |
-
/// Fragment object to be loaded or stored
|
| 973 |
-
using Fragment = cutlass::Array<
|
| 974 |
-
Element,
|
| 975 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 976 |
-
|
| 977 |
-
/// Predicate vector stores mask to guard accesses
|
| 978 |
-
using Mask = typename TileAccessIterator::Mask;
|
| 979 |
-
|
| 980 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 981 |
-
class Params {
|
| 982 |
-
public:
|
| 983 |
-
friend PredicatedTileIteratorResidualLast;
|
| 984 |
-
|
| 985 |
-
private:
|
| 986 |
-
/// Parameters object
|
| 987 |
-
typename TileAccessIterator::Params params_;
|
| 988 |
-
|
| 989 |
-
public:
|
| 990 |
-
/// Construct the Params object given a pitch-linear tensor's layout
|
| 991 |
-
CUTLASS_HOST_DEVICE
|
| 992 |
-
Params(Layout const& layout) : params_(layout) {}
|
| 993 |
-
|
| 994 |
-
CUTLASS_HOST_DEVICE
|
| 995 |
-
Params() {}
|
| 996 |
-
};
|
| 997 |
-
|
| 998 |
-
private:
|
| 999 |
-
/// Internal pointer type permits fast address arithmetic
|
| 1000 |
-
using BytePointer = char*;
|
| 1001 |
-
|
| 1002 |
-
private:
|
| 1003 |
-
//
|
| 1004 |
-
// Data members
|
| 1005 |
-
//
|
| 1006 |
-
|
| 1007 |
-
/// Data member to the tile access iterator
|
| 1008 |
-
TileAccessIterator address_iterator_;
|
| 1009 |
-
|
| 1010 |
-
public:
|
| 1011 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1012 |
-
/// and thread ID
|
| 1013 |
-
CUTLASS_HOST_DEVICE
|
| 1014 |
-
PredicatedTileIteratorResidualLast(
|
| 1015 |
-
/// Precomputed parameters object
|
| 1016 |
-
Params const& params,
|
| 1017 |
-
/// Pointer to start of tensor
|
| 1018 |
-
Pointer pointer,
|
| 1019 |
-
/// Extent of tensor
|
| 1020 |
-
TensorCoord extent,
|
| 1021 |
-
/// ID of each participating thread
|
| 1022 |
-
int thread_id,
|
| 1023 |
-
/// Initial offset of threadblock
|
| 1024 |
-
TensorCoord const& threadblock_offset,
|
| 1025 |
-
int const* indices =
|
| 1026 |
-
nullptr ///< gather/scatter indices, note no support for
|
| 1027 |
-
///< gather/scatter at this specialization
|
| 1028 |
-
)
|
| 1029 |
-
: address_iterator_(
|
| 1030 |
-
params.params_,
|
| 1031 |
-
pointer,
|
| 1032 |
-
extent,
|
| 1033 |
-
thread_id,
|
| 1034 |
-
threadblock_offset) {}
|
| 1035 |
-
|
| 1036 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 1037 |
-
/// offset
|
| 1038 |
-
CUTLASS_HOST_DEVICE
|
| 1039 |
-
PredicatedTileIteratorResidualLast(
|
| 1040 |
-
Params const& params, ///< Precomputed parameters object
|
| 1041 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 1042 |
-
TensorCoord extent, ///< Extent of tensor
|
| 1043 |
-
int thread_id ///< ID of each participating thread
|
| 1044 |
-
)
|
| 1045 |
-
: PredicatedTileIteratorResidualLast(
|
| 1046 |
-
params,
|
| 1047 |
-
pointer,
|
| 1048 |
-
extent,
|
| 1049 |
-
thread_id,
|
| 1050 |
-
make_Coord(0, 0)) {}
|
| 1051 |
-
|
| 1052 |
-
/// Adds a pointer offset in units of Element
|
| 1053 |
-
CUTLASS_HOST_DEVICE
|
| 1054 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1055 |
-
address_iterator_.add_pointer_offset(pointer_offset);
|
| 1056 |
-
}
|
| 1057 |
-
|
| 1058 |
-
/// Advances to the next tile in memory.
|
| 1059 |
-
///
|
| 1060 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1061 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1062 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1063 |
-
/// pointer.
|
| 1064 |
-
CUTLASS_HOST_DEVICE
|
| 1065 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 1066 |
-
if (kAdvanceRank)
|
| 1067 |
-
address_iterator_.add_tile_offset(make_Coord(0, 1));
|
| 1068 |
-
else
|
| 1069 |
-
address_iterator_.add_tile_offset(make_Coord(1, 0));
|
| 1070 |
-
|
| 1071 |
-
return *this;
|
| 1072 |
-
}
|
| 1073 |
-
|
| 1074 |
-
/// Advances to the next tile in memory.
|
| 1075 |
-
///
|
| 1076 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1077 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1078 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1079 |
-
/// pointer.
|
| 1080 |
-
CUTLASS_HOST_DEVICE
|
| 1081 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 1082 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 1083 |
-
operator++();
|
| 1084 |
-
return self;
|
| 1085 |
-
}
|
| 1086 |
-
|
| 1087 |
-
/// Clears the predicate set efficiently
|
| 1088 |
-
CUTLASS_HOST_DEVICE
|
| 1089 |
-
void clear_mask(bool enable = true) {
|
| 1090 |
-
address_iterator_.clear_mask(enable);
|
| 1091 |
-
}
|
| 1092 |
-
|
| 1093 |
-
CUTLASS_HOST_DEVICE
|
| 1094 |
-
void set_residual_tile(bool enable) {
|
| 1095 |
-
address_iterator_.set_residual_tile(enable);
|
| 1096 |
-
}
|
| 1097 |
-
|
| 1098 |
-
/// Clears the predicate set efficiently
|
| 1099 |
-
CUTLASS_HOST_DEVICE
|
| 1100 |
-
void enable_mask() {
|
| 1101 |
-
address_iterator_.enable_mask();
|
| 1102 |
-
}
|
| 1103 |
-
|
| 1104 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1105 |
-
CUTLASS_HOST_DEVICE
|
| 1106 |
-
void set_mask(Mask const& mask) {
|
| 1107 |
-
address_iterator_.set_mask(mask);
|
| 1108 |
-
}
|
| 1109 |
-
|
| 1110 |
-
/// Gets the mask
|
| 1111 |
-
CUTLASS_HOST_DEVICE
|
| 1112 |
-
void get_mask(Mask& mask) {
|
| 1113 |
-
address_iterator_.get_mask(mask);
|
| 1114 |
-
}
|
| 1115 |
-
|
| 1116 |
-
CUTLASS_DEVICE
|
| 1117 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 1118 |
-
load_with_byte_offset(
|
| 1119 |
-
frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 1120 |
-
}
|
| 1121 |
-
|
| 1122 |
-
CUTLASS_DEVICE
|
| 1123 |
-
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
|
| 1124 |
-
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
|
| 1125 |
-
|
| 1126 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1127 |
-
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 1128 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1129 |
-
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 1130 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1131 |
-
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 1132 |
-
int idx = v +
|
| 1133 |
-
kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 1134 |
-
|
| 1135 |
-
address_iterator_.set_iteration_index(idx);
|
| 1136 |
-
char const* byte_ptr =
|
| 1137 |
-
reinterpret_cast<char const*>(address_iterator_.get()) +
|
| 1138 |
-
byte_offset;
|
| 1139 |
-
|
| 1140 |
-
AccessType const* access_ptr =
|
| 1141 |
-
reinterpret_cast<AccessType const*>(byte_ptr);
|
| 1142 |
-
|
| 1143 |
-
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
|
| 1144 |
-
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
| 1145 |
-
|
| 1146 |
-
++address_iterator_;
|
| 1147 |
-
}
|
| 1148 |
-
}
|
| 1149 |
-
}
|
| 1150 |
-
}
|
| 1151 |
-
|
| 1152 |
-
/// Loads a fragment from memory
|
| 1153 |
-
CUTLASS_DEVICE
|
| 1154 |
-
void load(Fragment& frag) {
|
| 1155 |
-
load_with_byte_offset(frag, 0);
|
| 1156 |
-
}
|
| 1157 |
-
|
| 1158 |
-
/// Store a fragment to memory
|
| 1159 |
-
CUTLASS_DEVICE
|
| 1160 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 1161 |
-
store_with_byte_offset(
|
| 1162 |
-
frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 1163 |
-
}
|
| 1164 |
-
|
| 1165 |
-
/// Store a fragment to memory
|
| 1166 |
-
CUTLASS_DEVICE
|
| 1167 |
-
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
|
| 1168 |
-
address_iterator_.set_iteration_index(0);
|
| 1169 |
-
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
|
| 1170 |
-
|
| 1171 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1172 |
-
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 1173 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1174 |
-
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 1175 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1176 |
-
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 1177 |
-
int idx = v +
|
| 1178 |
-
kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 1179 |
-
|
| 1180 |
-
char* byte_ptr =
|
| 1181 |
-
reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
|
| 1182 |
-
AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
|
| 1183 |
-
|
| 1184 |
-
if (address_iterator_.valid()) {
|
| 1185 |
-
*access_ptr = frag_ptr[idx];
|
| 1186 |
-
}
|
| 1187 |
-
++address_iterator_;
|
| 1188 |
-
}
|
| 1189 |
-
}
|
| 1190 |
-
}
|
| 1191 |
-
}
|
| 1192 |
-
|
| 1193 |
-
/// Store a fragment to memory
|
| 1194 |
-
CUTLASS_DEVICE
|
| 1195 |
-
void store(Fragment const& frag) {
|
| 1196 |
-
store_with_byte_offset(frag, 0);
|
| 1197 |
-
}
|
| 1198 |
-
};
|
| 1199 |
-
|
| 1200 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 1201 |
-
|
| 1202 |
-
/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
|
| 1203 |
-
/// column-major data.
|
| 1204 |
-
///
|
| 1205 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 1206 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 1207 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 1208 |
-
/// MaskedTileIteratorConcept
|
| 1209 |
-
///
|
| 1210 |
-
template <
|
| 1211 |
-
typename Shape_,
|
| 1212 |
-
typename Element_,
|
| 1213 |
-
int AdvanceRank,
|
| 1214 |
-
typename ThreadMap_,
|
| 1215 |
-
int AccessSize>
|
| 1216 |
-
class PredicatedTileIteratorResidualLast<
|
| 1217 |
-
Shape_,
|
| 1218 |
-
Element_,
|
| 1219 |
-
layout::AffineRank2ColumnMajor,
|
| 1220 |
-
AdvanceRank,
|
| 1221 |
-
ThreadMap_,
|
| 1222 |
-
AccessSize,
|
| 1223 |
-
false> {
|
| 1224 |
-
public:
|
| 1225 |
-
static_assert(
|
| 1226 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1227 |
-
"Specialization for pitch-linear iterator may along advance along the "
|
| 1228 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1229 |
-
|
| 1230 |
-
using Shape = Shape_;
|
| 1231 |
-
using Element = Element_;
|
| 1232 |
-
using Layout = layout::AffineRank2ColumnMajor;
|
| 1233 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 1234 |
-
using ThreadMap = ThreadMap_;
|
| 1235 |
-
|
| 1236 |
-
using Index = typename Layout::Index;
|
| 1237 |
-
using LongIndex = typename Layout::LongIndex;
|
| 1238 |
-
|
| 1239 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 1240 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1241 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 1242 |
-
|
| 1243 |
-
using Pointer = Element*;
|
| 1244 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 1245 |
-
|
| 1246 |
-
// Map to the underlying AffineRankN<2> layout
|
| 1247 |
-
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
|
| 1248 |
-
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 1249 |
-
Element,
|
| 1250 |
-
layout::AffineRankN<2>,
|
| 1251 |
-
(kAdvanceRank == 0 ? 0 : 1),
|
| 1252 |
-
ThreadMap,
|
| 1253 |
-
AccessSize>;
|
| 1254 |
-
|
| 1255 |
-
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1256 |
-
|
| 1257 |
-
/// Fragment object to be loaded or stored
|
| 1258 |
-
using Fragment = cutlass::Array<
|
| 1259 |
-
Element,
|
| 1260 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 1261 |
-
|
| 1262 |
-
/// Predicate vector stores mask to guard accesses
|
| 1263 |
-
using Mask = typename UnderlyingIterator::Mask;
|
| 1264 |
-
|
| 1265 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 1266 |
-
class Params {
|
| 1267 |
-
private:
|
| 1268 |
-
friend PredicatedTileIteratorResidualLast;
|
| 1269 |
-
|
| 1270 |
-
/// Parameters object
|
| 1271 |
-
typename UnderlyingIterator::Params params_;
|
| 1272 |
-
|
| 1273 |
-
public:
|
| 1274 |
-
CUTLASS_HOST_DEVICE
|
| 1275 |
-
Params() {}
|
| 1276 |
-
|
| 1277 |
-
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
| 1278 |
-
CUTLASS_HOST_DEVICE
|
| 1279 |
-
Params(Layout const& layout)
|
| 1280 |
-
: params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {}
|
| 1281 |
-
};
|
| 1282 |
-
|
| 1283 |
-
private:
|
| 1284 |
-
//
|
| 1285 |
-
// Data members
|
| 1286 |
-
//
|
| 1287 |
-
|
| 1288 |
-
/// Underlying AffineRankN<2> tile iterator
|
| 1289 |
-
UnderlyingIterator iterator_;
|
| 1290 |
-
|
| 1291 |
-
public:
|
| 1292 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1293 |
-
/// and thread ID
|
| 1294 |
-
CUTLASS_HOST_DEVICE
|
| 1295 |
-
PredicatedTileIteratorResidualLast(
|
| 1296 |
-
Params const& params, ///< Precomputed parameters object
|
| 1297 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 1298 |
-
TensorCoord extent, ///< Extent of tensor
|
| 1299 |
-
int thread_id, ///< ID of each participating thread
|
| 1300 |
-
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
|
| 1301 |
-
int const* indices =
|
| 1302 |
-
nullptr ///< gather/scatter indices, note no support for
|
| 1303 |
-
///< gather/scatter at this specialization
|
| 1304 |
-
)
|
| 1305 |
-
: iterator_(
|
| 1306 |
-
params.params_,
|
| 1307 |
-
pointer,
|
| 1308 |
-
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 1309 |
-
thread_id,
|
| 1310 |
-
layout::PitchLinearCoord(
|
| 1311 |
-
threadblock_offset.row(),
|
| 1312 |
-
threadblock_offset.column())) {}
|
| 1313 |
-
|
| 1314 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 1315 |
-
/// offset
|
| 1316 |
-
CUTLASS_HOST_DEVICE
|
| 1317 |
-
PredicatedTileIteratorResidualLast(
|
| 1318 |
-
Params const& params, ///< Precomputed parameters object
|
| 1319 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 1320 |
-
TensorCoord extent, ///< Extent of tensor
|
| 1321 |
-
int thread_id ///< ID of each participating thread
|
| 1322 |
-
)
|
| 1323 |
-
: PredicatedTileIteratorResidualLast(
|
| 1324 |
-
params,
|
| 1325 |
-
pointer,
|
| 1326 |
-
extent,
|
| 1327 |
-
thread_id,
|
| 1328 |
-
make_Coord(0, 0)) {}
|
| 1329 |
-
|
| 1330 |
-
/// Adds a pointer offset in units of Element
|
| 1331 |
-
CUTLASS_HOST_DEVICE
|
| 1332 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1333 |
-
iterator_.add_pointer_offset(pointer_offset);
|
| 1334 |
-
}
|
| 1335 |
-
|
| 1336 |
-
/// Advances to the next tile in memory.
|
| 1337 |
-
///
|
| 1338 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1339 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1340 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1341 |
-
/// pointer.
|
| 1342 |
-
CUTLASS_HOST_DEVICE
|
| 1343 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 1344 |
-
++iterator_;
|
| 1345 |
-
return *this;
|
| 1346 |
-
}
|
| 1347 |
-
|
| 1348 |
-
/// Advances to the next tile in memory.
|
| 1349 |
-
///
|
| 1350 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1351 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1352 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1353 |
-
/// pointer.
|
| 1354 |
-
CUTLASS_HOST_DEVICE
|
| 1355 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 1356 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 1357 |
-
operator++();
|
| 1358 |
-
return self;
|
| 1359 |
-
}
|
| 1360 |
-
|
| 1361 |
-
/// Clears the predicate set efficiently
|
| 1362 |
-
CUTLASS_HOST_DEVICE
|
| 1363 |
-
void clear_mask(bool enable = true) {
|
| 1364 |
-
iterator_.clear_mask(enable);
|
| 1365 |
-
}
|
| 1366 |
-
|
| 1367 |
-
CUTLASS_HOST_DEVICE
|
| 1368 |
-
void set_residual_tile(bool enable) {
|
| 1369 |
-
iterator_.set_residual_tile(enable);
|
| 1370 |
-
}
|
| 1371 |
-
|
| 1372 |
-
/// Clears the predicate set efficiently
|
| 1373 |
-
CUTLASS_HOST_DEVICE
|
| 1374 |
-
void enable_mask() {
|
| 1375 |
-
iterator_.enable_mask();
|
| 1376 |
-
}
|
| 1377 |
-
|
| 1378 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1379 |
-
CUTLASS_HOST_DEVICE
|
| 1380 |
-
void set_mask(Mask const& mask) {
|
| 1381 |
-
iterator_.set_mask(mask);
|
| 1382 |
-
}
|
| 1383 |
-
|
| 1384 |
-
/// Gets the mask
|
| 1385 |
-
CUTLASS_HOST_DEVICE
|
| 1386 |
-
void get_mask(Mask& mask) {
|
| 1387 |
-
iterator_.get_mask(mask);
|
| 1388 |
-
}
|
| 1389 |
-
|
| 1390 |
-
/// Loads a fragment from memory
|
| 1391 |
-
CUTLASS_DEVICE
|
| 1392 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 1393 |
-
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1394 |
-
}
|
| 1395 |
-
|
| 1396 |
-
/// Loads a fragment from memory
|
| 1397 |
-
CUTLASS_DEVICE
|
| 1398 |
-
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
|
| 1399 |
-
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 1400 |
-
}
|
| 1401 |
-
|
| 1402 |
-
/// Loads a fragment from memory
|
| 1403 |
-
CUTLASS_DEVICE
|
| 1404 |
-
void load(Fragment& frag) {
|
| 1405 |
-
load_with_pointer_offset(frag, 0);
|
| 1406 |
-
}
|
| 1407 |
-
|
| 1408 |
-
/// Store a fragment to memory
|
| 1409 |
-
CUTLASS_DEVICE
|
| 1410 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 1411 |
-
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1412 |
-
}
|
| 1413 |
-
|
| 1414 |
-
/// Store a fragment to memory
|
| 1415 |
-
CUTLASS_DEVICE
|
| 1416 |
-
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
|
| 1417 |
-
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 1418 |
-
}
|
| 1419 |
-
|
| 1420 |
-
/// Store a fragment to memory
|
| 1421 |
-
CUTLASS_DEVICE
|
| 1422 |
-
void store(Fragment const& frag) {
|
| 1423 |
-
store_with_pointer_offset(frag, 0);
|
| 1424 |
-
}
|
| 1425 |
-
};
|
| 1426 |
-
|
| 1427 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 1428 |
-
|
| 1429 |
-
/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
|
| 1430 |
-
/// row-major data.
|
| 1431 |
-
///
|
| 1432 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 1433 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 1434 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 1435 |
-
/// MaskedTileIteratorConcept
|
| 1436 |
-
///
|
| 1437 |
-
template <
|
| 1438 |
-
typename Shape_,
|
| 1439 |
-
typename Element_,
|
| 1440 |
-
int AdvanceRank,
|
| 1441 |
-
typename ThreadMap_,
|
| 1442 |
-
int AccessSize>
|
| 1443 |
-
class PredicatedTileIteratorResidualLast<
|
| 1444 |
-
Shape_,
|
| 1445 |
-
Element_,
|
| 1446 |
-
layout::AffineRank2RowMajor,
|
| 1447 |
-
AdvanceRank,
|
| 1448 |
-
ThreadMap_,
|
| 1449 |
-
AccessSize,
|
| 1450 |
-
false> {
|
| 1451 |
-
public:
|
| 1452 |
-
static_assert(
|
| 1453 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1454 |
-
"Specialization for pitch-linear iterator may along advance along the "
|
| 1455 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1456 |
-
|
| 1457 |
-
using Shape = Shape_;
|
| 1458 |
-
using Element = Element_;
|
| 1459 |
-
using Layout = layout::AffineRank2RowMajor;
|
| 1460 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 1461 |
-
using ThreadMap = ThreadMap_;
|
| 1462 |
-
|
| 1463 |
-
using Index = typename Layout::Index;
|
| 1464 |
-
using LongIndex = typename Layout::LongIndex;
|
| 1465 |
-
|
| 1466 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 1467 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1468 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 1469 |
-
|
| 1470 |
-
using Pointer = Element*;
|
| 1471 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 1472 |
-
|
| 1473 |
-
// Map to the underlying AffineRankN<2> layout
|
| 1474 |
-
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
|
| 1475 |
-
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 1476 |
-
Element,
|
| 1477 |
-
layout::AffineRankN<2>,
|
| 1478 |
-
(kAdvanceRank == 0 ? 1 : 0),
|
| 1479 |
-
ThreadMap,
|
| 1480 |
-
AccessSize>;
|
| 1481 |
-
|
| 1482 |
-
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1483 |
-
|
| 1484 |
-
/// Fragment object to be loaded or stored
|
| 1485 |
-
using Fragment = cutlass::Array<
|
| 1486 |
-
Element,
|
| 1487 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 1488 |
-
|
| 1489 |
-
/// Predicate vector stores mask to guard accesses
|
| 1490 |
-
using Mask = typename UnderlyingIterator::Mask;
|
| 1491 |
-
|
| 1492 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 1493 |
-
class Params {
|
| 1494 |
-
private:
|
| 1495 |
-
friend PredicatedTileIteratorResidualLast;
|
| 1496 |
-
|
| 1497 |
-
/// Parameters object
|
| 1498 |
-
typename UnderlyingIterator::Params params_;
|
| 1499 |
-
|
| 1500 |
-
public:
|
| 1501 |
-
CUTLASS_HOST_DEVICE
|
| 1502 |
-
Params() {}
|
| 1503 |
-
|
| 1504 |
-
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
| 1505 |
-
CUTLASS_HOST_DEVICE
|
| 1506 |
-
Params(Layout const& layout)
|
| 1507 |
-
: params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
|
| 1508 |
-
};
|
| 1509 |
-
|
| 1510 |
-
private:
|
| 1511 |
-
//
|
| 1512 |
-
// Data members
|
| 1513 |
-
//
|
| 1514 |
-
|
| 1515 |
-
/// Underlying AffineRankN<2> tile iterator
|
| 1516 |
-
UnderlyingIterator iterator_;
|
| 1517 |
-
|
| 1518 |
-
public:
|
| 1519 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1520 |
-
/// and thread ID
|
| 1521 |
-
CUTLASS_HOST_DEVICE
|
| 1522 |
-
PredicatedTileIteratorResidualLast(
|
| 1523 |
-
Params const& params, ///< Precomputed parameters object
|
| 1524 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 1525 |
-
TensorCoord extent, ///< Extent of tensor
|
| 1526 |
-
int thread_id, ///< ID of each participating thread
|
| 1527 |
-
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
|
| 1528 |
-
int const* indices =
|
| 1529 |
-
nullptr ///< gather/scatter indices, note no support for
|
| 1530 |
-
///< gather/scatter at this specialization
|
| 1531 |
-
)
|
| 1532 |
-
: iterator_(
|
| 1533 |
-
params.params_,
|
| 1534 |
-
pointer,
|
| 1535 |
-
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 1536 |
-
thread_id,
|
| 1537 |
-
layout::PitchLinearCoord(
|
| 1538 |
-
threadblock_offset.column(),
|
| 1539 |
-
threadblock_offset.row())) {}
|
| 1540 |
-
|
| 1541 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 1542 |
-
/// offset
|
| 1543 |
-
CUTLASS_HOST_DEVICE
|
| 1544 |
-
PredicatedTileIteratorResidualLast(
|
| 1545 |
-
Params const& params, ///< Precomputed parameters object
|
| 1546 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 1547 |
-
TensorCoord extent, ///< Extent of tensor
|
| 1548 |
-
int thread_id ///< ID of each participating thread
|
| 1549 |
-
)
|
| 1550 |
-
: PredicatedTileIteratorResidualLast(
|
| 1551 |
-
params,
|
| 1552 |
-
pointer,
|
| 1553 |
-
extent,
|
| 1554 |
-
thread_id,
|
| 1555 |
-
make_Coord(0, 0)) {}
|
| 1556 |
-
|
| 1557 |
-
/// Adds a pointer offset in units of Element
|
| 1558 |
-
CUTLASS_HOST_DEVICE
|
| 1559 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1560 |
-
iterator_.add_pointer_offset(pointer_offset);
|
| 1561 |
-
}
|
| 1562 |
-
|
| 1563 |
-
/// Advances to the next tile in memory.
|
| 1564 |
-
///
|
| 1565 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1566 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1567 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1568 |
-
/// pointer.
|
| 1569 |
-
CUTLASS_HOST_DEVICE
|
| 1570 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 1571 |
-
++iterator_;
|
| 1572 |
-
return *this;
|
| 1573 |
-
}
|
| 1574 |
-
|
| 1575 |
-
/// Advances to the next tile in memory.
|
| 1576 |
-
///
|
| 1577 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1578 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1579 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1580 |
-
/// pointer.
|
| 1581 |
-
CUTLASS_HOST_DEVICE
|
| 1582 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 1583 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 1584 |
-
operator++();
|
| 1585 |
-
return self;
|
| 1586 |
-
}
|
| 1587 |
-
|
| 1588 |
-
/// Clears the predicate set efficiently
|
| 1589 |
-
CUTLASS_HOST_DEVICE
|
| 1590 |
-
void clear_mask(bool enable = true) {
|
| 1591 |
-
iterator_.clear_mask(enable);
|
| 1592 |
-
}
|
| 1593 |
-
|
| 1594 |
-
CUTLASS_HOST_DEVICE
|
| 1595 |
-
void set_residual_tile(bool enable) {
|
| 1596 |
-
iterator_.set_residual_tile(enable);
|
| 1597 |
-
}
|
| 1598 |
-
|
| 1599 |
-
/// Clears the predicate set efficiently
|
| 1600 |
-
CUTLASS_HOST_DEVICE
|
| 1601 |
-
void enable_mask() {
|
| 1602 |
-
iterator_.enable_mask();
|
| 1603 |
-
}
|
| 1604 |
-
|
| 1605 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1606 |
-
CUTLASS_HOST_DEVICE
|
| 1607 |
-
void set_mask(Mask const& mask) {
|
| 1608 |
-
iterator_.set_mask(mask);
|
| 1609 |
-
}
|
| 1610 |
-
|
| 1611 |
-
/// Gets the mask
|
| 1612 |
-
CUTLASS_HOST_DEVICE
|
| 1613 |
-
void get_mask(Mask& mask) {
|
| 1614 |
-
iterator_.get_mask(mask);
|
| 1615 |
-
}
|
| 1616 |
-
|
| 1617 |
-
/// Loads a fragment from memory
|
| 1618 |
-
CUTLASS_DEVICE
|
| 1619 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 1620 |
-
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1621 |
-
}
|
| 1622 |
-
|
| 1623 |
-
/// Loads a fragment from memory
|
| 1624 |
-
CUTLASS_DEVICE
|
| 1625 |
-
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
|
| 1626 |
-
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 1627 |
-
}
|
| 1628 |
-
|
| 1629 |
-
/// Loads a fragment from memory
|
| 1630 |
-
CUTLASS_DEVICE
|
| 1631 |
-
void load(Fragment& frag) {
|
| 1632 |
-
load_with_pointer_offset(frag, 0);
|
| 1633 |
-
}
|
| 1634 |
-
|
| 1635 |
-
/// Store a fragment to memory
|
| 1636 |
-
CUTLASS_DEVICE
|
| 1637 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 1638 |
-
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1639 |
-
}
|
| 1640 |
-
|
| 1641 |
-
/// Store a fragment to memory
|
| 1642 |
-
CUTLASS_DEVICE
|
| 1643 |
-
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
|
| 1644 |
-
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 1645 |
-
}
|
| 1646 |
-
|
| 1647 |
-
/// Store a fragment to memory
|
| 1648 |
-
CUTLASS_DEVICE
|
| 1649 |
-
void store(Fragment const& frag) {
|
| 1650 |
-
store_with_pointer_offset(frag, 0);
|
| 1651 |
-
}
|
| 1652 |
-
};
|
| 1653 |
-
|
| 1654 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 1655 |
-
|
| 1656 |
-
/// Specialization of PredicatedTileIteratorResidualLast for interleaved data.
|
| 1657 |
-
/// It is mapped to the congruous layout.
|
| 1658 |
-
///
|
| 1659 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 1660 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 1661 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 1662 |
-
/// MaskedTileIteratorConcept
|
| 1663 |
-
///
|
| 1664 |
-
|
| 1665 |
-
template <
|
| 1666 |
-
typename Shape_,
|
| 1667 |
-
typename Element_,
|
| 1668 |
-
int AdvanceRank,
|
| 1669 |
-
typename ThreadMap_,
|
| 1670 |
-
int AccessSize,
|
| 1671 |
-
int InterleavedK>
|
| 1672 |
-
class PredicatedTileIteratorResidualLast<
|
| 1673 |
-
Shape_,
|
| 1674 |
-
Element_,
|
| 1675 |
-
layout::ColumnMajorInterleaved<InterleavedK>,
|
| 1676 |
-
AdvanceRank,
|
| 1677 |
-
ThreadMap_,
|
| 1678 |
-
AccessSize,
|
| 1679 |
-
false> {
|
| 1680 |
-
public:
|
| 1681 |
-
static_assert(
|
| 1682 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1683 |
-
"Specialization for pitch-linear iterator may along advance along the "
|
| 1684 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1685 |
-
|
| 1686 |
-
using Shape = Shape_;
|
| 1687 |
-
using Element = Element_;
|
| 1688 |
-
static int const kInterleavedK = InterleavedK;
|
| 1689 |
-
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
|
| 1690 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 1691 |
-
using ThreadMap = ThreadMap_;
|
| 1692 |
-
|
| 1693 |
-
using Index = typename Layout::Index;
|
| 1694 |
-
using LongIndex = typename Layout::LongIndex;
|
| 1695 |
-
|
| 1696 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 1697 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1698 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 1699 |
-
|
| 1700 |
-
using Pointer = Element*;
|
| 1701 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 1702 |
-
|
| 1703 |
-
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
|
| 1704 |
-
layout::PitchLinearShape<
|
| 1705 |
-
Shape::kRow * kInterleavedK,
|
| 1706 |
-
Shape::kColumn / kInterleavedK>,
|
| 1707 |
-
Element,
|
| 1708 |
-
layout::PitchLinear,
|
| 1709 |
-
(kAdvanceRank == 0 ? 0 : 1),
|
| 1710 |
-
ThreadMap,
|
| 1711 |
-
AccessSize>;
|
| 1712 |
-
|
| 1713 |
-
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1714 |
-
|
| 1715 |
-
/// Fragment object to be loaded or stored
|
| 1716 |
-
using Fragment = cutlass::Array<
|
| 1717 |
-
Element,
|
| 1718 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 1719 |
-
|
| 1720 |
-
/// Predicate vector stores mask to guard accesses
|
| 1721 |
-
using Mask = typename UnderlyingIterator::Mask;
|
| 1722 |
-
|
| 1723 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 1724 |
-
class Params {
|
| 1725 |
-
private:
|
| 1726 |
-
friend PredicatedTileIteratorResidualLast;
|
| 1727 |
-
|
| 1728 |
-
/// Parameters object
|
| 1729 |
-
typename UnderlyingIterator::Params params_;
|
| 1730 |
-
|
| 1731 |
-
public:
|
| 1732 |
-
CUTLASS_HOST_DEVICE
|
| 1733 |
-
Params() {}
|
| 1734 |
-
|
| 1735 |
-
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1736 |
-
CUTLASS_HOST_DEVICE
|
| 1737 |
-
Params(Layout const& layout)
|
| 1738 |
-
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1739 |
-
|
| 1740 |
-
CUTLASS_HOST_DEVICE
|
| 1741 |
-
Params(typename UnderlyingIterator::Params::Base const& base)
|
| 1742 |
-
: params_(base) {}
|
| 1743 |
-
};
|
| 1744 |
-
|
| 1745 |
-
private:
|
| 1746 |
-
//
|
| 1747 |
-
// Data members
|
| 1748 |
-
//
|
| 1749 |
-
|
| 1750 |
-
/// Underlying pitch-linear tile iterator
|
| 1751 |
-
UnderlyingIterator iterator_;
|
| 1752 |
-
|
| 1753 |
-
public:
|
| 1754 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1755 |
-
/// and thread ID
|
| 1756 |
-
CUTLASS_HOST_DEVICE
|
| 1757 |
-
PredicatedTileIteratorResidualLast(
|
| 1758 |
-
/// Precomputed parameters object
|
| 1759 |
-
Params const& params,
|
| 1760 |
-
/// Pointer to start of tensor
|
| 1761 |
-
Pointer pointer,
|
| 1762 |
-
/// Extent of tensor
|
| 1763 |
-
TensorCoord extent,
|
| 1764 |
-
/// ID of each participating thread
|
| 1765 |
-
int thread_id,
|
| 1766 |
-
/// Initial offset of threadblock
|
| 1767 |
-
TensorCoord const& threadblock_offset,
|
| 1768 |
-
int const* indices =
|
| 1769 |
-
nullptr ///< gather/scatter indices, note no support for
|
| 1770 |
-
///< gather/scatter at this specialization
|
| 1771 |
-
)
|
| 1772 |
-
: iterator_(
|
| 1773 |
-
params.params_,
|
| 1774 |
-
pointer,
|
| 1775 |
-
layout::PitchLinearCoord(
|
| 1776 |
-
extent.row() * kInterleavedK,
|
| 1777 |
-
extent.column() / kInterleavedK),
|
| 1778 |
-
thread_id,
|
| 1779 |
-
layout::PitchLinearCoord(
|
| 1780 |
-
threadblock_offset.row() * kInterleavedK,
|
| 1781 |
-
threadblock_offset.column() / kInterleavedK)) {}
|
| 1782 |
-
|
| 1783 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 1784 |
-
/// offset
|
| 1785 |
-
CUTLASS_HOST_DEVICE
|
| 1786 |
-
PredicatedTileIteratorResidualLast(
|
| 1787 |
-
Params const& params, ///< Precomputed parameters object
|
| 1788 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 1789 |
-
TensorCoord extent, ///< Extent of tensor
|
| 1790 |
-
int thread_id ///< ID of each participating thread
|
| 1791 |
-
)
|
| 1792 |
-
: PredicatedTileIteratorResidualLast(
|
| 1793 |
-
params,
|
| 1794 |
-
pointer,
|
| 1795 |
-
extent,
|
| 1796 |
-
thread_id,
|
| 1797 |
-
make_Coord(0, 0)) {}
|
| 1798 |
-
|
| 1799 |
-
/// Adds a pointer offset in units of Element
|
| 1800 |
-
CUTLASS_HOST_DEVICE
|
| 1801 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1802 |
-
iterator_.add_pointer_offset(pointer_offset);
|
| 1803 |
-
}
|
| 1804 |
-
|
| 1805 |
-
/// Advances to the next tile in memory.
|
| 1806 |
-
///
|
| 1807 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1808 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1809 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1810 |
-
/// pointer.
|
| 1811 |
-
CUTLASS_HOST_DEVICE
|
| 1812 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 1813 |
-
++iterator_;
|
| 1814 |
-
return *this;
|
| 1815 |
-
}
|
| 1816 |
-
|
| 1817 |
-
/// Advances to the next tile in memory.
|
| 1818 |
-
///
|
| 1819 |
-
/// The first time this method is called, predicates are updated, and the
|
| 1820 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1821 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 1822 |
-
/// pointer.
|
| 1823 |
-
CUTLASS_HOST_DEVICE
|
| 1824 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 1825 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 1826 |
-
operator++();
|
| 1827 |
-
return self;
|
| 1828 |
-
}
|
| 1829 |
-
|
| 1830 |
-
/// Clears the predicate set efficiently
|
| 1831 |
-
CUTLASS_HOST_DEVICE
|
| 1832 |
-
void clear_mask(bool enable = true) {
|
| 1833 |
-
iterator_.clear_mask(enable);
|
| 1834 |
-
}
|
| 1835 |
-
|
| 1836 |
-
CUTLASS_HOST_DEVICE
|
| 1837 |
-
void set_residual_tile(bool enable) {
|
| 1838 |
-
iterator_.set_residual_tile(enable);
|
| 1839 |
-
}
|
| 1840 |
-
|
| 1841 |
-
/// Clears the predicate set efficiently
|
| 1842 |
-
CUTLASS_HOST_DEVICE
|
| 1843 |
-
void enable_mask() {
|
| 1844 |
-
iterator_.enable_mask();
|
| 1845 |
-
}
|
| 1846 |
-
|
| 1847 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1848 |
-
CUTLASS_HOST_DEVICE
|
| 1849 |
-
void set_mask(Mask const& mask) {
|
| 1850 |
-
iterator_.set_mask(mask);
|
| 1851 |
-
}
|
| 1852 |
-
|
| 1853 |
-
/// Gets the mask
|
| 1854 |
-
CUTLASS_HOST_DEVICE
|
| 1855 |
-
void get_mask(Mask& mask) {
|
| 1856 |
-
iterator_.get_mask(mask);
|
| 1857 |
-
}
|
| 1858 |
-
|
| 1859 |
-
/// Loads a fragment from memory
|
| 1860 |
-
CUTLASS_DEVICE
|
| 1861 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 1862 |
-
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1863 |
-
}
|
| 1864 |
-
|
| 1865 |
-
/// Loads a fragment from memory
|
| 1866 |
-
CUTLASS_DEVICE
|
| 1867 |
-
void load(Fragment& frag) {
|
| 1868 |
-
load_with_pointer_offset(frag, 0);
|
| 1869 |
-
}
|
| 1870 |
-
|
| 1871 |
-
/// Store a fragment to memory
|
| 1872 |
-
CUTLASS_DEVICE
|
| 1873 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 1874 |
-
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1875 |
-
}
|
| 1876 |
-
|
| 1877 |
-
/// Store a fragment to memory
|
| 1878 |
-
CUTLASS_DEVICE
|
| 1879 |
-
void store(Fragment const& frag) {
|
| 1880 |
-
store_with_pointer_offset(frag, 0);
|
| 1881 |
-
}
|
| 1882 |
-
};
|
| 1883 |
-
|
| 1884 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 1885 |
-
|
| 1886 |
-
/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32
|
| 1887 |
-
/// data. It is mapped to the congruous layout.
|
| 1888 |
-
///
|
| 1889 |
-
/// Satisfies: ForwardTileIteratorConcept |
|
| 1890 |
-
/// ReadableContiguousTileIteratorConcept |
|
| 1891 |
-
/// WriteableContiguousTileIteratorConcept |
|
| 1892 |
-
/// MaskedTileIteratorConcept
|
| 1893 |
-
///
|
| 1894 |
-
template <
|
| 1895 |
-
typename Shape_,
|
| 1896 |
-
typename Element_,
|
| 1897 |
-
int AdvanceRank,
|
| 1898 |
-
typename ThreadMap_,
|
| 1899 |
-
int AccessSize,
|
| 1900 |
-
int InterleavedK>
|
| 1901 |
-
class PredicatedTileIteratorResidualLast<
|
| 1902 |
-
Shape_,
|
| 1903 |
-
Element_,
|
| 1904 |
-
layout::RowMajorInterleaved<InterleavedK>,
|
| 1905 |
-
AdvanceRank,
|
| 1906 |
-
ThreadMap_,
|
| 1907 |
-
AccessSize,
|
| 1908 |
-
false> {
|
| 1909 |
-
public:
|
| 1910 |
-
static_assert(
|
| 1911 |
-
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1912 |
-
"Specialization for pitch-linear iterator may along advance along the "
|
| 1913 |
-
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1914 |
-
|
| 1915 |
-
using Shape = Shape_;
|
| 1916 |
-
using Element = Element_;
|
| 1917 |
-
static int const kInterleavedK = InterleavedK;
|
| 1918 |
-
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
|
| 1919 |
-
static int const kAdvanceRank = AdvanceRank;
|
| 1920 |
-
using ThreadMap = ThreadMap_;
|
| 1921 |
-
|
| 1922 |
-
using Index = typename Layout::Index;
|
| 1923 |
-
using LongIndex = typename Layout::LongIndex;
|
| 1924 |
-
|
| 1925 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 1926 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1927 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 1928 |
-
|
| 1929 |
-
using Pointer = Element*;
|
| 1930 |
-
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
| 1931 |
-
|
| 1932 |
-
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
|
| 1933 |
-
layout::PitchLinearShape<
|
| 1934 |
-
Shape::kColumn * kInterleavedK,
|
| 1935 |
-
Shape::kRow / kInterleavedK>,
|
| 1936 |
-
Element,
|
| 1937 |
-
layout::PitchLinear,
|
| 1938 |
-
(kAdvanceRank == 0 ? 1 : 0),
|
| 1939 |
-
ThreadMap,
|
| 1940 |
-
AccessSize>;
|
| 1941 |
-
|
| 1942 |
-
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1943 |
-
|
| 1944 |
-
/// Fragment object to be loaded or stored
|
| 1945 |
-
using Fragment = cutlass::Array<
|
| 1946 |
-
Element,
|
| 1947 |
-
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 1948 |
-
|
| 1949 |
-
/// Predicate vector stores mask to guard accesses
|
| 1950 |
-
using Mask = typename UnderlyingIterator::Mask;
|
| 1951 |
-
|
| 1952 |
-
/// Parameters object is precomputed state and is host-constructible
|
| 1953 |
-
class Params {
|
| 1954 |
-
private:
|
| 1955 |
-
friend PredicatedTileIteratorResidualLast;
|
| 1956 |
-
|
| 1957 |
-
/// Parameters object
|
| 1958 |
-
typename UnderlyingIterator::Params params_;
|
| 1959 |
-
|
| 1960 |
-
public:
|
| 1961 |
-
CUTLASS_HOST_DEVICE
|
| 1962 |
-
Params() {}
|
| 1963 |
-
|
| 1964 |
-
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1965 |
-
CUTLASS_HOST_DEVICE
|
| 1966 |
-
Params(Layout const& layout)
|
| 1967 |
-
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1968 |
-
|
| 1969 |
-
CUTLASS_HOST_DEVICE
|
| 1970 |
-
Params(typename UnderlyingIterator::Params::Base const& base)
|
| 1971 |
-
: params_(base) {}
|
| 1972 |
-
};
|
| 1973 |
-
|
| 1974 |
-
private:
|
| 1975 |
-
//
|
| 1976 |
-
// Data members
|
| 1977 |
-
//
|
| 1978 |
-
|
| 1979 |
-
/// Underlying pitch-linear tile iterator
|
| 1980 |
-
UnderlyingIterator iterator_;
|
| 1981 |
-
|
| 1982 |
-
public:
|
| 1983 |
-
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1984 |
-
/// and thread ID
|
| 1985 |
-
CUTLASS_HOST_DEVICE
|
| 1986 |
-
PredicatedTileIteratorResidualLast(
|
| 1987 |
-
/// Precomputed parameters object
|
| 1988 |
-
Params const& params,
|
| 1989 |
-
/// Pointer to start of tensor
|
| 1990 |
-
Pointer pointer,
|
| 1991 |
-
/// Extent of tensor
|
| 1992 |
-
TensorCoord extent,
|
| 1993 |
-
/// ID of each participating thread
|
| 1994 |
-
int thread_id,
|
| 1995 |
-
/// Initial offset of threadblock
|
| 1996 |
-
TensorCoord const& threadblock_offset,
|
| 1997 |
-
int const* indices =
|
| 1998 |
-
nullptr ///< gather/scatter indices, note no support for
|
| 1999 |
-
///< gather/scatter at this specialization
|
| 2000 |
-
)
|
| 2001 |
-
: iterator_(
|
| 2002 |
-
params.params_,
|
| 2003 |
-
pointer,
|
| 2004 |
-
layout::PitchLinearCoord(
|
| 2005 |
-
extent.column() * kInterleavedK,
|
| 2006 |
-
extent.row() / kInterleavedK),
|
| 2007 |
-
thread_id,
|
| 2008 |
-
layout::PitchLinearCoord(
|
| 2009 |
-
threadblock_offset.column() * kInterleavedK,
|
| 2010 |
-
threadblock_offset.row() / kInterleavedK)) {}
|
| 2011 |
-
|
| 2012 |
-
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
|
| 2013 |
-
/// offset
|
| 2014 |
-
CUTLASS_HOST_DEVICE
|
| 2015 |
-
PredicatedTileIteratorResidualLast(
|
| 2016 |
-
Params const& params, ///< Precomputed parameters object
|
| 2017 |
-
Pointer pointer, ///< Pointer to start of tensor
|
| 2018 |
-
TensorCoord extent, ///< Extent of tensor
|
| 2019 |
-
int thread_id ///< ID of each participating thread
|
| 2020 |
-
)
|
| 2021 |
-
: PredicatedTileIteratorResidualLast(
|
| 2022 |
-
params,
|
| 2023 |
-
pointer,
|
| 2024 |
-
extent,
|
| 2025 |
-
thread_id,
|
| 2026 |
-
make_Coord(0, 0)) {}
|
| 2027 |
-
|
| 2028 |
-
/// Adds a pointer offset in units of Element
|
| 2029 |
-
CUTLASS_HOST_DEVICE
|
| 2030 |
-
void add_pointer_offset(LongIndex pointer_offset) {
|
| 2031 |
-
iterator_.add_pointer_offset(pointer_offset);
|
| 2032 |
-
}
|
| 2033 |
-
|
| 2034 |
-
/// Advances to the next tile in memory.
|
| 2035 |
-
///
|
| 2036 |
-
/// The first time this method is called, predicates are updated, and the
|
| 2037 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 2038 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 2039 |
-
/// pointer.
|
| 2040 |
-
CUTLASS_HOST_DEVICE
|
| 2041 |
-
PredicatedTileIteratorResidualLast& operator++() {
|
| 2042 |
-
++iterator_;
|
| 2043 |
-
return *this;
|
| 2044 |
-
}
|
| 2045 |
-
|
| 2046 |
-
/// Advances to the next tile in memory.
|
| 2047 |
-
///
|
| 2048 |
-
/// The first time this method is called, predicates are updated, and the
|
| 2049 |
-
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 2050 |
-
/// Subsequent calls are lightweight and must only update the internal
|
| 2051 |
-
/// pointer.
|
| 2052 |
-
CUTLASS_HOST_DEVICE
|
| 2053 |
-
PredicatedTileIteratorResidualLast operator++(int) {
|
| 2054 |
-
PredicatedTileIteratorResidualLast self(*this);
|
| 2055 |
-
operator++();
|
| 2056 |
-
return self;
|
| 2057 |
-
}
|
| 2058 |
-
|
| 2059 |
-
/// Clears the predicate set efficiently
|
| 2060 |
-
CUTLASS_HOST_DEVICE
|
| 2061 |
-
void clear_mask(bool enable = true) {
|
| 2062 |
-
iterator_.clear_mask(enable);
|
| 2063 |
-
}
|
| 2064 |
-
|
| 2065 |
-
CUTLASS_HOST_DEVICE
|
| 2066 |
-
void set_residual_tile(bool enable) {
|
| 2067 |
-
iterator_.set_residual_tile(enable);
|
| 2068 |
-
}
|
| 2069 |
-
|
| 2070 |
-
/// Clears the predicate set efficiently
|
| 2071 |
-
CUTLASS_HOST_DEVICE
|
| 2072 |
-
void enable_mask() {
|
| 2073 |
-
iterator_.enable_mask();
|
| 2074 |
-
}
|
| 2075 |
-
|
| 2076 |
-
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 2077 |
-
CUTLASS_HOST_DEVICE
|
| 2078 |
-
void set_mask(Mask const& mask) {
|
| 2079 |
-
iterator_.set_mask(mask);
|
| 2080 |
-
}
|
| 2081 |
-
|
| 2082 |
-
/// Gets the mask
|
| 2083 |
-
CUTLASS_HOST_DEVICE
|
| 2084 |
-
void get_mask(Mask& mask) {
|
| 2085 |
-
iterator_.get_mask(mask);
|
| 2086 |
-
}
|
| 2087 |
-
|
| 2088 |
-
/// Loads a fragment from memory
|
| 2089 |
-
CUTLASS_DEVICE
|
| 2090 |
-
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
|
| 2091 |
-
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 2092 |
-
}
|
| 2093 |
-
|
| 2094 |
-
/// Loads a fragment from memory
|
| 2095 |
-
CUTLASS_DEVICE
|
| 2096 |
-
void load(Fragment& frag) {
|
| 2097 |
-
load_with_pointer_offset(frag, 0);
|
| 2098 |
-
}
|
| 2099 |
-
|
| 2100 |
-
/// Store a fragment to memory
|
| 2101 |
-
CUTLASS_DEVICE
|
| 2102 |
-
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
|
| 2103 |
-
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 2104 |
-
}
|
| 2105 |
-
|
| 2106 |
-
/// Store a fragment to memory
|
| 2107 |
-
CUTLASS_DEVICE
|
| 2108 |
-
void store(Fragment const& frag) {
|
| 2109 |
-
store_with_pointer_offset(frag, 0);
|
| 2110 |
-
}
|
| 2111 |
-
};
|
| 2112 |
-
|
| 2113 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 2114 |
-
|
| 2115 |
-
} // namespace threadblock
|
| 2116 |
-
} // namespace transform
|
| 2117 |
-
} // namespace cutlass
|
| 2118 |
-
|
| 2119 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "warp_iterator_from_smem.h"
|
| 35 |
-
|
| 36 |
-
template <typename WarpIterator>
|
| 37 |
-
struct TransposeWarpIterator {
|
| 38 |
-
using Iterator = char;
|
| 39 |
-
static bool constexpr kSupportsTranspose = false;
|
| 40 |
-
};
|
| 41 |
-
|
| 42 |
-
template <
|
| 43 |
-
/// Operand identity
|
| 44 |
-
cutlass::gemm::Operand Operand,
|
| 45 |
-
/// Data type of A elements
|
| 46 |
-
typename Element,
|
| 47 |
-
typename InstructionShape,
|
| 48 |
-
bool kTranspose>
|
| 49 |
-
struct TransposeWarpIterator<
|
| 50 |
-
cutlass::gemm::warp::
|
| 51 |
-
WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
|
| 52 |
-
using Iterator = cutlass::gemm::warp::
|
| 53 |
-
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
|
| 54 |
-
static bool constexpr kSupportsTranspose = true;
|
| 55 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h
DELETED
|
@@ -1,283 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Inspired from
|
| 33 |
-
"cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
|
| 34 |
-
operands from a RowMajor shared-memory layout into registers to use by A100
|
| 35 |
-
TensorCores.
|
| 36 |
-
|
| 37 |
-
The difference with "mma_tensor_op_tile_access_iterator.h" is that:
|
| 38 |
-
(1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
|
| 39 |
-
faster) (2) We support to transpose the operand (eg read `A.transpose()` when
|
| 40 |
-
the shared memory holds `A`)
|
| 41 |
-
|
| 42 |
-
This is only implemented for the specific shapes.
|
| 43 |
-
*/
|
| 44 |
-
#pragma once
|
| 45 |
-
|
| 46 |
-
#include <cutlass/gemm/gemm.h>
|
| 47 |
-
|
| 48 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
-
namespace cutlass {
|
| 50 |
-
namespace gemm {
|
| 51 |
-
namespace warp {
|
| 52 |
-
|
| 53 |
-
template <
|
| 54 |
-
/// Operand identity
|
| 55 |
-
Operand Operand_,
|
| 56 |
-
/// Data type of A elements
|
| 57 |
-
typename Element_,
|
| 58 |
-
typename InstructionShape_,
|
| 59 |
-
bool kTranspose = false>
|
| 60 |
-
class WarpIteratorFromSmem {
|
| 61 |
-
public:
|
| 62 |
-
/// Shape of tile to load (concept: MatrixShape)
|
| 63 |
-
using Shape = cutlass::MatrixShape<32, 32>;
|
| 64 |
-
|
| 65 |
-
/// Operand tag
|
| 66 |
-
static Operand const kOperand = Operand_;
|
| 67 |
-
static_assert(
|
| 68 |
-
kOperand == Operand::kA,
|
| 69 |
-
"No support for OperandB at the moment");
|
| 70 |
-
|
| 71 |
-
/// Basic check
|
| 72 |
-
static_assert(
|
| 73 |
-
kOperand == Operand::kA || kOperand == Operand::kB,
|
| 74 |
-
"WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
|
| 75 |
-
|
| 76 |
-
/// Element type
|
| 77 |
-
using Element = Element_;
|
| 78 |
-
static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
|
| 79 |
-
|
| 80 |
-
/// Layout of source tile
|
| 81 |
-
using Layout = cutlass::layout::RowMajor;
|
| 82 |
-
|
| 83 |
-
/// Shape of one matrix product operation (concept: MatrixShape)
|
| 84 |
-
using InstructionShape = InstructionShape_;
|
| 85 |
-
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
|
| 86 |
-
static_assert(
|
| 87 |
-
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
|
| 88 |
-
"Only supports 16x8x8 / 16x8x16");
|
| 89 |
-
|
| 90 |
-
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
| 91 |
-
/// MatrixShape)
|
| 92 |
-
static int const kOpDelta = 1;
|
| 93 |
-
|
| 94 |
-
/// Number of participating threads
|
| 95 |
-
static int const kThreads = 32;
|
| 96 |
-
|
| 97 |
-
/// TensorRef type for loading element from a tensor
|
| 98 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 99 |
-
|
| 100 |
-
/// Index type
|
| 101 |
-
using Index = typename TensorRef::Index;
|
| 102 |
-
|
| 103 |
-
/// Long Index type
|
| 104 |
-
using LongIndex = typename TensorRef::LongIndex;
|
| 105 |
-
|
| 106 |
-
/// Coordinate for an element in the tensor
|
| 107 |
-
using TensorCoord = typename TensorRef::TensorCoord;
|
| 108 |
-
|
| 109 |
-
/// Number of elements accessed per Shared Memory load
|
| 110 |
-
static int const kElementsPerAccess =
|
| 111 |
-
(sizeof_bits<Element>::value >= 32 ? 1
|
| 112 |
-
: 32 / sizeof_bits<Element>::value);
|
| 113 |
-
|
| 114 |
-
using InstructionCount = MatrixShape<
|
| 115 |
-
Shape::kRow / InstructionShape::kRow,
|
| 116 |
-
Shape::kColumn / InstructionShape::kColumn>;
|
| 117 |
-
|
| 118 |
-
static int const kIterations = (kOperand == Operand::kA)
|
| 119 |
-
? InstructionCount::kColumn
|
| 120 |
-
: InstructionCount::kRow;
|
| 121 |
-
|
| 122 |
-
public:
|
| 123 |
-
//
|
| 124 |
-
// Derived quantities
|
| 125 |
-
//
|
| 126 |
-
|
| 127 |
-
/// Fragment object holding a thread's part of a tile
|
| 128 |
-
using Fragment = Array<
|
| 129 |
-
Element,
|
| 130 |
-
(kOperand == Operand::kA)
|
| 131 |
-
? (Shape::kRow* InstructionShape::kColumn / kThreads)
|
| 132 |
-
: (Shape::kColumn* InstructionShape::kRow / kThreads)>;
|
| 133 |
-
|
| 134 |
-
/// Memory access type
|
| 135 |
-
// using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
| 136 |
-
using AccessType = Array<unsigned, 4>;
|
| 137 |
-
|
| 138 |
-
static int constexpr kWarpShapeDivisibleInner =
|
| 139 |
-
(kOperand == Operand::kA ? InstructionShape::kColumn
|
| 140 |
-
: InstructionShape::kRow);
|
| 141 |
-
static int constexpr kAccessesInner =
|
| 142 |
-
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
| 143 |
-
// Number of 32bits tiles to load per `ldmatrix`
|
| 144 |
-
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
| 145 |
-
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
|
| 146 |
-
|
| 147 |
-
private:
|
| 148 |
-
/// Underlying tensor reference
|
| 149 |
-
TensorRef ref_;
|
| 150 |
-
|
| 151 |
-
/// Origin
|
| 152 |
-
MatrixCoord origin_;
|
| 153 |
-
|
| 154 |
-
/// Iterations in a tile
|
| 155 |
-
int iterations_;
|
| 156 |
-
|
| 157 |
-
public:
|
| 158 |
-
/// Constructor from TensorRef
|
| 159 |
-
CUTLASS_HOST_DEVICE
|
| 160 |
-
WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
|
| 161 |
-
: WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
|
| 162 |
-
CUTLASS_HOST_DEVICE
|
| 163 |
-
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
| 164 |
-
: ref_(ref), iterations_(0) {
|
| 165 |
-
// See also:
|
| 166 |
-
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
|
| 167 |
-
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
|
| 168 |
-
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
|
| 169 |
-
int ldsm_vec_num = (lane_id >> 3);
|
| 170 |
-
if (kOperand == Operand::kA) {
|
| 171 |
-
origin_ = MatrixCoord(lane_id % 8, 0);
|
| 172 |
-
static_assert(
|
| 173 |
-
InstructionCount::kRow * kTilesPerInstruction == 4,
|
| 174 |
-
"can't use ldmatrix.x4");
|
| 175 |
-
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
|
| 176 |
-
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
|
| 177 |
-
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
|
| 178 |
-
MatrixCoord offset(
|
| 179 |
-
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
| 180 |
-
inner_idx * 4 * kElementsPerAccess);
|
| 181 |
-
if (kTranspose) {
|
| 182 |
-
offset = MatrixCoord(offset.column(), offset.row());
|
| 183 |
-
}
|
| 184 |
-
origin_ += offset;
|
| 185 |
-
} else {
|
| 186 |
-
// Note: This is not tested or used
|
| 187 |
-
origin_ = MatrixCoord(0, lane_id % 8);
|
| 188 |
-
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
| 189 |
-
CUTLASS_PRAGMA_UNROLL
|
| 190 |
-
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
|
| 191 |
-
++inst_n_idx) {
|
| 192 |
-
CUTLASS_PRAGMA_UNROLL
|
| 193 |
-
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
| 194 |
-
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
|
| 195 |
-
|
| 196 |
-
MatrixCoord offset(
|
| 197 |
-
inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
|
| 198 |
-
|
| 199 |
-
if (access_idx == ldsm_vec_num) {
|
| 200 |
-
if (kTranspose) {
|
| 201 |
-
offset = MatrixCoord(offset.column(), offset.row());
|
| 202 |
-
}
|
| 203 |
-
origin_ += offset;
|
| 204 |
-
}
|
| 205 |
-
}
|
| 206 |
-
}
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
ref_.add_coord_offset(origin_);
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 213 |
-
/// tiles
|
| 214 |
-
CUTLASS_HOST_DEVICE
|
| 215 |
-
WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
|
| 216 |
-
TensorCoord coord_offset(
|
| 217 |
-
tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
| 218 |
-
if (kTranspose) {
|
| 219 |
-
coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
|
| 220 |
-
}
|
| 221 |
-
origin_ += coord_offset;
|
| 222 |
-
|
| 223 |
-
ref_.add_coord_offset(coord_offset);
|
| 224 |
-
|
| 225 |
-
return *this;
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
/// Advances the iterator along the advance dimension
|
| 229 |
-
CUTLASS_DEVICE
|
| 230 |
-
void advance() {
|
| 231 |
-
if (kOperand == Operand::kA) {
|
| 232 |
-
add_tile_offset({0, 1});
|
| 233 |
-
} else {
|
| 234 |
-
add_tile_offset({1, 0});
|
| 235 |
-
}
|
| 236 |
-
|
| 237 |
-
iterations_ = 0;
|
| 238 |
-
}
|
| 239 |
-
|
| 240 |
-
/// increase iterations in a tile
|
| 241 |
-
CUTLASS_HOST_DEVICE
|
| 242 |
-
WarpIteratorFromSmem& operator++() {
|
| 243 |
-
iterations_++;
|
| 244 |
-
|
| 245 |
-
if (iterations_ >= kIterations)
|
| 246 |
-
advance();
|
| 247 |
-
|
| 248 |
-
return *this;
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
/// Loads a fragment from memory at the location pointed to by the iterator.
|
| 252 |
-
CUTLASS_DEVICE
|
| 253 |
-
void load(Fragment& frag) const {
|
| 254 |
-
AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
|
| 255 |
-
using LoadLayout = typename platform::
|
| 256 |
-
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
| 257 |
-
|
| 258 |
-
CUTLASS_PRAGMA_UNROLL
|
| 259 |
-
for (int access_m_idx = 0; access_m_idx <
|
| 260 |
-
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
|
| 261 |
-
++access_m_idx) {
|
| 262 |
-
MatrixCoord offset;
|
| 263 |
-
if (kOperand == Operand::kA) {
|
| 264 |
-
offset = MatrixCoord(
|
| 265 |
-
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
|
| 266 |
-
} else {
|
| 267 |
-
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
| 268 |
-
}
|
| 269 |
-
if (kTranspose) {
|
| 270 |
-
offset = MatrixCoord(offset.column(), offset.row());
|
| 271 |
-
}
|
| 272 |
-
cutlass::arch::ldsm<LoadLayout, 4>(
|
| 273 |
-
access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
|
| 274 |
-
}
|
| 275 |
-
}
|
| 276 |
-
};
|
| 277 |
-
|
| 278 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 279 |
-
|
| 280 |
-
} // namespace warp
|
| 281 |
-
} // namespace gemm
|
| 282 |
-
} // namespace cutlass
|
| 283 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h
DELETED
|
@@ -1,1322 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#ifdef HAS_PYTORCH
|
| 35 |
-
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 36 |
-
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
| 37 |
-
#endif
|
| 38 |
-
|
| 39 |
-
#include <curand_kernel.h>
|
| 40 |
-
#include <cmath>
|
| 41 |
-
#include <cinttypes>
|
| 42 |
-
#include <vector>
|
| 43 |
-
|
| 44 |
-
#include "cutlass/fast_math.h"
|
| 45 |
-
#include "cutlass/gemm/gemm.h"
|
| 46 |
-
#include "cutlass/layout/matrix.h"
|
| 47 |
-
#include "cutlass/layout/vector.h"
|
| 48 |
-
#include "cutlass/matrix.h"
|
| 49 |
-
#include "cutlass/numeric_types.h"
|
| 50 |
-
#include "cutlass/tensor_ref.h"
|
| 51 |
-
|
| 52 |
-
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
| 53 |
-
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 54 |
-
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 55 |
-
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
| 56 |
-
#include "cutlass/gemm/kernel/default_gemm.h"
|
| 57 |
-
#include "cutlass/gemm/threadblock/default_mma.h"
|
| 58 |
-
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
| 59 |
-
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
| 60 |
-
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
| 61 |
-
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
| 62 |
-
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
| 63 |
-
#include "cutlass/matrix_shape.h"
|
| 64 |
-
#include "cutlass/platform/platform.h"
|
| 65 |
-
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
| 66 |
-
#include "debug_utils.h"
|
| 67 |
-
#include "epilogue/epilogue_pipelined.h"
|
| 68 |
-
#include "epilogue/epilogue_rescale_output.h"
|
| 69 |
-
#include "gemm/custom_mma.h"
|
| 70 |
-
#include "gemm/find_default_mma.h"
|
| 71 |
-
#include "gemm/mma_from_smem.h"
|
| 72 |
-
#include "gemm_kernel_utils.h"
|
| 73 |
-
#include "transform/tile_smem_loader.h"
|
| 74 |
-
|
| 75 |
-
using namespace gemm_kernel_utils;
|
| 76 |
-
|
| 77 |
-
namespace {
|
| 78 |
-
template <typename scalar_t, typename Arch>
|
| 79 |
-
constexpr int getWarpsPerSmFw() {
|
| 80 |
-
return (
|
| 81 |
-
Arch::kMinComputeCapability >= 80 &&
|
| 82 |
-
!cutlass::platform::is_same<scalar_t, float>::value
|
| 83 |
-
? 16
|
| 84 |
-
: 12);
|
| 85 |
-
}
|
| 86 |
-
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
| 87 |
-
// source: https://stackoverflow.com/a/51549250
|
| 88 |
-
return (value >= 0)
|
| 89 |
-
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
| 90 |
-
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
| 91 |
-
}
|
| 92 |
-
} // namespace
|
| 93 |
-
|
| 94 |
-
// If ToBatchHookType_ is supplied other than this default (which is
|
| 95 |
-
// never the case in the xformers library) then the user is
|
| 96 |
-
// defining the logic which each block uses to find its data to work on,
|
| 97 |
-
// with the advance_to_batch function with the following signature.
|
| 98 |
-
// It should return false if there is no work to do for this block.
|
| 99 |
-
// In general this will not work with saving for backward due to fixed layout
|
| 100 |
-
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
|
| 101 |
-
// custom inference.
|
| 102 |
-
struct DefaultToBatchHook {
|
| 103 |
-
template <typename Params>
|
| 104 |
-
CUTLASS_DEVICE static bool advance_to_batch(
|
| 105 |
-
Params&,
|
| 106 |
-
int64_t& /* q_start */,
|
| 107 |
-
int64_t& /* k_start */) {
|
| 108 |
-
return true;
|
| 109 |
-
}
|
| 110 |
-
};
|
| 111 |
-
|
| 112 |
-
template <
|
| 113 |
-
// The datatype of Q/K/V
|
| 114 |
-
typename scalar_t_,
|
| 115 |
-
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
|
| 116 |
-
typename ArchTag,
|
| 117 |
-
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
| 118 |
-
bool isAligned_,
|
| 119 |
-
int kQueriesPerBlock_,
|
| 120 |
-
int kKeysPerBlock_,
|
| 121 |
-
// upperbound on `max(value.shape[-1], query.shape[-1])`
|
| 122 |
-
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
| 123 |
-
// This is quite slower on V100 for some reason
|
| 124 |
-
// Set to false if you know at compile-time you will never need dropout
|
| 125 |
-
bool kSupportsDropout_ = true,
|
| 126 |
-
bool kSupportsBias_ = true,
|
| 127 |
-
typename ToBatchHookType_ = DefaultToBatchHook>
|
| 128 |
-
struct AttentionKernel {
|
| 129 |
-
enum CustomMaskType {
|
| 130 |
-
NoCustomMask = 0,
|
| 131 |
-
CausalFromTopLeft = 1,
|
| 132 |
-
CausalFromBottomRight = 2,
|
| 133 |
-
NumCustomMaskTypes,
|
| 134 |
-
};
|
| 135 |
-
|
| 136 |
-
using scalar_t = scalar_t_;
|
| 137 |
-
using accum_t = float;
|
| 138 |
-
using lse_scalar_t = float;
|
| 139 |
-
using output_t = scalar_t;
|
| 140 |
-
// Accumulator between 2 iterations
|
| 141 |
-
// Using `accum_t` improves perf on f16 at the cost of
|
| 142 |
-
// numerical errors
|
| 143 |
-
using output_accum_t = accum_t;
|
| 144 |
-
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
| 145 |
-
static constexpr bool kSupportsBias = kSupportsBias_;
|
| 146 |
-
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
| 147 |
-
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
|
| 148 |
-
static constexpr int kMaxK = kMaxK_;
|
| 149 |
-
static constexpr bool kIsAligned = isAligned_;
|
| 150 |
-
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
| 151 |
-
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
| 152 |
-
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
| 153 |
-
static constexpr bool kPreloadV =
|
| 154 |
-
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
|
| 155 |
-
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
| 156 |
-
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
| 157 |
-
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
| 158 |
-
|
| 159 |
-
static_assert(kQueriesPerBlock % 32 == 0, "");
|
| 160 |
-
static_assert(kKeysPerBlock % 32 == 0, "");
|
| 161 |
-
static constexpr int kNumWarpsPerBlock =
|
| 162 |
-
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
|
| 163 |
-
static constexpr int kWarpSize = 32;
|
| 164 |
-
|
| 165 |
-
// Launch bounds
|
| 166 |
-
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
| 167 |
-
static constexpr int kMinBlocksPerSm =
|
| 168 |
-
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
| 169 |
-
|
| 170 |
-
struct Params {
|
| 171 |
-
// Input tensors
|
| 172 |
-
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
|
| 173 |
-
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
|
| 174 |
-
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
|
| 175 |
-
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
| 176 |
-
int32_t* seqstart_q_ptr = nullptr;
|
| 177 |
-
int32_t* seqstart_k_ptr = nullptr;
|
| 178 |
-
|
| 179 |
-
int32_t* seqlen_k_ptr = nullptr;
|
| 180 |
-
uint32_t causal_diagonal_offset = 0;
|
| 181 |
-
|
| 182 |
-
// Output tensors
|
| 183 |
-
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
|
| 184 |
-
// [num_queries, num_heads, head_dim_value]
|
| 185 |
-
output_accum_t* output_accum_ptr = nullptr;
|
| 186 |
-
// [num_heads, num_queries] - can be null
|
| 187 |
-
lse_scalar_t* logsumexp_ptr = nullptr;
|
| 188 |
-
|
| 189 |
-
// Scale
|
| 190 |
-
accum_t scale = 0.0;
|
| 191 |
-
|
| 192 |
-
// Dimensions/strides
|
| 193 |
-
int32_t head_dim = 0;
|
| 194 |
-
int32_t head_dim_value = 0;
|
| 195 |
-
int32_t num_queries = 0;
|
| 196 |
-
int32_t num_keys = 0;
|
| 197 |
-
int32_t num_keys_absolute = 0;
|
| 198 |
-
|
| 199 |
-
uint8_t custom_mask_type = NoCustomMask;
|
| 200 |
-
|
| 201 |
-
int32_t q_strideM = 0;
|
| 202 |
-
int32_t k_strideM = 0;
|
| 203 |
-
int32_t v_strideM = 0;
|
| 204 |
-
int32_t bias_strideM = 0;
|
| 205 |
-
|
| 206 |
-
int32_t o_strideM = 0;
|
| 207 |
-
|
| 208 |
-
// Everything below is only used in `advance_to_block`
|
| 209 |
-
// and shouldn't use registers
|
| 210 |
-
int32_t q_strideH = 0;
|
| 211 |
-
int32_t k_strideH = 0;
|
| 212 |
-
int32_t v_strideH = 0;
|
| 213 |
-
int64_t bias_strideH = 0;
|
| 214 |
-
|
| 215 |
-
int64_t q_strideB = 0;
|
| 216 |
-
int64_t k_strideB = 0;
|
| 217 |
-
int64_t v_strideB = 0;
|
| 218 |
-
int64_t bias_strideB = 0;
|
| 219 |
-
|
| 220 |
-
int32_t num_batches = 0;
|
| 221 |
-
int32_t num_heads = 0;
|
| 222 |
-
|
| 223 |
-
// dropout
|
| 224 |
-
bool use_dropout = false;
|
| 225 |
-
unsigned long long dropout_batch_head_rng_offset = 0;
|
| 226 |
-
float dropout_prob = 0.0f;
|
| 227 |
-
#ifdef HAS_PYTORCH
|
| 228 |
-
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
|
| 229 |
-
#endif
|
| 230 |
-
|
| 231 |
-
// Moves pointers to what we should process
|
| 232 |
-
// Returns "false" if there is no work to do
|
| 233 |
-
CUTLASS_DEVICE bool advance_to_block() {
|
| 234 |
-
auto batch_id = blockIdx.z;
|
| 235 |
-
auto head_id = blockIdx.y;
|
| 236 |
-
auto query_start = blockIdx.x * kQueriesPerBlock;
|
| 237 |
-
|
| 238 |
-
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
|
| 239 |
-
|
| 240 |
-
if (kSupportsDropout) {
|
| 241 |
-
dropout_batch_head_rng_offset =
|
| 242 |
-
batch_id * num_heads * num_queries * num_keys +
|
| 243 |
-
head_id * num_queries * num_keys;
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
int64_t q_start = 0, k_start = 0;
|
| 247 |
-
// Advance to current batch - in case of different sequence lengths
|
| 248 |
-
constexpr bool kToBatchHook =
|
| 249 |
-
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
|
| 250 |
-
value;
|
| 251 |
-
if (kToBatchHook) {
|
| 252 |
-
// Call out to a custom implementation.
|
| 253 |
-
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
|
| 254 |
-
return false;
|
| 255 |
-
}
|
| 256 |
-
} else if (seqstart_q_ptr != nullptr) {
|
| 257 |
-
assert(seqstart_k_ptr != nullptr);
|
| 258 |
-
seqstart_q_ptr += batch_id;
|
| 259 |
-
|
| 260 |
-
q_start = seqstart_q_ptr[0];
|
| 261 |
-
int64_t q_next_start = seqstart_q_ptr[1];
|
| 262 |
-
int64_t k_end;
|
| 263 |
-
seqstart_k_ptr += batch_id;
|
| 264 |
-
|
| 265 |
-
if (seqlen_k_ptr) {
|
| 266 |
-
k_start = seqstart_k_ptr[0];
|
| 267 |
-
k_end = k_start + seqlen_k_ptr[batch_id];
|
| 268 |
-
} else {
|
| 269 |
-
k_start = seqstart_k_ptr[0];
|
| 270 |
-
k_end = seqstart_k_ptr[1];
|
| 271 |
-
}
|
| 272 |
-
|
| 273 |
-
num_queries = q_next_start - q_start;
|
| 274 |
-
num_keys = k_end - k_start;
|
| 275 |
-
|
| 276 |
-
if (query_start >= num_queries) {
|
| 277 |
-
return false;
|
| 278 |
-
}
|
| 279 |
-
} else {
|
| 280 |
-
query_ptr += batch_id * q_strideB;
|
| 281 |
-
key_ptr += batch_id * k_strideB;
|
| 282 |
-
value_ptr += batch_id * v_strideB;
|
| 283 |
-
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
|
| 284 |
-
if (output_accum_ptr != nullptr) {
|
| 285 |
-
output_accum_ptr +=
|
| 286 |
-
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
|
| 287 |
-
}
|
| 288 |
-
q_start = 0;
|
| 289 |
-
k_start = 0;
|
| 290 |
-
}
|
| 291 |
-
|
| 292 |
-
// Advance to the current batch / head / query_start
|
| 293 |
-
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
|
| 294 |
-
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
| 295 |
-
|
| 296 |
-
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
| 297 |
-
output_ptr +=
|
| 298 |
-
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
|
| 299 |
-
|
| 300 |
-
if (kSupportsBias && attn_bias_ptr != nullptr) {
|
| 301 |
-
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
|
| 302 |
-
}
|
| 303 |
-
if (output_accum_ptr != nullptr) {
|
| 304 |
-
output_accum_ptr +=
|
| 305 |
-
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
|
| 306 |
-
head_id * head_dim_value;
|
| 307 |
-
} else {
|
| 308 |
-
// Accumulate directly in the destination buffer (eg for f32)
|
| 309 |
-
output_accum_ptr = (accum_t*)output_ptr;
|
| 310 |
-
}
|
| 311 |
-
|
| 312 |
-
if (logsumexp_ptr != nullptr) {
|
| 313 |
-
// lse[batch_id, head_id, query_start]
|
| 314 |
-
logsumexp_ptr +=
|
| 315 |
-
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
// Custom masking
|
| 319 |
-
if (custom_mask_type == CausalFromBottomRight) {
|
| 320 |
-
causal_diagonal_offset = num_keys - num_queries;
|
| 321 |
-
}
|
| 322 |
-
// We use num_keys_absolute to index into the rng_state
|
| 323 |
-
// We need this index to match between forward and backwards
|
| 324 |
-
num_keys_absolute = num_keys;
|
| 325 |
-
if (custom_mask_type == CausalFromTopLeft ||
|
| 326 |
-
custom_mask_type == CausalFromBottomRight) {
|
| 327 |
-
// the bottom row of the current block is query_start + kQueriesPerBlock
|
| 328 |
-
// the last active key is then query_start + causal_diagonal_offset +
|
| 329 |
-
// kQueriesPerBlock so num_keys is the min between actual num_keys and
|
| 330 |
-
// this to avoid extra computations
|
| 331 |
-
num_keys = cutlass::fast_min(
|
| 332 |
-
int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
|
| 333 |
-
num_keys);
|
| 334 |
-
}
|
| 335 |
-
|
| 336 |
-
num_queries -= query_start;
|
| 337 |
-
num_batches = 0; // no longer used after
|
| 338 |
-
|
| 339 |
-
// If num_queries == 1, and there is only one key head we're wasting
|
| 340 |
-
// 15/16th of tensor core compute In that case :
|
| 341 |
-
// - we only launch kernels for head_id % kQueriesPerBlock == 0
|
| 342 |
-
// - we iterate over heads instead of queries (strideM = strideH)
|
| 343 |
-
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
|
| 344 |
-
if (head_id % kQueriesPerBlock != 0)
|
| 345 |
-
return false;
|
| 346 |
-
q_strideM = q_strideH;
|
| 347 |
-
num_queries = num_heads;
|
| 348 |
-
num_heads = 1; // unused but here for intent
|
| 349 |
-
// remove causal since n_query = 1
|
| 350 |
-
// otherwise, offset would change with head !
|
| 351 |
-
custom_mask_type = NoCustomMask;
|
| 352 |
-
o_strideM = head_dim_value;
|
| 353 |
-
}
|
| 354 |
-
|
| 355 |
-
// Make sure the compiler knows these variables are the same on all
|
| 356 |
-
// the threads of the warp.
|
| 357 |
-
// Only worth doing if they could have been modified above.
|
| 358 |
-
query_ptr = warp_uniform(query_ptr);
|
| 359 |
-
key_ptr = warp_uniform(key_ptr);
|
| 360 |
-
value_ptr = warp_uniform(value_ptr);
|
| 361 |
-
if (kSupportsBias) {
|
| 362 |
-
attn_bias_ptr = warp_uniform(attn_bias_ptr);
|
| 363 |
-
}
|
| 364 |
-
output_ptr = warp_uniform(output_ptr);
|
| 365 |
-
output_accum_ptr = warp_uniform(output_accum_ptr);
|
| 366 |
-
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
| 367 |
-
num_queries = warp_uniform(num_queries);
|
| 368 |
-
num_keys = warp_uniform(num_keys);
|
| 369 |
-
num_heads = warp_uniform(num_heads);
|
| 370 |
-
o_strideM = warp_uniform(o_strideM);
|
| 371 |
-
custom_mask_type = warp_uniform(custom_mask_type);
|
| 372 |
-
return true;
|
| 373 |
-
}
|
| 374 |
-
|
| 375 |
-
__host__ dim3 getBlocksGrid() const {
|
| 376 |
-
return dim3(
|
| 377 |
-
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
|
| 378 |
-
num_heads,
|
| 379 |
-
num_batches);
|
| 380 |
-
}
|
| 381 |
-
|
| 382 |
-
__host__ dim3 getThreadsGrid() const {
|
| 383 |
-
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
|
| 384 |
-
}
|
| 385 |
-
};
|
| 386 |
-
|
| 387 |
-
struct MM0 {
|
| 388 |
-
/*
|
| 389 |
-
In this first matmul, we compute a block of `Q @ K.T`.
|
| 390 |
-
While the calculation result is still hot in registers, we update
|
| 391 |
-
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
|
| 392 |
-
into a shared-memory ("AccumulatorSharedStorage") that is used later as
|
| 393 |
-
operand A for the second matmul (see MM1)
|
| 394 |
-
*/
|
| 395 |
-
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
| 396 |
-
|
| 397 |
-
using OpClass = typename GemmType::OpClass;
|
| 398 |
-
using DefaultConfig =
|
| 399 |
-
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
| 400 |
-
OpClass,
|
| 401 |
-
ArchTag,
|
| 402 |
-
scalar_t,
|
| 403 |
-
scalar_t,
|
| 404 |
-
scalar_t, // ElementC
|
| 405 |
-
accum_t // ElementAccumulator
|
| 406 |
-
>;
|
| 407 |
-
static constexpr int kAlignmentA =
|
| 408 |
-
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
|
| 409 |
-
static constexpr int kAlignmentB =
|
| 410 |
-
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
| 411 |
-
using ThreadblockShape = cutlass::gemm::
|
| 412 |
-
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
| 413 |
-
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
| 414 |
-
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
|
| 415 |
-
scalar_t, // ElementA,
|
| 416 |
-
cutlass::layout::RowMajor, // LayoutA,
|
| 417 |
-
kAlignmentA,
|
| 418 |
-
scalar_t, // ElementB,
|
| 419 |
-
cutlass::layout::ColumnMajor, // LayoutB,
|
| 420 |
-
kAlignmentB,
|
| 421 |
-
accum_t,
|
| 422 |
-
cutlass::layout::RowMajor, // LayoutC,
|
| 423 |
-
OpClass,
|
| 424 |
-
ArchTag, // ArchTag
|
| 425 |
-
ThreadblockShape, // ThreadblockShape
|
| 426 |
-
WarpShape, // WarpShape
|
| 427 |
-
typename GemmType::InstructionShape, // InstructionShape
|
| 428 |
-
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
| 429 |
-
? 4
|
| 430 |
-
: DefaultConfig::kStages,
|
| 431 |
-
typename GemmType::Operator // Operator
|
| 432 |
-
>::DefaultMma;
|
| 433 |
-
using MmaCore = typename DefaultMma::MmaCore;
|
| 434 |
-
using IteratorA = typename DefaultMma::IteratorA;
|
| 435 |
-
using IteratorB = typename DefaultMma::IteratorB;
|
| 436 |
-
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
| 437 |
-
using Mma = typename cutlass::platform::conditional<
|
| 438 |
-
kSingleValueIteration,
|
| 439 |
-
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
| 440 |
-
DefaultThreadblockMma>::type;
|
| 441 |
-
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
| 442 |
-
typename Mma::Operator::IteratorC,
|
| 443 |
-
accum_t,
|
| 444 |
-
kWarpSize>::Iterator;
|
| 445 |
-
static_assert(
|
| 446 |
-
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
|
| 447 |
-
MmaCore::WarpCount::kK ==
|
| 448 |
-
kNumWarpsPerBlock,
|
| 449 |
-
"");
|
| 450 |
-
|
| 451 |
-
// used for efficient load of bias tile Bij from global to shared memory
|
| 452 |
-
using BiasLoader = TileSmemLoader<
|
| 453 |
-
scalar_t,
|
| 454 |
-
cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
|
| 455 |
-
MmaCore::kThreads,
|
| 456 |
-
// input restriction: kv_len has to be a multiple of this value
|
| 457 |
-
128 / cutlass::sizeof_bits<scalar_t>::value>;
|
| 458 |
-
|
| 459 |
-
// Epilogue to store to shared-memory in a format that we can use later for
|
| 460 |
-
// the second matmul
|
| 461 |
-
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
| 462 |
-
typename Mma::Operator::IteratorC,
|
| 463 |
-
typename Mma::Operator,
|
| 464 |
-
scalar_t,
|
| 465 |
-
WarpShape,
|
| 466 |
-
ThreadblockShape>;
|
| 467 |
-
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
| 468 |
-
};
|
| 469 |
-
|
| 470 |
-
struct MM1 {
|
| 471 |
-
/**
|
| 472 |
-
Second matmul: perform `attn @ V` where `attn` is the attention (not
|
| 473 |
-
normalized) and stored in shared memory
|
| 474 |
-
*/
|
| 475 |
-
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
| 476 |
-
|
| 477 |
-
using OpClass = typename GemmType::OpClass;
|
| 478 |
-
using DefaultConfig =
|
| 479 |
-
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
| 480 |
-
OpClass,
|
| 481 |
-
ArchTag,
|
| 482 |
-
scalar_t,
|
| 483 |
-
scalar_t,
|
| 484 |
-
output_accum_t, // ElementC
|
| 485 |
-
accum_t // ElementAccumulator
|
| 486 |
-
>;
|
| 487 |
-
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
|
| 488 |
-
static constexpr int kAlignmentB =
|
| 489 |
-
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
| 490 |
-
using ThreadblockShape = cutlass::gemm::
|
| 491 |
-
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
| 492 |
-
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
| 493 |
-
using InstructionShape = typename GemmType::InstructionShape;
|
| 494 |
-
|
| 495 |
-
using LayoutB = cutlass::layout::RowMajor;
|
| 496 |
-
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
| 497 |
-
scalar_t, // ElementA,
|
| 498 |
-
cutlass::layout::RowMajor, // LayoutA,
|
| 499 |
-
kAlignmentA,
|
| 500 |
-
scalar_t, // ElementB,
|
| 501 |
-
LayoutB, // LayoutB,
|
| 502 |
-
kAlignmentB,
|
| 503 |
-
output_accum_t,
|
| 504 |
-
cutlass::layout::RowMajor, // LayoutC,
|
| 505 |
-
accum_t,
|
| 506 |
-
OpClass,
|
| 507 |
-
ArchTag,
|
| 508 |
-
ThreadblockShape,
|
| 509 |
-
WarpShape,
|
| 510 |
-
typename GemmType::InstructionShape,
|
| 511 |
-
typename DefaultConfig::EpilogueOutputOp,
|
| 512 |
-
void, // ThreadblockSwizzle - not used
|
| 513 |
-
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
| 514 |
-
? 4
|
| 515 |
-
: DefaultConfig::kStages,
|
| 516 |
-
false, // SplitKSerial
|
| 517 |
-
typename GemmType::Operator>;
|
| 518 |
-
|
| 519 |
-
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
| 520 |
-
DefaultWarpIteratorAFromSharedMemory<
|
| 521 |
-
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
| 522 |
-
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
| 523 |
-
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
| 524 |
-
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
| 525 |
-
using DefaultMmaFromSmem =
|
| 526 |
-
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
| 527 |
-
typename DefaultGemm::Mma,
|
| 528 |
-
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
| 529 |
-
WarpIteratorA,
|
| 530 |
-
false>; // kScaleOperandA
|
| 531 |
-
using Mma = typename DefaultMmaFromSmem::Mma;
|
| 532 |
-
using IteratorB = typename Mma::IteratorB;
|
| 533 |
-
using WarpCount = typename Mma::WarpCount;
|
| 534 |
-
static_assert(
|
| 535 |
-
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
|
| 536 |
-
"");
|
| 537 |
-
|
| 538 |
-
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
| 539 |
-
using OutputTileIterator =
|
| 540 |
-
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 541 |
-
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
| 542 |
-
output_t>;
|
| 543 |
-
using OutputTileIteratorAccum =
|
| 544 |
-
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 545 |
-
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
| 546 |
-
output_accum_t>;
|
| 547 |
-
};
|
| 548 |
-
|
| 549 |
-
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
| 550 |
-
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
|
| 551 |
-
static constexpr int64_t kAlignmentV = 1;
|
| 552 |
-
|
| 553 |
-
// Shared storage - depends on kernel params
|
| 554 |
-
struct ScalingCoefs {
|
| 555 |
-
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
| 556 |
-
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
| 557 |
-
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
| 558 |
-
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
|
| 559 |
-
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
| 560 |
-
addition_storage;
|
| 561 |
-
};
|
| 562 |
-
|
| 563 |
-
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
| 564 |
-
struct SharedStorageAfterMM0 {
|
| 565 |
-
// Everything here might be overwritten during MM0
|
| 566 |
-
union {
|
| 567 |
-
typename MM0::BiasLoader::SmemTile bias;
|
| 568 |
-
typename MM0::AccumulatorSharedStorage si;
|
| 569 |
-
};
|
| 570 |
-
typename MM1::Mma::SharedStorage mm1;
|
| 571 |
-
};
|
| 572 |
-
|
| 573 |
-
union {
|
| 574 |
-
typename MM0::Mma::SharedStorage mm0;
|
| 575 |
-
SharedStorageAfterMM0 after_mm0;
|
| 576 |
-
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
| 577 |
-
};
|
| 578 |
-
|
| 579 |
-
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
| 580 |
-
epilogue_shared_storage() {
|
| 581 |
-
return epilogue;
|
| 582 |
-
}
|
| 583 |
-
};
|
| 584 |
-
|
| 585 |
-
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
| 586 |
-
struct SharedStorageAfterMM0 {
|
| 587 |
-
// Everything here might be overwritten during MM0
|
| 588 |
-
union {
|
| 589 |
-
typename MM0::BiasLoader::SmemTile bias;
|
| 590 |
-
typename MM0::AccumulatorSharedStorage si;
|
| 591 |
-
};
|
| 592 |
-
typename MM1::Mma::SharedStorage mm1;
|
| 593 |
-
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
| 594 |
-
};
|
| 595 |
-
|
| 596 |
-
union {
|
| 597 |
-
typename MM0::Mma::SharedStorage mm0;
|
| 598 |
-
SharedStorageAfterMM0 after_mm0;
|
| 599 |
-
};
|
| 600 |
-
|
| 601 |
-
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
| 602 |
-
epilogue_shared_storage() {
|
| 603 |
-
return after_mm0.epilogue;
|
| 604 |
-
}
|
| 605 |
-
};
|
| 606 |
-
|
| 607 |
-
using SharedStorage = typename cutlass::platform::conditional<
|
| 608 |
-
kSingleValueIteration || kKeepOutputInRF,
|
| 609 |
-
SharedStorageEpilogueAtEnd,
|
| 610 |
-
SharedStorageEpilogueInLoop>::type;
|
| 611 |
-
|
| 612 |
-
static bool __host__ check_supported(Params const& p) {
|
| 613 |
-
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
|
| 614 |
-
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
|
| 615 |
-
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
|
| 616 |
-
if (kSupportsBias) {
|
| 617 |
-
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
|
| 618 |
-
XFORMERS_CHECK(
|
| 619 |
-
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
|
| 620 |
-
"attn_bias is not correctly aligned (strideB)");
|
| 621 |
-
XFORMERS_CHECK(
|
| 622 |
-
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
|
| 623 |
-
"attn_bias is not correctly aligned (strideH)");
|
| 624 |
-
XFORMERS_CHECK(
|
| 625 |
-
p.bias_strideM % kAlignmentQ == 0,
|
| 626 |
-
"attn_bias is not correctly aligned");
|
| 627 |
-
}
|
| 628 |
-
XFORMERS_CHECK(
|
| 629 |
-
p.q_strideM % kAlignmentQ == 0,
|
| 630 |
-
"query is not correctly aligned (strideM)");
|
| 631 |
-
XFORMERS_CHECK(
|
| 632 |
-
p.k_strideM % kAlignmentK == 0,
|
| 633 |
-
"key is not correctly aligned (strideM)");
|
| 634 |
-
XFORMERS_CHECK(
|
| 635 |
-
p.v_strideM % kAlignmentV == 0,
|
| 636 |
-
"value is not correctly aligned (strideM)");
|
| 637 |
-
XFORMERS_CHECK(
|
| 638 |
-
p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
|
| 639 |
-
"query is not correctly aligned (strideH)");
|
| 640 |
-
XFORMERS_CHECK(
|
| 641 |
-
p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
|
| 642 |
-
"key is not correctly aligned (strideH)");
|
| 643 |
-
XFORMERS_CHECK(
|
| 644 |
-
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
| 645 |
-
"value is not correctly aligned (strideH)");
|
| 646 |
-
XFORMERS_CHECK(
|
| 647 |
-
p.custom_mask_type < NumCustomMaskTypes,
|
| 648 |
-
"invalid value for `custom_mask_type`");
|
| 649 |
-
return true;
|
| 650 |
-
}
|
| 651 |
-
|
| 652 |
-
static void CUTLASS_DEVICE attention_kernel(Params& p) {
|
| 653 |
-
// In this block, we will only ever:
|
| 654 |
-
// - read query[query_start:query_end, :]
|
| 655 |
-
// - write to output[query_start:query_end, :]
|
| 656 |
-
|
| 657 |
-
extern __shared__ char smem_buffer[];
|
| 658 |
-
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
| 659 |
-
auto& m_prime = shared_storage.m_prime;
|
| 660 |
-
auto& s_prime = shared_storage.s_prime;
|
| 661 |
-
auto& mi = shared_storage.mi;
|
| 662 |
-
auto& out_rescale = shared_storage.out_rescale;
|
| 663 |
-
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
| 664 |
-
|
| 665 |
-
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
| 666 |
-
if (thread_id() < kQueriesPerBlock) {
|
| 667 |
-
s_prime[thread_id()] = accum_t(0);
|
| 668 |
-
out_rescale[thread_id()] = accum_t(1.0);
|
| 669 |
-
m_prime[thread_id()] =
|
| 670 |
-
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
| 671 |
-
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
| 672 |
-
}
|
| 673 |
-
typename MM1::Mma::FragmentC accum_o;
|
| 674 |
-
accum_o.clear();
|
| 675 |
-
|
| 676 |
-
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
| 677 |
-
using OutputTileIterator = typename MM1::OutputTileIterator;
|
| 678 |
-
return OutputTileIterator(
|
| 679 |
-
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
|
| 680 |
-
p.output_ptr,
|
| 681 |
-
typename OutputTileIterator::TensorCoord{
|
| 682 |
-
p.num_queries, p.head_dim_value},
|
| 683 |
-
thread_id(),
|
| 684 |
-
{0, col});
|
| 685 |
-
};
|
| 686 |
-
|
| 687 |
-
auto createOutputAccumIter = [&](int col) ->
|
| 688 |
-
typename MM1::OutputTileIteratorAccum {
|
| 689 |
-
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
| 690 |
-
return OutputTileIteratorAccum(
|
| 691 |
-
typename OutputTileIteratorAccum::Params{
|
| 692 |
-
(int32_t)(p.head_dim_value * p.num_heads)},
|
| 693 |
-
p.output_accum_ptr,
|
| 694 |
-
typename OutputTileIteratorAccum::TensorCoord{
|
| 695 |
-
p.num_queries, p.head_dim_value},
|
| 696 |
-
thread_id(),
|
| 697 |
-
{0, col});
|
| 698 |
-
};
|
| 699 |
-
|
| 700 |
-
#ifdef HAS_PYTORCH
|
| 701 |
-
curandStatePhilox4_32_10_t curand_state_init;
|
| 702 |
-
if (kSupportsDropout && p.use_dropout) {
|
| 703 |
-
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
|
| 704 |
-
|
| 705 |
-
// each element of the attention matrix P with shape
|
| 706 |
-
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
|
| 707 |
-
// offset in RNG sequence. we initialize the RNG state with offset that
|
| 708 |
-
// starts at the beginning of a (n_queries, n_keys) matrix for this
|
| 709 |
-
// block's batch_id and head_id
|
| 710 |
-
// initializing rng state is very expensive, so we run once per kernel,
|
| 711 |
-
// rather than once per iteration. each iteration takes a copy of the
|
| 712 |
-
// initialized RNG state and offsets it as needed.
|
| 713 |
-
curand_init(
|
| 714 |
-
std::get<0>(seeds),
|
| 715 |
-
0,
|
| 716 |
-
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
|
| 717 |
-
&curand_state_init);
|
| 718 |
-
}
|
| 719 |
-
#endif
|
| 720 |
-
|
| 721 |
-
// Iterate through keys
|
| 722 |
-
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
|
| 723 |
-
iter_key_start += kKeysPerBlock) {
|
| 724 |
-
int32_t problem_size_0_m =
|
| 725 |
-
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
|
| 726 |
-
int32_t problem_size_0_n = cutlass::fast_min(
|
| 727 |
-
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
|
| 728 |
-
int32_t const& problem_size_0_k = p.head_dim;
|
| 729 |
-
int32_t const& problem_size_1_n = p.head_dim_value;
|
| 730 |
-
int32_t const& problem_size_1_k = problem_size_0_n;
|
| 731 |
-
|
| 732 |
-
auto prologueV = [&](int blockN) {
|
| 733 |
-
typename MM1::Mma::IteratorB iterator_V(
|
| 734 |
-
typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
|
| 735 |
-
p.value_ptr + iter_key_start * p.v_strideM,
|
| 736 |
-
{problem_size_1_k, problem_size_1_n},
|
| 737 |
-
thread_id(),
|
| 738 |
-
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
| 739 |
-
MM1::Mma::prologue(
|
| 740 |
-
shared_storage.after_mm0.mm1,
|
| 741 |
-
iterator_V,
|
| 742 |
-
thread_id(),
|
| 743 |
-
problem_size_1_k);
|
| 744 |
-
};
|
| 745 |
-
|
| 746 |
-
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
|
| 747 |
-
// updated from end of prev iter
|
| 748 |
-
//
|
| 749 |
-
// MATMUL: Q.K_t
|
| 750 |
-
//
|
| 751 |
-
// Computes the block-matrix product of:
|
| 752 |
-
// (a) query[query_start:query_end, :]
|
| 753 |
-
// with
|
| 754 |
-
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
|
| 755 |
-
// and stores that into `shared_storage.si`
|
| 756 |
-
//
|
| 757 |
-
|
| 758 |
-
// Compute threadblock location
|
| 759 |
-
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
|
| 760 |
-
|
| 761 |
-
cutlass::MatrixCoord tb_offset_A{
|
| 762 |
-
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
|
| 763 |
-
|
| 764 |
-
cutlass::MatrixCoord tb_offset_B{
|
| 765 |
-
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
|
| 766 |
-
|
| 767 |
-
// Construct iterators to A and B operands
|
| 768 |
-
typename MM0::IteratorA iterator_A(
|
| 769 |
-
typename MM0::IteratorA::Params(
|
| 770 |
-
typename MM0::MmaCore::LayoutA(p.q_strideM)),
|
| 771 |
-
p.query_ptr,
|
| 772 |
-
{problem_size_0_m, problem_size_0_k},
|
| 773 |
-
thread_id(),
|
| 774 |
-
tb_offset_A);
|
| 775 |
-
|
| 776 |
-
typename MM0::IteratorB iterator_B(
|
| 777 |
-
typename MM0::IteratorB::Params(
|
| 778 |
-
typename MM0::MmaCore::LayoutB(p.k_strideM)),
|
| 779 |
-
p.key_ptr + iter_key_start * p.k_strideM,
|
| 780 |
-
{problem_size_0_k, problem_size_0_n},
|
| 781 |
-
thread_id(),
|
| 782 |
-
tb_offset_B);
|
| 783 |
-
|
| 784 |
-
auto my_warp_id = warp_uniform(warp_id());
|
| 785 |
-
auto my_lane_id = lane_id();
|
| 786 |
-
|
| 787 |
-
// Construct thread-scoped matrix multiply
|
| 788 |
-
typename MM0::Mma mma(
|
| 789 |
-
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
|
| 790 |
-
|
| 791 |
-
typename MM0::Mma::FragmentC accum;
|
| 792 |
-
|
| 793 |
-
accum.clear();
|
| 794 |
-
|
| 795 |
-
auto gemm_k_iterations =
|
| 796 |
-
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
|
| 797 |
-
|
| 798 |
-
// Compute threadblock-scoped matrix multiply-add
|
| 799 |
-
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
| 800 |
-
__syncthreads();
|
| 801 |
-
|
| 802 |
-
if (kPreloadV) {
|
| 803 |
-
prologueV(0);
|
| 804 |
-
} else {
|
| 805 |
-
MM1::Mma::drain_cp_asyncs();
|
| 806 |
-
}
|
| 807 |
-
|
| 808 |
-
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
| 809 |
-
iteratorC_tile_offset = {
|
| 810 |
-
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
|
| 811 |
-
(my_warp_id % MM0::Mma::WarpCount::kM),
|
| 812 |
-
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
|
| 813 |
-
(my_warp_id / MM0::Mma::WarpCount::kM)};
|
| 814 |
-
|
| 815 |
-
// multiply by scaling factor
|
| 816 |
-
if (kSupportsBias) {
|
| 817 |
-
accum =
|
| 818 |
-
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
|
| 819 |
-
}
|
| 820 |
-
|
| 821 |
-
// apply attention bias if applicable
|
| 822 |
-
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
|
| 823 |
-
// load bias tile Bij into shared memory
|
| 824 |
-
typename MM0::BiasLoader::GmemTileIterator bias_iter(
|
| 825 |
-
{cutlass::layout::RowMajor(p.bias_strideM)},
|
| 826 |
-
// attn_bias_pointer points to matrix of size (n_queries, n_keys)
|
| 827 |
-
// for the relevant batch_id and head_id
|
| 828 |
-
p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
|
| 829 |
-
{problem_size_0_m, problem_size_0_n},
|
| 830 |
-
thread_id());
|
| 831 |
-
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
|
| 832 |
-
shared_storage.after_mm0.bias.data(),
|
| 833 |
-
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
|
| 834 |
-
typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
|
| 835 |
-
bias_tensor_ref, thread_id());
|
| 836 |
-
MM0::BiasLoader::load(bias_iter, smem_tile_iter);
|
| 837 |
-
|
| 838 |
-
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
| 839 |
-
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
| 840 |
-
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
| 841 |
-
MM0::AccumLambdaIterator::iterateRows(
|
| 842 |
-
lane_offset,
|
| 843 |
-
[&](int accum_m) {},
|
| 844 |
-
[&](int accum_m, int accum_n, int idx) {
|
| 845 |
-
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
|
| 846 |
-
accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
|
| 847 |
-
}
|
| 848 |
-
},
|
| 849 |
-
[&](int accum_m) {});
|
| 850 |
-
}
|
| 851 |
-
|
| 852 |
-
// Mask out last if causal
|
| 853 |
-
// This is only needed if upper-right corner of current query / key block
|
| 854 |
-
// intersects the mask Coordinates of upper-right corner of current block
|
| 855 |
-
// is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
|
| 856 |
-
// first masked element is x = y + offset -> query_start + offset There is
|
| 857 |
-
// intersection (and we need to mask) if min(iter_key_start +
|
| 858 |
-
// kKeysPerBlock, num_keys)) >= query_start + offset
|
| 859 |
-
if (p.custom_mask_type &&
|
| 860 |
-
cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
|
| 861 |
-
(query_start + p.causal_diagonal_offset)) {
|
| 862 |
-
auto query_start = blockIdx.x * kQueriesPerBlock;
|
| 863 |
-
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
| 864 |
-
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
| 865 |
-
int32_t last_col;
|
| 866 |
-
MM0::AccumLambdaIterator::iterateRows(
|
| 867 |
-
lane_offset,
|
| 868 |
-
[&](int accum_m) {
|
| 869 |
-
// last absolute col is (last absolute query + offset)
|
| 870 |
-
// last local col is (last absolute query + offset -
|
| 871 |
-
// iter_key_start)
|
| 872 |
-
last_col = query_start + accum_m + p.causal_diagonal_offset -
|
| 873 |
-
iter_key_start;
|
| 874 |
-
},
|
| 875 |
-
[&](int accum_m, int accum_n, int idx) {
|
| 876 |
-
if (accum_n > last_col) {
|
| 877 |
-
accum[idx] =
|
| 878 |
-
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
| 879 |
-
}
|
| 880 |
-
},
|
| 881 |
-
[&](int accum_m) {});
|
| 882 |
-
}
|
| 883 |
-
// Update `mi` from accum stored in registers
|
| 884 |
-
// Also does accum[i] <- exp(accum[i] - mi)
|
| 885 |
-
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
| 886 |
-
accum_o,
|
| 887 |
-
accum,
|
| 888 |
-
mi,
|
| 889 |
-
m_prime,
|
| 890 |
-
s_prime,
|
| 891 |
-
out_rescale,
|
| 892 |
-
shared_storage.addition_storage,
|
| 893 |
-
my_lane_id,
|
| 894 |
-
thread_id(),
|
| 895 |
-
my_warp_id,
|
| 896 |
-
p.num_keys - iter_key_start,
|
| 897 |
-
iter_key_start == 0,
|
| 898 |
-
iteratorC_tile_offset,
|
| 899 |
-
kSupportsBias ? 1.0f : p.scale);
|
| 900 |
-
|
| 901 |
-
// Output results to shared-memory
|
| 902 |
-
int warp_idx_mn_0 = my_warp_id %
|
| 903 |
-
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
|
| 904 |
-
auto output_tile_coords = cutlass::MatrixCoord{
|
| 905 |
-
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
|
| 906 |
-
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
|
| 907 |
-
|
| 908 |
-
MM0::B2bGemm::accumToSmem(
|
| 909 |
-
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
|
| 910 |
-
|
| 911 |
-
__syncthreads();
|
| 912 |
-
|
| 913 |
-
#ifdef HAS_PYTORCH
|
| 914 |
-
// apply dropout (if applicable) after we've written Pij to smem.
|
| 915 |
-
// dropout is applied by multiplying each element of Pij by:
|
| 916 |
-
// - 0 with probability dropout_p
|
| 917 |
-
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
|
| 918 |
-
//
|
| 919 |
-
// for backward purposes we want to be able to map each element of the
|
| 920 |
-
// attention matrix to the same random uniform number as the one we used
|
| 921 |
-
// in forward, without needing to use the same iteration order or having
|
| 922 |
-
// to store the dropout matrix. its possible to do this in registers but
|
| 923 |
-
// it ends up being very slow because each thread having noncontiguous
|
| 924 |
-
// strips of the Pij tile means we have to skip around a lot, and also
|
| 925 |
-
// have to generate a single random number at a time
|
| 926 |
-
if (kSupportsDropout && p.use_dropout) {
|
| 927 |
-
auto si = shared_storage.after_mm0.si.accum_ref();
|
| 928 |
-
// each thread handles a contiguous sequence of elements from Sij, all
|
| 929 |
-
// coming from the same row. the reason they have to come from the same
|
| 930 |
-
// row is that the sampling random numbers from a contiguous random
|
| 931 |
-
// number sequence is much more efficient than jumping around, and the
|
| 932 |
-
// linear offset of each element of S (the global matrix) maps to an
|
| 933 |
-
// offset in a random number sequence. for S, the end of a row and the
|
| 934 |
-
// beginning of the next have adjacent offsets, but for Sij, this is not
|
| 935 |
-
// necessarily the case.
|
| 936 |
-
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
|
| 937 |
-
const int threads_per_row =
|
| 938 |
-
cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
|
| 939 |
-
const int elts_per_thread = cutlass::round_nearest(
|
| 940 |
-
cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
|
| 941 |
-
|
| 942 |
-
const int thread_i = thread_id() / threads_per_row;
|
| 943 |
-
const int thread_start_j =
|
| 944 |
-
(thread_id() % threads_per_row) * elts_per_thread;
|
| 945 |
-
|
| 946 |
-
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
|
| 947 |
-
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
| 948 |
-
skipahead(
|
| 949 |
-
static_cast<unsigned long long>(
|
| 950 |
-
(query_start + thread_i) * p.num_keys_absolute +
|
| 951 |
-
(iter_key_start + thread_start_j)),
|
| 952 |
-
&curand_state);
|
| 953 |
-
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
| 954 |
-
|
| 955 |
-
// apply dropout scaling to elements this thread is responsible for,
|
| 956 |
-
// in chunks of 4
|
| 957 |
-
for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
|
| 958 |
-
cutlass::fast_min(thread_start_j + elts_per_thread,
|
| 959 |
-
problem_size_0_n);
|
| 960 |
-
sij_start_col_idx += 4) {
|
| 961 |
-
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
|
| 962 |
-
|
| 963 |
-
CUTLASS_PRAGMA_UNROLL
|
| 964 |
-
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
|
| 965 |
-
si.at({thread_i, sij_start_col_idx + quad_idx}) *=
|
| 966 |
-
static_cast<scalar_t>(
|
| 967 |
-
dropout_scale *
|
| 968 |
-
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
|
| 969 |
-
}
|
| 970 |
-
}
|
| 971 |
-
}
|
| 972 |
-
__syncthreads(); // p.use_dropout should have same value kernel-wide
|
| 973 |
-
}
|
| 974 |
-
#endif
|
| 975 |
-
|
| 976 |
-
//
|
| 977 |
-
// MATMUL: Attn . V
|
| 978 |
-
// Run the matmul `attn @ V` for a block of attn and V.
|
| 979 |
-
// `attn` is read from shared memory (in `shared_storage_si`)
|
| 980 |
-
// `V` is read from global memory (with iterator_B)
|
| 981 |
-
//
|
| 982 |
-
|
| 983 |
-
const int64_t nBlockN = kSingleValueIteration
|
| 984 |
-
? 1
|
| 985 |
-
: ceil_div(
|
| 986 |
-
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
|
| 987 |
-
for (int blockN = 0; blockN < nBlockN; ++blockN) {
|
| 988 |
-
int gemm_k_iterations =
|
| 989 |
-
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
|
| 990 |
-
|
| 991 |
-
// Compute threadblock-scoped matrix multiply-add and store it in accum
|
| 992 |
-
// (in registers)
|
| 993 |
-
if (!kPreloadV) {
|
| 994 |
-
__syncthreads(); // we share shmem between mma and epilogue
|
| 995 |
-
}
|
| 996 |
-
|
| 997 |
-
typename MM1::Mma::IteratorB iterator_V(
|
| 998 |
-
typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
|
| 999 |
-
p.value_ptr + iter_key_start * p.v_strideM,
|
| 1000 |
-
{problem_size_1_k, problem_size_1_n},
|
| 1001 |
-
thread_id(),
|
| 1002 |
-
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
| 1003 |
-
typename MM1::Mma mma_pv(
|
| 1004 |
-
// operand A: Pij_dropped in shared memory
|
| 1005 |
-
shared_storage.after_mm0.si.accum_ref(),
|
| 1006 |
-
// operand B: shared memory staging area for Vj, which is loaded
|
| 1007 |
-
// from global memory
|
| 1008 |
-
shared_storage.after_mm0.mm1.operand_B_ref(),
|
| 1009 |
-
(int)thread_id(),
|
| 1010 |
-
(int)my_warp_id,
|
| 1011 |
-
(int)my_lane_id);
|
| 1012 |
-
mma_pv.set_prologue_done(kPreloadV);
|
| 1013 |
-
if (!kKeepOutputInRF) {
|
| 1014 |
-
accum_o.clear();
|
| 1015 |
-
}
|
| 1016 |
-
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
|
| 1017 |
-
__syncthreads();
|
| 1018 |
-
|
| 1019 |
-
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
|
| 1020 |
-
prologueV(blockN + 1);
|
| 1021 |
-
}
|
| 1022 |
-
|
| 1023 |
-
if (!kKeepOutputInRF) {
|
| 1024 |
-
MM1::Mma::drain_cp_asyncs();
|
| 1025 |
-
DISPATCH_BOOL(
|
| 1026 |
-
iter_key_start == 0, kIsFirst, ([&] {
|
| 1027 |
-
DISPATCH_BOOL(
|
| 1028 |
-
(iter_key_start + kKeysPerBlock) >= p.num_keys,
|
| 1029 |
-
kIsLast,
|
| 1030 |
-
([&] {
|
| 1031 |
-
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
| 1032 |
-
using DefaultOp =
|
| 1033 |
-
typename MM1::DefaultConfig::EpilogueOutputOp;
|
| 1034 |
-
using ElementCompute = typename DefaultOp::ElementCompute;
|
| 1035 |
-
using EpilogueOutputOp = typename cutlass::epilogue::
|
| 1036 |
-
thread::MemoryEfficientAttentionNormalize<
|
| 1037 |
-
typename cutlass::platform::conditional<
|
| 1038 |
-
kIsLast::value,
|
| 1039 |
-
output_t,
|
| 1040 |
-
output_accum_t>::type,
|
| 1041 |
-
output_accum_t,
|
| 1042 |
-
DefaultOp::kCount,
|
| 1043 |
-
typename DefaultOp::ElementAccumulator,
|
| 1044 |
-
ElementCompute,
|
| 1045 |
-
kIsFirst::value,
|
| 1046 |
-
kIsLast::value,
|
| 1047 |
-
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
| 1048 |
-
using Epilogue = typename cutlass::epilogue::threadblock::
|
| 1049 |
-
EpiloguePipelined<
|
| 1050 |
-
typename DefaultEpilogue::Shape,
|
| 1051 |
-
typename MM1::Mma::Operator,
|
| 1052 |
-
DefaultEpilogue::kPartitionsK,
|
| 1053 |
-
typename cutlass::platform::conditional<
|
| 1054 |
-
kIsLast::value,
|
| 1055 |
-
typename MM1::OutputTileIterator,
|
| 1056 |
-
typename MM1::OutputTileIteratorAccum>::type,
|
| 1057 |
-
typename DefaultEpilogue::
|
| 1058 |
-
AccumulatorFragmentIterator,
|
| 1059 |
-
typename DefaultEpilogue::WarpTileIterator,
|
| 1060 |
-
typename DefaultEpilogue::SharedLoadIterator,
|
| 1061 |
-
EpilogueOutputOp,
|
| 1062 |
-
typename DefaultEpilogue::Padding,
|
| 1063 |
-
DefaultEpilogue::kFragmentsPerIteration,
|
| 1064 |
-
true, // IterationsUnroll
|
| 1065 |
-
typename MM1::OutputTileIteratorAccum // Read
|
| 1066 |
-
// iterator
|
| 1067 |
-
>;
|
| 1068 |
-
|
| 1069 |
-
int col = blockN * MM1::Mma::Shape::kN;
|
| 1070 |
-
auto source_iter = createOutputAccumIter(col);
|
| 1071 |
-
auto dest_iter = call_conditional<
|
| 1072 |
-
kIsLast::value,
|
| 1073 |
-
decltype(createOutputIter),
|
| 1074 |
-
decltype(createOutputAccumIter)>::
|
| 1075 |
-
apply(createOutputIter, createOutputAccumIter, col);
|
| 1076 |
-
EpilogueOutputOp rescale(s_prime, out_rescale);
|
| 1077 |
-
Epilogue epilogue(
|
| 1078 |
-
shared_storage.epilogue_shared_storage(),
|
| 1079 |
-
thread_id(),
|
| 1080 |
-
my_warp_id,
|
| 1081 |
-
my_lane_id);
|
| 1082 |
-
epilogue(rescale, dest_iter, accum_o, source_iter);
|
| 1083 |
-
}));
|
| 1084 |
-
}));
|
| 1085 |
-
if (!kSingleValueIteration) {
|
| 1086 |
-
__syncthreads();
|
| 1087 |
-
}
|
| 1088 |
-
}
|
| 1089 |
-
}
|
| 1090 |
-
__syncthreads(); // we modify `m_prime` after
|
| 1091 |
-
}
|
| 1092 |
-
|
| 1093 |
-
if (kKeepOutputInRF) {
|
| 1094 |
-
constexpr bool kIsFirst = true;
|
| 1095 |
-
constexpr bool kIsLast = true;
|
| 1096 |
-
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
| 1097 |
-
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
| 1098 |
-
using ElementCompute = typename DefaultOp::ElementCompute;
|
| 1099 |
-
using EpilogueOutputOp =
|
| 1100 |
-
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
|
| 1101 |
-
output_t, // output
|
| 1102 |
-
output_accum_t, // source
|
| 1103 |
-
DefaultOp::kCount,
|
| 1104 |
-
typename DefaultOp::ElementAccumulator, // accum
|
| 1105 |
-
output_accum_t, // compute
|
| 1106 |
-
kIsFirst,
|
| 1107 |
-
kIsLast,
|
| 1108 |
-
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
| 1109 |
-
using Epilogue =
|
| 1110 |
-
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
| 1111 |
-
typename DefaultEpilogue::Shape,
|
| 1112 |
-
typename MM1::Mma::Operator,
|
| 1113 |
-
DefaultEpilogue::kPartitionsK,
|
| 1114 |
-
typename MM1::OutputTileIterator, // destination
|
| 1115 |
-
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
| 1116 |
-
typename DefaultEpilogue::WarpTileIterator,
|
| 1117 |
-
typename DefaultEpilogue::SharedLoadIterator,
|
| 1118 |
-
EpilogueOutputOp,
|
| 1119 |
-
typename DefaultEpilogue::Padding,
|
| 1120 |
-
DefaultEpilogue::kFragmentsPerIteration,
|
| 1121 |
-
true, // IterationsUnroll
|
| 1122 |
-
typename MM1::OutputTileIteratorAccum // source tile
|
| 1123 |
-
>;
|
| 1124 |
-
auto dest_iter = createOutputIter(0);
|
| 1125 |
-
EpilogueOutputOp rescale(s_prime, out_rescale);
|
| 1126 |
-
Epilogue epilogue(
|
| 1127 |
-
shared_storage.epilogue_shared_storage(),
|
| 1128 |
-
thread_id(),
|
| 1129 |
-
warp_id(),
|
| 1130 |
-
lane_id());
|
| 1131 |
-
MM1::Mma::drain_cp_asyncs();
|
| 1132 |
-
epilogue(rescale, dest_iter, accum_o);
|
| 1133 |
-
}
|
| 1134 |
-
|
| 1135 |
-
// 7. Calculate logsumexp
|
| 1136 |
-
// To make the backward easier, we pad logsumexp with `inf`
|
| 1137 |
-
// this avoids a few bound checks, and is not more expensive during fwd
|
| 1138 |
-
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
| 1139 |
-
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
| 1140 |
-
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
| 1141 |
-
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
| 1142 |
-
if (thread_id() < p.num_queries) {
|
| 1143 |
-
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
| 1144 |
-
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
| 1145 |
-
} else if (thread_id() < lse_dim) {
|
| 1146 |
-
p.logsumexp_ptr[thread_id()] =
|
| 1147 |
-
cutlass::platform::numeric_limits<accum_t>::infinity();
|
| 1148 |
-
}
|
| 1149 |
-
}
|
| 1150 |
-
}
|
| 1151 |
-
|
| 1152 |
-
template <typename WarpIteratorC>
|
| 1153 |
-
CUTLASS_DEVICE static void iterative_softmax(
|
| 1154 |
-
typename WarpIteratorC::Fragment& frag_o, // output so far
|
| 1155 |
-
typename WarpIteratorC::Fragment& frag,
|
| 1156 |
-
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
| 1157 |
-
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
| 1158 |
-
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
| 1159 |
-
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
| 1160 |
-
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
| 1161 |
-
addition_storage,
|
| 1162 |
-
int8_t lane_id,
|
| 1163 |
-
int8_t thread_id,
|
| 1164 |
-
int8_t warp_id,
|
| 1165 |
-
int max_col,
|
| 1166 |
-
bool is_first,
|
| 1167 |
-
typename WarpIteratorC::TensorCoord const& tile_offset,
|
| 1168 |
-
float scaling) {
|
| 1169 |
-
/* Iterates on the accumulator and corresponding position on result matrix
|
| 1170 |
-
|
| 1171 |
-
(1) Update `mi[r]` to the max value of the row `r`
|
| 1172 |
-
(2) In a second iteration do the following:
|
| 1173 |
-
(a) accum <- exp(accum - mi)
|
| 1174 |
-
(b) m_prime <- exp(m_prime - mi)
|
| 1175 |
-
(c) s_prime <- s_prime * m_prime + sum(accum)
|
| 1176 |
-
|
| 1177 |
-
All of this is done on registers, before we store all of this
|
| 1178 |
-
on shared memory for the next matmul with Value.
|
| 1179 |
-
*/
|
| 1180 |
-
using Fragment = typename WarpIteratorC::Fragment;
|
| 1181 |
-
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
| 1182 |
-
WarpIteratorC,
|
| 1183 |
-
accum_t,
|
| 1184 |
-
kWarpSize>::Iterator;
|
| 1185 |
-
// Convert to `accum_t` (rather than double)
|
| 1186 |
-
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
| 1187 |
-
|
| 1188 |
-
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
| 1189 |
-
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
| 1190 |
-
|
| 1191 |
-
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
| 1192 |
-
|
| 1193 |
-
auto lane_offset =
|
| 1194 |
-
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
| 1195 |
-
|
| 1196 |
-
// First update `mi` to the max per-row
|
| 1197 |
-
{
|
| 1198 |
-
accum_t max;
|
| 1199 |
-
LambdaIterator::iterateRows(
|
| 1200 |
-
lane_offset,
|
| 1201 |
-
[&](int accum_m) {
|
| 1202 |
-
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
| 1203 |
-
},
|
| 1204 |
-
[&](int accum_m, int accum_n, int idx) {
|
| 1205 |
-
if (accum_n < max_col) {
|
| 1206 |
-
max = cutlass::fast_max(max, frag[idx]);
|
| 1207 |
-
}
|
| 1208 |
-
},
|
| 1209 |
-
[&](int accum_m) {
|
| 1210 |
-
// Having 4x atomicMax seems faster than reduce within warp
|
| 1211 |
-
// first...
|
| 1212 |
-
atomicMaxFloat(&mi[accum_m], max);
|
| 1213 |
-
});
|
| 1214 |
-
}
|
| 1215 |
-
|
| 1216 |
-
// Make sure we all share the update values for `mi`
|
| 1217 |
-
__syncthreads();
|
| 1218 |
-
|
| 1219 |
-
// Doing this `exp` is quite expensive. Let's
|
| 1220 |
-
// split it across the warps
|
| 1221 |
-
bool restore_mi_to_minus_inf = false;
|
| 1222 |
-
if (lane_id < kLinesPerWarp) {
|
| 1223 |
-
int id = warp_id * kLinesPerWarp + lane_id;
|
| 1224 |
-
auto m_prime_id = m_prime[id];
|
| 1225 |
-
auto mi_id = mi[id];
|
| 1226 |
-
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
| 1227 |
-
if (changed) {
|
| 1228 |
-
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
| 1229 |
-
out_rescale[id] = m_prime_exp;
|
| 1230 |
-
s_prime[id] *= m_prime_exp;
|
| 1231 |
-
} else {
|
| 1232 |
-
// Only when bias is enabled, it's possible that all the first values
|
| 1233 |
-
// of attention are masked to `-inf`. In that case we want to avoid
|
| 1234 |
-
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
| 1235 |
-
if (kSupportsBias &&
|
| 1236 |
-
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
| 1237 |
-
restore_mi_to_minus_inf = true;
|
| 1238 |
-
mi[id] = 0.0f;
|
| 1239 |
-
}
|
| 1240 |
-
out_rescale[id] = 1.0f;
|
| 1241 |
-
}
|
| 1242 |
-
}
|
| 1243 |
-
__syncthreads(); // Update output fragments
|
| 1244 |
-
if (kKeepOutputInRF && !is_first) {
|
| 1245 |
-
accum_t line_rescale;
|
| 1246 |
-
LambdaIterator::iterateRows(
|
| 1247 |
-
lane_offset,
|
| 1248 |
-
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
| 1249 |
-
[&](int accum_m, int accum_n, int idx) {
|
| 1250 |
-
frag_o[idx] = frag_o[idx] * line_rescale;
|
| 1251 |
-
},
|
| 1252 |
-
[&](int accum_m) {});
|
| 1253 |
-
}
|
| 1254 |
-
// Update accum_m, accum_n, ...
|
| 1255 |
-
{
|
| 1256 |
-
accum_t mi_row, total_row;
|
| 1257 |
-
LambdaIterator::iterateRows(
|
| 1258 |
-
lane_offset,
|
| 1259 |
-
[&](int accum_m) { mi_row = mi[accum_m]; },
|
| 1260 |
-
[&](int accum_m, int accum_n, int idx) {
|
| 1261 |
-
frag[idx] =
|
| 1262 |
-
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
| 1263 |
-
},
|
| 1264 |
-
[&](int accum_m) {});
|
| 1265 |
-
LambdaIterator::iterateRows(
|
| 1266 |
-
lane_offset,
|
| 1267 |
-
[&](int accum_m) { total_row = 0.0; },
|
| 1268 |
-
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
|
| 1269 |
-
[&](int accum_m) {
|
| 1270 |
-
if (LambdaIterator::reduceSameRow(
|
| 1271 |
-
lane_id, total_row, [](accum_t a, accum_t b) {
|
| 1272 |
-
return a + b;
|
| 1273 |
-
})) {
|
| 1274 |
-
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
| 1275 |
-
// it's faster (and deterministic) to avoid atomics here
|
| 1276 |
-
addition_storage
|
| 1277 |
-
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
| 1278 |
-
total_row;
|
| 1279 |
-
}
|
| 1280 |
-
});
|
| 1281 |
-
}
|
| 1282 |
-
__syncthreads();
|
| 1283 |
-
if (lane_id < kLinesPerWarp) {
|
| 1284 |
-
int id = warp_id * kLinesPerWarp + lane_id;
|
| 1285 |
-
accum_t total_row = s_prime[id];
|
| 1286 |
-
if (restore_mi_to_minus_inf) {
|
| 1287 |
-
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
| 1288 |
-
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
| 1289 |
-
} else {
|
| 1290 |
-
m_prime[id] = mi[id];
|
| 1291 |
-
}
|
| 1292 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1293 |
-
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
| 1294 |
-
total_row += addition_storage[id + kQueriesPerBlock * i];
|
| 1295 |
-
}
|
| 1296 |
-
s_prime[id] = total_row;
|
| 1297 |
-
}
|
| 1298 |
-
}
|
| 1299 |
-
|
| 1300 |
-
static CUTLASS_DEVICE int8_t lane_id() {
|
| 1301 |
-
return threadIdx.x;
|
| 1302 |
-
}
|
| 1303 |
-
static CUTLASS_DEVICE int8_t warp_id() {
|
| 1304 |
-
return threadIdx.y;
|
| 1305 |
-
}
|
| 1306 |
-
static CUTLASS_DEVICE int16_t thread_id() {
|
| 1307 |
-
return threadIdx.x + threadIdx.y * blockDim.x;
|
| 1308 |
-
}
|
| 1309 |
-
};
|
| 1310 |
-
|
| 1311 |
-
template <typename AK>
|
| 1312 |
-
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
| 1313 |
-
attention_kernel_batched_impl(typename AK::Params p) {
|
| 1314 |
-
if (!p.advance_to_block()) {
|
| 1315 |
-
return;
|
| 1316 |
-
}
|
| 1317 |
-
AK::attention_kernel(p);
|
| 1318 |
-
}
|
| 1319 |
-
|
| 1320 |
-
template <typename AK>
|
| 1321 |
-
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
| 1322 |
-
attention_kernel_batched(typename AK::Params params);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from typing import List
|
| 34 |
-
import torch
|
| 35 |
-
import subprocess
|
| 36 |
-
import sys
|
| 37 |
-
import tempfile
|
| 38 |
-
import os
|
| 39 |
-
import numpy as np
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
TORCH_DTYPE_NAME = {
|
| 43 |
-
torch.float32: "f32",
|
| 44 |
-
torch.float16: "f16",
|
| 45 |
-
torch.bfloat16: "b16"
|
| 46 |
-
}
|
| 47 |
-
NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
|
| 48 |
-
|
| 49 |
-
def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
|
| 50 |
-
# PyTorch >= 2.0
|
| 51 |
-
if hasattr(tensor, 'untyped_storage'):
|
| 52 |
-
return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
|
| 53 |
-
return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
|
| 54 |
-
|
| 55 |
-
class PipedSubprocess:
|
| 56 |
-
def __init__(self, binary: str) -> None:
|
| 57 |
-
self.binary = binary
|
| 58 |
-
self.tempdir_ctx = tempfile.TemporaryDirectory()
|
| 59 |
-
|
| 60 |
-
def __enter__(self) -> "PipedSubprocess":
|
| 61 |
-
self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
|
| 62 |
-
self.tempdir = self.tempdir_ctx.__enter__()
|
| 63 |
-
self.file_counter = 0
|
| 64 |
-
return self
|
| 65 |
-
|
| 66 |
-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 67 |
-
self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
|
| 68 |
-
|
| 69 |
-
def temp_filename(self, suffix: str) -> str:
|
| 70 |
-
self.file_counter += 1
|
| 71 |
-
return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
|
| 72 |
-
|
| 73 |
-
def write(self, *args) -> None:
|
| 74 |
-
for a in args:
|
| 75 |
-
self.subp.stdin.write(str(a) + " ")
|
| 76 |
-
|
| 77 |
-
def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
|
| 78 |
-
print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
|
| 79 |
-
tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
|
| 80 |
-
self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
|
| 81 |
-
filename = self.temp_filename(f"{name}.tensor")
|
| 82 |
-
assert tensor.storage_offset() == 0
|
| 83 |
-
with open(filename, "wb+") as fd:
|
| 84 |
-
fd.write(bytes(tensor_u8.numpy()))
|
| 85 |
-
self.write("file", filename)
|
| 86 |
-
self.write("tensor_end")
|
| 87 |
-
|
| 88 |
-
for stride_name, stride_value in zip(stride_names, tensor.stride()):
|
| 89 |
-
self.write(stride_name, stride_value)
|
| 90 |
-
|
| 91 |
-
def readTensor(self, name, stride_name, shape) -> torch.Tensor:
|
| 92 |
-
tmpfile = self.temp_filename(f"{name}.tensor")
|
| 93 |
-
self.write("tmpfile", tmpfile)
|
| 94 |
-
|
| 95 |
-
self.readExpect("tensor_begin")
|
| 96 |
-
dtype_str, name = self.read().split(":")
|
| 97 |
-
print(f"C++->Py : {dtype_str}:{name}")
|
| 98 |
-
u8len = int(self.read())
|
| 99 |
-
dtype = NAME_TORCH_DTYPE[dtype_str]
|
| 100 |
-
|
| 101 |
-
self.readExpect("file")
|
| 102 |
-
self.readExpect(tmpfile)
|
| 103 |
-
|
| 104 |
-
with open(tmpfile, "rb") as fd:
|
| 105 |
-
data = fd.read(u8len)
|
| 106 |
-
# `np.array` is not strictly needed, but avoids a torch warning
|
| 107 |
-
tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
|
| 108 |
-
self.readExpect("tensor_end")
|
| 109 |
-
|
| 110 |
-
tensor = _tensor_from_storage(tensor_u8, dtype)
|
| 111 |
-
strides = []
|
| 112 |
-
for sn in stride_name:
|
| 113 |
-
self.readExpect(sn)
|
| 114 |
-
strides.append(int(self.read()))
|
| 115 |
-
if len(strides) != shape:
|
| 116 |
-
strides.append(1)
|
| 117 |
-
assert len(strides) == len(shape), name
|
| 118 |
-
return torch.as_strided(tensor, shape, strides)
|
| 119 |
-
|
| 120 |
-
def readNamed(self, name: str):
|
| 121 |
-
self.readExpect(name)
|
| 122 |
-
return self.read()
|
| 123 |
-
|
| 124 |
-
def readExpect(self, what: str) -> None:
|
| 125 |
-
r = self.read()
|
| 126 |
-
if r != what:
|
| 127 |
-
raise ValueError(f"Read {r} but expected {what}")
|
| 128 |
-
|
| 129 |
-
def read(self):
|
| 130 |
-
read_all = []
|
| 131 |
-
# Skip initial whitespace
|
| 132 |
-
while True:
|
| 133 |
-
r = self.subp.stdout.read(1)
|
| 134 |
-
if r not in [' ', "\n"]:
|
| 135 |
-
read_all.append(r)
|
| 136 |
-
break
|
| 137 |
-
# Read data
|
| 138 |
-
while True:
|
| 139 |
-
r = self.subp.stdout.read(1)
|
| 140 |
-
if r in [' ', "\n"]:
|
| 141 |
-
break
|
| 142 |
-
read_all.append(r)
|
| 143 |
-
return ''.join(read_all)
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <cutlass/cutlass.h>
|
| 35 |
-
#include "cutlass/aligned_buffer.h"
|
| 36 |
-
#include "cutlass/array.h"
|
| 37 |
-
#include "cutlass/layout/matrix.h"
|
| 38 |
-
#include "cutlass/layout/pitch_linear.h"
|
| 39 |
-
#include "cutlass/numeric_types.h"
|
| 40 |
-
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 41 |
-
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
| 42 |
-
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 43 |
-
|
| 44 |
-
template <
|
| 45 |
-
typename scalar_t, // scalar type
|
| 46 |
-
typename ThreadblockTileShape, // size of tile to load
|
| 47 |
-
int Threads, // number of participating threads
|
| 48 |
-
int ElementsPerAccess> // thread access width in elements
|
| 49 |
-
class TileSmemLoader {
|
| 50 |
-
public:
|
| 51 |
-
using SmemTile =
|
| 52 |
-
cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
|
| 53 |
-
|
| 54 |
-
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
|
| 55 |
-
cutlass::layout::PitchLinearShape<
|
| 56 |
-
ThreadblockTileShape::kColumn, // contiguous
|
| 57 |
-
ThreadblockTileShape::kRow>, // strided
|
| 58 |
-
Threads, // Threads
|
| 59 |
-
ElementsPerAccess>; // ElementsPerAccess
|
| 60 |
-
|
| 61 |
-
using GmemTileIterator =
|
| 62 |
-
cutlass::transform::threadblock::PredicatedTileIterator<
|
| 63 |
-
ThreadblockTileShape, // Shape
|
| 64 |
-
scalar_t, // Element
|
| 65 |
-
cutlass::layout::RowMajor, // Layout
|
| 66 |
-
0, // AdvanceRank
|
| 67 |
-
ThreadMap>; // ThreadMap
|
| 68 |
-
|
| 69 |
-
using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
|
| 70 |
-
ThreadblockTileShape, // Shape
|
| 71 |
-
scalar_t, // Element
|
| 72 |
-
cutlass::layout::RowMajor, // Layout
|
| 73 |
-
0, // AdvanceRank
|
| 74 |
-
ThreadMap>; // ThreadMap
|
| 75 |
-
|
| 76 |
-
using Fragment = typename GmemTileIterator::Fragment;
|
| 77 |
-
|
| 78 |
-
/// load a tile from global memory into shared memory
|
| 79 |
-
CUTLASS_DEVICE
|
| 80 |
-
static void load(
|
| 81 |
-
GmemTileIterator tile_load_iter,
|
| 82 |
-
SmemTileIterator tile_store_iter) {
|
| 83 |
-
Fragment tb_frag;
|
| 84 |
-
tb_frag.clear();
|
| 85 |
-
tile_load_iter.load(tb_frag);
|
| 86 |
-
tile_store_iter.store(tb_frag);
|
| 87 |
-
|
| 88 |
-
__syncthreads();
|
| 89 |
-
}
|
| 90 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h
DELETED
|
@@ -1,154 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 34 |
-
|
| 35 |
-
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 36 |
-
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 37 |
-
|
| 38 |
-
*/
|
| 39 |
-
|
| 40 |
-
#pragma once
|
| 41 |
-
|
| 42 |
-
#include "cutlass/cutlass.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
#include "cutlass/array.h"
|
| 45 |
-
|
| 46 |
-
#include "cutlass/gemm/gemm.h"
|
| 47 |
-
|
| 48 |
-
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 49 |
-
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 50 |
-
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 51 |
-
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 52 |
-
|
| 53 |
-
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 54 |
-
|
| 55 |
-
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
| 56 |
-
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
| 57 |
-
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 58 |
-
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
| 59 |
-
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
| 60 |
-
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 61 |
-
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 62 |
-
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
| 63 |
-
|
| 64 |
-
// #include "cutlass/epilogue/threadblock/epilogue.h"
|
| 65 |
-
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
| 66 |
-
|
| 67 |
-
#include "fused_bias_act_epilogue.h"
|
| 68 |
-
#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h"
|
| 69 |
-
#include "output_tile_thread_map_for_fused_bias.h"
|
| 70 |
-
#include "default_thread_map_tensor_op_for_fused_bias.h"
|
| 71 |
-
|
| 72 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 73 |
-
|
| 74 |
-
namespace cutlass {
|
| 75 |
-
namespace epilogue {
|
| 76 |
-
namespace threadblock {
|
| 77 |
-
|
| 78 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 82 |
-
|
| 83 |
-
/// Defines sensible defaults for epilogues for TensorOps.
|
| 84 |
-
template <
|
| 85 |
-
typename Shape_,
|
| 86 |
-
typename WarpMmaTensorOp_,
|
| 87 |
-
int PartitionsK,
|
| 88 |
-
typename OutputOp_,
|
| 89 |
-
int ElementsPerAccess
|
| 90 |
-
>
|
| 91 |
-
struct DefaultFusedBiasActEpilogueTensorOp {
|
| 92 |
-
|
| 93 |
-
using Shape = Shape_;
|
| 94 |
-
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
| 95 |
-
static int const kPartitionsK = PartitionsK;
|
| 96 |
-
using OutputOp = OutputOp_;
|
| 97 |
-
static int const kElementsPerAccess = ElementsPerAccess;
|
| 98 |
-
using ElementOutput = typename OutputOp::ElementOutput;
|
| 99 |
-
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
| 100 |
-
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
| 101 |
-
|
| 102 |
-
//
|
| 103 |
-
// Thread map
|
| 104 |
-
//
|
| 105 |
-
|
| 106 |
-
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias<
|
| 107 |
-
Shape,
|
| 108 |
-
typename WarpMmaTensorOp::Shape,
|
| 109 |
-
kPartitionsK,
|
| 110 |
-
ElementOutput,
|
| 111 |
-
kElementsPerAccess
|
| 112 |
-
>::Type;
|
| 113 |
-
|
| 114 |
-
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
| 115 |
-
OutputTileThreadMap,
|
| 116 |
-
ElementOutput
|
| 117 |
-
>;
|
| 118 |
-
|
| 119 |
-
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
| 120 |
-
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
| 121 |
-
typename WarpMmaTensorOp::Shape,
|
| 122 |
-
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 123 |
-
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 124 |
-
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 125 |
-
LayoutC>,
|
| 126 |
-
cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp<
|
| 127 |
-
typename WarpMmaTensorOp::Shape,
|
| 128 |
-
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
| 129 |
-
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
| 130 |
-
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
| 131 |
-
LayoutC> >::type;
|
| 132 |
-
|
| 133 |
-
//
|
| 134 |
-
// Define the epilogue
|
| 135 |
-
//
|
| 136 |
-
using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue<
|
| 137 |
-
Shape,
|
| 138 |
-
WarpMmaTensorOp,
|
| 139 |
-
kPartitionsK,
|
| 140 |
-
OutputTileIterator,
|
| 141 |
-
AccumulatorFragmentIterator,
|
| 142 |
-
OutputOp
|
| 143 |
-
>;
|
| 144 |
-
};
|
| 145 |
-
|
| 146 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 147 |
-
|
| 148 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 149 |
-
|
| 150 |
-
} // namespace threadblock
|
| 151 |
-
} // namespace epilogue
|
| 152 |
-
} // namespace cutlass
|
| 153 |
-
|
| 154 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief
|
| 34 |
-
|
| 35 |
-
*/
|
| 36 |
-
|
| 37 |
-
#pragma once
|
| 38 |
-
|
| 39 |
-
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 40 |
-
#include "cutlass/gemm/gemm.h"
|
| 41 |
-
#include "cutlass/layout/pitch_linear.h"
|
| 42 |
-
|
| 43 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 44 |
-
|
| 45 |
-
namespace cutlass {
|
| 46 |
-
namespace epilogue {
|
| 47 |
-
namespace threadblock {
|
| 48 |
-
|
| 49 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
-
|
| 51 |
-
/// Defines the optimal thread map for TensorOp accumulator layouts
|
| 52 |
-
template <
|
| 53 |
-
typename ThreadblockShape_,
|
| 54 |
-
typename WarpShape_,
|
| 55 |
-
int PartitionsK,
|
| 56 |
-
typename Element_,
|
| 57 |
-
int ElementsPerAccess
|
| 58 |
-
>
|
| 59 |
-
struct DefaultThreadMapTensorOpForFusedBias {
|
| 60 |
-
|
| 61 |
-
using ThreadblockShape = ThreadblockShape_;
|
| 62 |
-
using WarpShape = WarpShape_;
|
| 63 |
-
static int const kPartitionsK = PartitionsK;
|
| 64 |
-
using Element = Element_;
|
| 65 |
-
static int const kElementsPerAccess = ElementsPerAccess;
|
| 66 |
-
|
| 67 |
-
//
|
| 68 |
-
// Definitions
|
| 69 |
-
//
|
| 70 |
-
|
| 71 |
-
struct Detail {
|
| 72 |
-
|
| 73 |
-
/// Tensor Operations fundamentally perform operations on 8 rows
|
| 74 |
-
static int const kTensorOpRows = 8;
|
| 75 |
-
static int const kWarpSize = 32;
|
| 76 |
-
|
| 77 |
-
static_assert(
|
| 78 |
-
!(ThreadblockShape::kM % WarpShape::kM) &&
|
| 79 |
-
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
| 80 |
-
|
| 81 |
-
/// Number of warps
|
| 82 |
-
using WarpCount = gemm::GemmShape<
|
| 83 |
-
ThreadblockShape::kM / WarpShape::kM,
|
| 84 |
-
ThreadblockShape::kN / WarpShape::kN,
|
| 85 |
-
kPartitionsK
|
| 86 |
-
>;
|
| 87 |
-
|
| 88 |
-
/// Number of participating threads
|
| 89 |
-
static int const kThreads = WarpCount::kCount * kWarpSize;
|
| 90 |
-
};
|
| 91 |
-
|
| 92 |
-
//
|
| 93 |
-
// ThreadMap
|
| 94 |
-
//
|
| 95 |
-
|
| 96 |
-
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
| 97 |
-
using Type = OutputTileOptimalThreadMapBiasAct <
|
| 98 |
-
OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
|
| 99 |
-
OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
|
| 100 |
-
Detail::kThreads,
|
| 101 |
-
kElementsPerAccess,
|
| 102 |
-
sizeof_bits<Element>::value
|
| 103 |
-
>;
|
| 104 |
-
};
|
| 105 |
-
|
| 106 |
-
///////////////////////////////////////////////////////////////////////////////
|
| 107 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 108 |
-
|
| 109 |
-
} // namespace threadblock
|
| 110 |
-
} // namespace epilogue
|
| 111 |
-
} // namespace cutlass
|
| 112 |
-
|
| 113 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 34 |
-
|
| 35 |
-
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 36 |
-
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 37 |
-
|
| 38 |
-
*/
|
| 39 |
-
|
| 40 |
-
#pragma once
|
| 41 |
-
#include "cutlass/cutlass.h"
|
| 42 |
-
#include CUDA_STD_HEADER(cassert)
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
#include "cutlass/array.h"
|
| 45 |
-
#include "cutlass/layout/vector.h"
|
| 46 |
-
#include "cutlass/layout/tensor.h"
|
| 47 |
-
#include "cutlass/tensor_coord.h"
|
| 48 |
-
#include "cutlass/aligned_buffer.h"
|
| 49 |
-
#include "cutlass/functional.h"
|
| 50 |
-
#include "cutlass/gemm/gemm.h"
|
| 51 |
-
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 52 |
-
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 53 |
-
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 54 |
-
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 55 |
-
|
| 56 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 57 |
-
|
| 58 |
-
namespace cutlass {
|
| 59 |
-
namespace epilogue {
|
| 60 |
-
namespace threadblock {
|
| 61 |
-
|
| 62 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 63 |
-
|
| 64 |
-
/// Epilogue operator without splitk
|
| 65 |
-
template <
|
| 66 |
-
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 67 |
-
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 68 |
-
int PartitionsK, ///< Number of partitions of the K dimension
|
| 69 |
-
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 70 |
-
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 71 |
-
typename OutputOp_ ///< Output operator
|
| 72 |
-
>
|
| 73 |
-
class FusedBiasActEpilogue {
|
| 74 |
-
|
| 75 |
-
public:
|
| 76 |
-
|
| 77 |
-
using Shape = Shape_;
|
| 78 |
-
using WarpMmaOperator = WarpMmaOperator_;
|
| 79 |
-
static int const kPartitionsK = PartitionsK;
|
| 80 |
-
using OutputTileIterator = OutputTileIterator_;
|
| 81 |
-
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 82 |
-
using OutputOp = OutputOp_;
|
| 83 |
-
|
| 84 |
-
/// Output layout is always row-major
|
| 85 |
-
using Layout = layout::RowMajor;
|
| 86 |
-
using LongIndex = typename Layout::LongIndex;
|
| 87 |
-
|
| 88 |
-
/// The complete warp-level accumulator tile
|
| 89 |
-
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
| 90 |
-
|
| 91 |
-
/// Output element
|
| 92 |
-
using ElementOutput = typename OutputTileIterator::Element;
|
| 93 |
-
|
| 94 |
-
/// Output access size
|
| 95 |
-
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
public:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 102 |
-
|
| 103 |
-
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 104 |
-
"Divisibility");
|
| 105 |
-
|
| 106 |
-
public:
|
| 107 |
-
|
| 108 |
-
/// Constructor
|
| 109 |
-
CUTLASS_DEVICE
|
| 110 |
-
FusedBiasActEpilogue(
|
| 111 |
-
){ }
|
| 112 |
-
|
| 113 |
-
/// Streams the result to global memory
|
| 114 |
-
CUTLASS_DEVICE
|
| 115 |
-
void operator()(
|
| 116 |
-
OutputOp const &output_op, ///< Output operator
|
| 117 |
-
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
| 118 |
-
AccumulatorTile & fused_bias_act_accumlators,
|
| 119 |
-
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 120 |
-
|
| 121 |
-
bool need_bias = output_op.is_source_needed();
|
| 122 |
-
|
| 123 |
-
if (need_bias)
|
| 124 |
-
compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator);
|
| 125 |
-
else
|
| 126 |
-
compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
CUTLASS_DEVICE
|
| 132 |
-
void operator()(
|
| 133 |
-
OutputOp const &output_op, ///< Output operator
|
| 134 |
-
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
| 135 |
-
AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 136 |
-
|
| 137 |
-
compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
CUTLASS_DEVICE
|
| 141 |
-
void compute_source_needed_(
|
| 142 |
-
OutputOp const &output_op, ///< Output operator
|
| 143 |
-
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
| 144 |
-
AccumulatorTile & fused_bias_act_accumlators,
|
| 145 |
-
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 146 |
-
|
| 147 |
-
typename OutputTileIterator::Fragment source_fragment;
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
source_fragment.clear();
|
| 151 |
-
|
| 152 |
-
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 153 |
-
AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
|
| 154 |
-
|
| 155 |
-
CUTLASS_PRAGMA_UNROLL
|
| 156 |
-
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
| 157 |
-
|
| 158 |
-
source_iterator.load(source_fragment);
|
| 159 |
-
++source_iterator;
|
| 160 |
-
|
| 161 |
-
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 162 |
-
|
| 163 |
-
accum_fragment_iterator.load(accum_fragment);
|
| 164 |
-
++accum_fragment_iterator;
|
| 165 |
-
|
| 166 |
-
typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
|
| 167 |
-
fused_bias_act_fragment = output_op(accum_fragment, source_fragment);
|
| 168 |
-
|
| 169 |
-
fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
|
| 170 |
-
++fused_bias_act_fragment_iterator;
|
| 171 |
-
}
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
CUTLASS_DEVICE
|
| 175 |
-
void compute_source_no_needed_(
|
| 176 |
-
OutputOp const &output_op, ///< Output operator
|
| 177 |
-
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
| 178 |
-
AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
| 182 |
-
AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
CUTLASS_PRAGMA_UNROLL
|
| 187 |
-
for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) {
|
| 188 |
-
|
| 189 |
-
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 190 |
-
|
| 191 |
-
accum_fragment_iterator.load(accum_fragment);
|
| 192 |
-
++accum_fragment_iterator;
|
| 193 |
-
|
| 194 |
-
typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
|
| 195 |
-
fused_bias_act_fragment = output_op(accum_fragment);
|
| 196 |
-
|
| 197 |
-
fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
|
| 198 |
-
++fused_bias_act_fragment_iterator;
|
| 199 |
-
}
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
};
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 208 |
-
|
| 209 |
-
} // namespace threadblock
|
| 210 |
-
} // namespace epilogue
|
| 211 |
-
} // namespace cutlass
|
| 212 |
-
|
| 213 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h
DELETED
|
@@ -1,311 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles.
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
*/
|
| 37 |
-
|
| 38 |
-
#pragma once
|
| 39 |
-
|
| 40 |
-
#include "cutlass/cutlass.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/array.h"
|
| 43 |
-
#include "cutlass/layout/matrix.h"
|
| 44 |
-
#include "cutlass/matrix_shape.h"
|
| 45 |
-
#include "cutlass/tensor_ref.h"
|
| 46 |
-
#include "cutlass/fast_math.h"
|
| 47 |
-
|
| 48 |
-
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
| 49 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
-
|
| 51 |
-
namespace cutlass {
|
| 52 |
-
namespace epilogue {
|
| 53 |
-
namespace threadblock {
|
| 54 |
-
|
| 55 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
-
|
| 57 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 58 |
-
|
| 59 |
-
namespace detail {
|
| 60 |
-
|
| 61 |
-
/// RowArrangement determines how one or more warps cover a region of consecutive rows.
|
| 62 |
-
template <
|
| 63 |
-
typename Shape,
|
| 64 |
-
int WarpsRemaining,
|
| 65 |
-
int ElementsPerAccess,
|
| 66 |
-
int ElementSize,
|
| 67 |
-
bool Is2dTile
|
| 68 |
-
>
|
| 69 |
-
struct RowArrangementBiasAct;
|
| 70 |
-
|
| 71 |
-
/// RowArrangement in which each warp's access is a 1D tiled arrangement.
|
| 72 |
-
template <
|
| 73 |
-
typename Shape,
|
| 74 |
-
int WarpsRemaining,
|
| 75 |
-
int ElementsPerAccess,
|
| 76 |
-
int ElementSize
|
| 77 |
-
>
|
| 78 |
-
struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
|
| 79 |
-
static int const kWarpSize = 32;
|
| 80 |
-
static int const kElementsPerAccess = ElementsPerAccess;
|
| 81 |
-
static int const kElementSize = ElementSize;
|
| 82 |
-
|
| 83 |
-
static int const kIterationsRow = 1;
|
| 84 |
-
static int const kDeltaRow = 1;
|
| 85 |
-
static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
|
| 86 |
-
static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
|
| 87 |
-
|
| 88 |
-
static int const kAccessWidth = kWarpSize;
|
| 89 |
-
static int const kAccessRows = 1;
|
| 90 |
-
static int const kWarpPartitionsRow = 1;
|
| 91 |
-
static int const kWarpPartitionsColumn = WarpsRemaining;
|
| 92 |
-
};
|
| 93 |
-
|
| 94 |
-
/// RowArrangement in which each warp's access is a 2D tiled arrangement.
|
| 95 |
-
template <
|
| 96 |
-
typename Shape,
|
| 97 |
-
int WarpsRemaining,
|
| 98 |
-
int ElementsPerAccess,
|
| 99 |
-
int ElementSize
|
| 100 |
-
>
|
| 101 |
-
struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
|
| 102 |
-
|
| 103 |
-
static int const kMemoryAccessSize = 4;//128;
|
| 104 |
-
static int const kWarpSize = 32;
|
| 105 |
-
|
| 106 |
-
static int const kElementsPerAccess = ElementsPerAccess;
|
| 107 |
-
static int const kElementSize = ElementSize;
|
| 108 |
-
|
| 109 |
-
struct Detail {
|
| 110 |
-
static int const kShapeRow = Shape::kRow / WarpsRemaining;
|
| 111 |
-
static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
|
| 112 |
-
|
| 113 |
-
static int const kTargetMemoryAccessWidth =
|
| 114 |
-
kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
|
| 115 |
-
|
| 116 |
-
static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
|
| 117 |
-
};
|
| 118 |
-
|
| 119 |
-
static int const kAccessWidth =
|
| 120 |
-
(Detail::kTargetAccessRows > Detail::kShapeRow ?
|
| 121 |
-
kWarpSize / Detail::kShapeRow
|
| 122 |
-
: const_min(
|
| 123 |
-
Detail::kShapeWidth,
|
| 124 |
-
const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
|
| 125 |
-
));
|
| 126 |
-
|
| 127 |
-
static int const kAccessRows =
|
| 128 |
-
(Detail::kTargetAccessRows > Detail::kShapeRow ?
|
| 129 |
-
Detail::kShapeRow
|
| 130 |
-
: const_min(Shape::kRow, kWarpSize / kAccessWidth));
|
| 131 |
-
|
| 132 |
-
static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
|
| 133 |
-
static int const kDeltaRow = kAccessRows;
|
| 134 |
-
|
| 135 |
-
static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
|
| 136 |
-
static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
|
| 137 |
-
|
| 138 |
-
static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");
|
| 139 |
-
static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );
|
| 140 |
-
static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );
|
| 141 |
-
|
| 142 |
-
static int const kWarpPartitionsRow = 1;
|
| 143 |
-
static int const kWarpPartitionsColumn = 1;
|
| 144 |
-
};
|
| 145 |
-
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 149 |
-
|
| 150 |
-
/// Template metaprogram for partitioning a 4D space across warps to achieve several performance
|
| 151 |
-
/// objectives:
|
| 152 |
-
///
|
| 153 |
-
/// - coalesced memory accesses in units of 16 Byte lines
|
| 154 |
-
/// - minimal address arithmetic
|
| 155 |
-
/// - minimal predicate calculations
|
| 156 |
-
///
|
| 157 |
-
template <
|
| 158 |
-
typename Shape_,
|
| 159 |
-
typename Count_,
|
| 160 |
-
int Threads,
|
| 161 |
-
int ElementsPerAccess,
|
| 162 |
-
int ElementSize
|
| 163 |
-
>
|
| 164 |
-
struct OutputTileOptimalThreadMapBiasAct {
|
| 165 |
-
|
| 166 |
-
using Shape = Shape_;
|
| 167 |
-
using Count = Count_;
|
| 168 |
-
|
| 169 |
-
static int const kWarpSize = 32;
|
| 170 |
-
static int const kThreads = Threads;
|
| 171 |
-
static int const kWarpCount = kThreads / kWarpSize;
|
| 172 |
-
|
| 173 |
-
static int const kElementsPerAccess = ElementsPerAccess;
|
| 174 |
-
static int const kElementSize = ElementSize;
|
| 175 |
-
|
| 176 |
-
//
|
| 177 |
-
// Metaprogram computation
|
| 178 |
-
//
|
| 179 |
-
|
| 180 |
-
struct Detail {
|
| 181 |
-
|
| 182 |
-
// Clusters
|
| 183 |
-
static int const kIterationsCluster =
|
| 184 |
-
((Shape::kCluster > kWarpCount) ?
|
| 185 |
-
Shape::kCluster / kWarpCount
|
| 186 |
-
: 1);
|
| 187 |
-
|
| 188 |
-
static int const kDeltaCluster =
|
| 189 |
-
((Shape::kCluster > kWarpCount) ?
|
| 190 |
-
Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
|
| 191 |
-
: 1);
|
| 192 |
-
|
| 193 |
-
static int const kCompactedDeltaCluster =
|
| 194 |
-
((Shape::kCluster > kWarpCount) ?
|
| 195 |
-
Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
|
| 196 |
-
: 1);
|
| 197 |
-
|
| 198 |
-
static int const kWarpPartitionsCluster =
|
| 199 |
-
((Shape::kCluster > kWarpCount) ?
|
| 200 |
-
kWarpCount
|
| 201 |
-
: kWarpCount / Shape::kCluster);
|
| 202 |
-
|
| 203 |
-
static int const kWarpsRemainingForGroups =
|
| 204 |
-
((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
|
| 205 |
-
|
| 206 |
-
// Groups
|
| 207 |
-
static int const kIterationsGroup =
|
| 208 |
-
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
| 209 |
-
Shape::kGroup / kWarpsRemainingForGroups
|
| 210 |
-
: 1);
|
| 211 |
-
|
| 212 |
-
static int const kDeltaGroup =
|
| 213 |
-
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
| 214 |
-
Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
|
| 215 |
-
: 1);
|
| 216 |
-
|
| 217 |
-
static int const kCompactedDeltaGroup =
|
| 218 |
-
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
| 219 |
-
Shape::kRow * Shape::kGroup / kIterationsGroup
|
| 220 |
-
: 1);
|
| 221 |
-
|
| 222 |
-
static int const kWarpPartitionsGroup =
|
| 223 |
-
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
| 224 |
-
1
|
| 225 |
-
: kWarpsRemainingForGroups / Shape::kGroup);
|
| 226 |
-
|
| 227 |
-
static int const kWarpsRemainingForRows =
|
| 228 |
-
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
| 229 |
-
1
|
| 230 |
-
: kWarpsRemainingForGroups / Shape::kGroup);
|
| 231 |
-
|
| 232 |
-
// Rows
|
| 233 |
-
using RowArrangement = detail::RowArrangementBiasAct<
|
| 234 |
-
Shape,
|
| 235 |
-
kWarpsRemainingForRows,
|
| 236 |
-
kElementsPerAccess,
|
| 237 |
-
kElementSize,
|
| 238 |
-
(Shape::kRow > kWarpsRemainingForRows)
|
| 239 |
-
>;
|
| 240 |
-
|
| 241 |
-
// Warp partitions
|
| 242 |
-
using WarpPartitions = OutputTileShape<
|
| 243 |
-
RowArrangement::kWarpPartitionsColumn,
|
| 244 |
-
RowArrangement::kWarpPartitionsRow,
|
| 245 |
-
kWarpPartitionsGroup,
|
| 246 |
-
kWarpPartitionsCluster,
|
| 247 |
-
1>;
|
| 248 |
-
|
| 249 |
-
static int const kAccessWidth = RowArrangement::kAccessWidth;
|
| 250 |
-
static int const kAccessRows = RowArrangement::kAccessRows;
|
| 251 |
-
};
|
| 252 |
-
|
| 253 |
-
//
|
| 254 |
-
// Output
|
| 255 |
-
//
|
| 256 |
-
|
| 257 |
-
using Iterations = OutputTileShape<
|
| 258 |
-
Detail::RowArrangement::kIterationsColumn,
|
| 259 |
-
Detail::RowArrangement::kIterationsRow,
|
| 260 |
-
Detail::kIterationsGroup,
|
| 261 |
-
Detail::kIterationsCluster,
|
| 262 |
-
1>;
|
| 263 |
-
|
| 264 |
-
using Delta = OutputTileShape<
|
| 265 |
-
Detail::RowArrangement::kDeltaColumn,
|
| 266 |
-
Detail::RowArrangement::kDeltaRow,
|
| 267 |
-
Detail::kDeltaGroup,
|
| 268 |
-
Detail::kDeltaCluster,
|
| 269 |
-
1>;
|
| 270 |
-
|
| 271 |
-
/// Initial offset function
|
| 272 |
-
CUTLASS_HOST_DEVICE
|
| 273 |
-
static MatrixCoord initial_offset(int thread_idx) {
|
| 274 |
-
|
| 275 |
-
int warp_idx = thread_idx / kWarpSize;
|
| 276 |
-
int lane_idx = thread_idx % kWarpSize;
|
| 277 |
-
|
| 278 |
-
// Compute warp location
|
| 279 |
-
int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
|
| 280 |
-
int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
|
| 281 |
-
|
| 282 |
-
int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
|
| 283 |
-
int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
|
| 284 |
-
|
| 285 |
-
int row_idx = residual_group / Detail::WarpPartitions::kRow;
|
| 286 |
-
int col_idx = residual_group % Detail::WarpPartitions::kRow;
|
| 287 |
-
|
| 288 |
-
// Compute per-lane offset
|
| 289 |
-
int lane_row_offset = lane_idx / Detail::kAccessWidth;
|
| 290 |
-
int lane_col_offset = lane_idx % Detail::kAccessWidth;
|
| 291 |
-
|
| 292 |
-
// Compute coordinate in output space
|
| 293 |
-
int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
|
| 294 |
-
int group_offset = group_idx * Shape::kRow * Count::kRow;
|
| 295 |
-
int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
|
| 296 |
-
int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
|
| 297 |
-
|
| 298 |
-
return MatrixCoord(
|
| 299 |
-
cluster_offset + group_offset + row_offset + lane_row_offset,
|
| 300 |
-
(column_offset + lane_col_offset) * kElementsPerAccess
|
| 301 |
-
);
|
| 302 |
-
}
|
| 303 |
-
|
| 304 |
-
};
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 308 |
-
|
| 309 |
-
} // namespace threadblock
|
| 310 |
-
} // namespace epilogue
|
| 311 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h
DELETED
|
@@ -1,189 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile
|
| 34 |
-
that participate in one warp-level store operation.
|
| 35 |
-
|
| 36 |
-
Typically, the accumulator tile is the largest single block of register-backed storage
|
| 37 |
-
within the kernel. Storing it to memory is best accomplished by partitioning it into
|
| 38 |
-
smaller tiles and storing these sequentially.
|
| 39 |
-
|
| 40 |
-
Round trips through shared memory during the Epilogue phase require partitioning, as
|
| 41 |
-
shared memory capacity is typically insufficient for a threadblock's total accumulator
|
| 42 |
-
size.
|
| 43 |
-
*/
|
| 44 |
-
|
| 45 |
-
#pragma once
|
| 46 |
-
|
| 47 |
-
#include "cutlass/array.h"
|
| 48 |
-
#include "cutlass/layout/matrix.h"
|
| 49 |
-
|
| 50 |
-
#include "cutlass/epilogue/warp/tensor_op_policy.h"
|
| 51 |
-
|
| 52 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
-
|
| 54 |
-
namespace cutlass {
|
| 55 |
-
namespace epilogue {
|
| 56 |
-
namespace warp {
|
| 57 |
-
|
| 58 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 59 |
-
|
| 60 |
-
///
|
| 61 |
-
template <
|
| 62 |
-
typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape)
|
| 63 |
-
typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
| 64 |
-
typename OperatorElementC, ///< matrix multiply operation data type (concept: data type)
|
| 65 |
-
typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array)
|
| 66 |
-
typename Layout ///< target shared memory layout
|
| 67 |
-
>
|
| 68 |
-
class FusedBiasActFragmentIteratorTensorOp;
|
| 69 |
-
|
| 70 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 71 |
-
|
| 72 |
-
/// Partial specialization for row-major shared memory
|
| 73 |
-
template <
|
| 74 |
-
typename WarpShape_, ///< shape of the warp-level GEMM tile
|
| 75 |
-
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
| 76 |
-
typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type)
|
| 77 |
-
typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array)
|
| 78 |
-
>
|
| 79 |
-
class FusedBiasActFragmentIteratorTensorOp<WarpShape_, OperatorShape_, OperatorElementC_, OperatorFragmentC_, layout::RowMajor> {
|
| 80 |
-
public:
|
| 81 |
-
|
| 82 |
-
using WarpShape = WarpShape_;
|
| 83 |
-
using OperatorShape = OperatorShape_;
|
| 84 |
-
using OperatorElementC = OperatorElementC_;
|
| 85 |
-
using OperatorFragmentC = OperatorFragmentC_;
|
| 86 |
-
using Layout = layout::RowMajor;
|
| 87 |
-
|
| 88 |
-
using Policy = TensorOpPolicy<WarpShape, OperatorShape, Layout>;
|
| 89 |
-
|
| 90 |
-
/// This is the fragment size produced by one access of the iterator.
|
| 91 |
-
using Fragment = Array<
|
| 92 |
-
OperatorElementC,
|
| 93 |
-
Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
|
| 94 |
-
|
| 95 |
-
/// This is the complete warp-level accumulator tile.
|
| 96 |
-
using AccumulatorTile = Array<
|
| 97 |
-
OperatorElementC,
|
| 98 |
-
OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>;
|
| 99 |
-
|
| 100 |
-
using OutputAccumulatorTile = AccumulatorTile;
|
| 101 |
-
|
| 102 |
-
/// Number of times this iterator can be incremented
|
| 103 |
-
static int const kIterations = Policy::kIterations;
|
| 104 |
-
|
| 105 |
-
private:
|
| 106 |
-
|
| 107 |
-
/// Internal access type
|
| 108 |
-
using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
|
| 109 |
-
|
| 110 |
-
private:
|
| 111 |
-
|
| 112 |
-
//
|
| 113 |
-
// Data members
|
| 114 |
-
//
|
| 115 |
-
|
| 116 |
-
/// Accumulator tile
|
| 117 |
-
AccessType *accumulators_;
|
| 118 |
-
|
| 119 |
-
/// Internal index
|
| 120 |
-
int index_;
|
| 121 |
-
|
| 122 |
-
public:
|
| 123 |
-
|
| 124 |
-
/// Constructs an iterator
|
| 125 |
-
CUTLASS_HOST_DEVICE
|
| 126 |
-
FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum):
|
| 127 |
-
accumulators_(reinterpret_cast<AccessType *>(&accum)),
|
| 128 |
-
index_(0) {
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
/// Increments
|
| 132 |
-
CUTLASS_HOST_DEVICE
|
| 133 |
-
FusedBiasActFragmentIteratorTensorOp &operator++() {
|
| 134 |
-
++index_;
|
| 135 |
-
return *this;
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
/// Decrements
|
| 139 |
-
CUTLASS_HOST_DEVICE
|
| 140 |
-
FusedBiasActFragmentIteratorTensorOp &operator--() {
|
| 141 |
-
--index_;
|
| 142 |
-
return *this;
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
/// Loads a fragment from the referenced part of the accumulator tile
|
| 146 |
-
CUTLASS_HOST_DEVICE
|
| 147 |
-
void load(Fragment &frag, int index_offset = 0) const {
|
| 148 |
-
|
| 149 |
-
int index = index_ + index_offset;
|
| 150 |
-
|
| 151 |
-
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 152 |
-
|
| 153 |
-
CUTLASS_PRAGMA_UNROLL
|
| 154 |
-
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
|
| 155 |
-
|
| 156 |
-
int accumulator_access_offset =
|
| 157 |
-
index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
|
| 158 |
-
|
| 159 |
-
frag_ptr[n] = accumulators_[accumulator_access_offset];
|
| 160 |
-
}
|
| 161 |
-
}
|
| 162 |
-
/// Stores a fragment from the referenced part of the accumulator tile
|
| 163 |
-
CUTLASS_HOST_DEVICE
|
| 164 |
-
void store(Fragment &frag, int index_offset = 0) const {
|
| 165 |
-
|
| 166 |
-
int index = index_ + index_offset;
|
| 167 |
-
|
| 168 |
-
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 169 |
-
|
| 170 |
-
CUTLASS_PRAGMA_UNROLL
|
| 171 |
-
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
|
| 172 |
-
|
| 173 |
-
int accumulator_access_offset =
|
| 174 |
-
index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
|
| 175 |
-
|
| 176 |
-
accumulators_[accumulator_access_offset] = frag_ptr[n];
|
| 177 |
-
}
|
| 178 |
-
}
|
| 179 |
-
};
|
| 180 |
-
|
| 181 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 182 |
-
|
| 183 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 184 |
-
|
| 185 |
-
} // namespace warp
|
| 186 |
-
} // namespace epilogue
|
| 187 |
-
} // namespace cutlass
|
| 188 |
-
|
| 189 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h
DELETED
|
@@ -1,427 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
|
| 36 |
-
#include "cutlass/array.h"
|
| 37 |
-
#include "cutlass/matrix_shape.h"
|
| 38 |
-
#include "cutlass/layout/matrix.h"
|
| 39 |
-
#include "cutlass/layout/tensor.h"
|
| 40 |
-
#include "cutlass/numeric_conversion.h"
|
| 41 |
-
|
| 42 |
-
namespace cutlass {
|
| 43 |
-
namespace gemm {
|
| 44 |
-
namespace warp {
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 48 |
-
|
| 49 |
-
template <
|
| 50 |
-
/// Size of the matrix to load (concept: MatrixShape)
|
| 51 |
-
typename Shape_,
|
| 52 |
-
/// Size of the accumulation tile shape (concept: MatrixShape)
|
| 53 |
-
typename AccumulatorShape_,
|
| 54 |
-
/// KBlocks columns to compute residual
|
| 55 |
-
int KBlocksColumn_,
|
| 56 |
-
/// Accumulator Element type
|
| 57 |
-
typename ElementAccumulator_,
|
| 58 |
-
/// Element type
|
| 59 |
-
typename Element_,
|
| 60 |
-
/// Layout of operand in memory
|
| 61 |
-
typename Layout_,
|
| 62 |
-
/// Shape of one matrix product operation (concept: MatrixShape)
|
| 63 |
-
typename InstructionShape_,
|
| 64 |
-
/// Whether beta is zero
|
| 65 |
-
bool IsBetaZero_ >
|
| 66 |
-
class MmaTensorOpPureFragmentIterator;
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
// Partial specialization for col-major accumulator tile
|
| 70 |
-
// And Element type is the same as Accumulator Element type
|
| 71 |
-
|
| 72 |
-
template <
|
| 73 |
-
/// Shape of warp tile to load (concept: MatrixShape)
|
| 74 |
-
typename Shape_,
|
| 75 |
-
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
| 76 |
-
typename AccumulatorShape_,
|
| 77 |
-
/// KBlocks columns to compute residual
|
| 78 |
-
int KBlocksColumn_,
|
| 79 |
-
/// Element type
|
| 80 |
-
typename Element_,
|
| 81 |
-
/// Shape of one matrix product operation (concept: MatrixShape)
|
| 82 |
-
typename InstructionShape_>
|
| 83 |
-
class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
|
| 84 |
-
cutlass::layout::ColumnMajor,
|
| 85 |
-
InstructionShape_, true> {
|
| 86 |
-
public:
|
| 87 |
-
|
| 88 |
-
/// Shape of warp tile to load (concept: MatrixShape)
|
| 89 |
-
using Shape = Shape_;
|
| 90 |
-
|
| 91 |
-
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
| 92 |
-
using AccumulatorShape = AccumulatorShape_;
|
| 93 |
-
|
| 94 |
-
/// KBlocks columns to compute residual
|
| 95 |
-
static int const kKBlockColumn = KBlocksColumn_;
|
| 96 |
-
|
| 97 |
-
/// Element type
|
| 98 |
-
using Element = Element_;
|
| 99 |
-
|
| 100 |
-
/// Layout of source tile
|
| 101 |
-
using Layout = cutlass::layout::ColumnMajor;
|
| 102 |
-
|
| 103 |
-
/// Shape of one matrix product operation (concept: MatrixShape)
|
| 104 |
-
using InstructionShape = InstructionShape_;
|
| 105 |
-
|
| 106 |
-
/// Whether beta is zero
|
| 107 |
-
static bool const IsBetaZero = true;
|
| 108 |
-
|
| 109 |
-
/// Number of participating threads
|
| 110 |
-
static int const kThreads = 32;
|
| 111 |
-
|
| 112 |
-
/// Internal structure of iterator - made public to enable introspection
|
| 113 |
-
struct Policy {
|
| 114 |
-
static_assert(
|
| 115 |
-
!(Shape::kRow % InstructionShape::kM) &&
|
| 116 |
-
!(Shape::kColumn % InstructionShape::kN),
|
| 117 |
-
"Shape of warp-level Mma must be divisible by operator shape.");
|
| 118 |
-
static_assert(
|
| 119 |
-
!(AccumulatorShape::kRow % Shape::kRow) &&
|
| 120 |
-
!(AccumulatorShape::kColumn % Shape::kColumn),
|
| 121 |
-
"Shape of Warp Accumulator must be divisible by warp shape.");
|
| 122 |
-
static_assert(
|
| 123 |
-
!(kKBlockColumn % Shape::kColumn),
|
| 124 |
-
"KBlock size must be divisible by warp shape.");
|
| 125 |
-
|
| 126 |
-
/// Number of times this iterator can be incremented
|
| 127 |
-
static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
|
| 128 |
-
};
|
| 129 |
-
|
| 130 |
-
private:
|
| 131 |
-
|
| 132 |
-
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
|
| 133 |
-
|
| 134 |
-
/// Number of mma operations performed by a warp
|
| 135 |
-
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
| 136 |
-
Shape::kColumn / InstructionShape::kN>;
|
| 137 |
-
/// Number of mma operations performed by the entire accumulator
|
| 138 |
-
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
|
| 139 |
-
AccumulatorShape::kColumn / InstructionShape::kN>;
|
| 140 |
-
|
| 141 |
-
/// Number of K iterations
|
| 142 |
-
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
|
| 143 |
-
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
|
| 144 |
-
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
|
| 145 |
-
* (AccumulatorShape::kRow / Shape::kRow);
|
| 146 |
-
static int const kResidualIndex = kResidualColumn / Shape::kColumn
|
| 147 |
-
* (AccumulatorShape::kRow / Shape::kRow);
|
| 148 |
-
|
| 149 |
-
public:
|
| 150 |
-
|
| 151 |
-
//
|
| 152 |
-
// Derived quantities
|
| 153 |
-
//
|
| 154 |
-
|
| 155 |
-
/// Fragment object holding a thread's part of a tile
|
| 156 |
-
/// This is the fragment size produced by one access of the iterator.
|
| 157 |
-
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
| 158 |
-
|
| 159 |
-
/// Accumulator Fragment object
|
| 160 |
-
using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
private:
|
| 164 |
-
|
| 165 |
-
/// Internal access type
|
| 166 |
-
using AccessType = Array<Element, kElementsPerAccess>;
|
| 167 |
-
|
| 168 |
-
private:
|
| 169 |
-
//
|
| 170 |
-
// Data members
|
| 171 |
-
//
|
| 172 |
-
|
| 173 |
-
/// Accumulator tile
|
| 174 |
-
AccessType const *accumulators_;
|
| 175 |
-
|
| 176 |
-
/// Internal index
|
| 177 |
-
int index_;
|
| 178 |
-
|
| 179 |
-
/// Used to access residual tile first
|
| 180 |
-
bool is_residual_tile_;
|
| 181 |
-
|
| 182 |
-
public:
|
| 183 |
-
/// Constructs an iterator
|
| 184 |
-
CUTLASS_HOST_DEVICE
|
| 185 |
-
MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
|
| 186 |
-
: accumulators_(reinterpret_cast<AccessType const *>(&accum)),
|
| 187 |
-
index_(0), is_residual_tile_(true) {}
|
| 188 |
-
|
| 189 |
-
/// Add offset
|
| 190 |
-
CUTLASS_HOST_DEVICE
|
| 191 |
-
void add_offset(int index_offset) {
|
| 192 |
-
index_ += index_offset;
|
| 193 |
-
if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
|
| 194 |
-
index_ = index_ - kKBlockColumnIterations + kResidualIndex;
|
| 195 |
-
is_residual_tile_ = false;
|
| 196 |
-
}
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
/// Increments
|
| 200 |
-
CUTLASS_HOST_DEVICE
|
| 201 |
-
MmaTensorOpPureFragmentIterator &operator++() {
|
| 202 |
-
add_offset(1);
|
| 203 |
-
return *this;
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
/// Decrements
|
| 207 |
-
CUTLASS_HOST_DEVICE
|
| 208 |
-
MmaTensorOpPureFragmentIterator &operator--() {
|
| 209 |
-
add_offset(-1);
|
| 210 |
-
return *this;
|
| 211 |
-
}
|
| 212 |
-
|
| 213 |
-
/// Loads a fragment from the referenced part of the accumulator tile
|
| 214 |
-
CUTLASS_HOST_DEVICE
|
| 215 |
-
void load(Fragment &frag) const {
|
| 216 |
-
|
| 217 |
-
AccessType src_fragment;
|
| 218 |
-
src_fragment.clear();
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 222 |
-
|
| 223 |
-
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
|
| 224 |
-
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
|
| 225 |
-
* MmaIterations::kColumn;
|
| 226 |
-
|
| 227 |
-
CUTLASS_PRAGMA_UNROLL
|
| 228 |
-
for (int n = 0; n < MmaIterations::kColumn; n++) {
|
| 229 |
-
for (int m = 0; m < MmaIterations::kRow; m++) {
|
| 230 |
-
int accumulator_access_offset =
|
| 231 |
-
(n + index_n) * AccumulatorIterations::kRow + m + index_m;
|
| 232 |
-
|
| 233 |
-
frag_ptr[n * MmaIterations::kRow + m].clear();
|
| 234 |
-
if(!(is_residual_tile_ && index_ >= kResidualIndex))
|
| 235 |
-
frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset];
|
| 236 |
-
// frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment);
|
| 237 |
-
}
|
| 238 |
-
}
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
};
|
| 242 |
-
|
| 243 |
-
// Partial specialization for row-major accumulator tile
|
| 244 |
-
|
| 245 |
-
template <
|
| 246 |
-
/// Shape of warp tile to load (concept: MatrixShape)
|
| 247 |
-
typename Shape_,
|
| 248 |
-
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
| 249 |
-
typename AccumulatorShape_,
|
| 250 |
-
/// KBlocks columns to compute residual
|
| 251 |
-
int KBlocksColumn_,
|
| 252 |
-
/// Accumulator Element type
|
| 253 |
-
typename ElementAccumulator_,
|
| 254 |
-
/// Element type
|
| 255 |
-
typename Element_,
|
| 256 |
-
/// Shape of one matrix product operation (concept: MatrixShape)
|
| 257 |
-
typename InstructionShape_>
|
| 258 |
-
class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
|
| 259 |
-
cutlass::layout::RowMajor,
|
| 260 |
-
InstructionShape_, true> {
|
| 261 |
-
public:
|
| 262 |
-
|
| 263 |
-
/// Shape of warp tile to load (concept: MatrixShape)
|
| 264 |
-
using Shape = Shape_;
|
| 265 |
-
|
| 266 |
-
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
| 267 |
-
using AccumulatorShape = AccumulatorShape_;
|
| 268 |
-
|
| 269 |
-
/// KBlocks columns to compute residual
|
| 270 |
-
static int const kKBlockColumn = KBlocksColumn_;
|
| 271 |
-
|
| 272 |
-
/// Accumulator Element type
|
| 273 |
-
using ElementAccumulator = ElementAccumulator_;
|
| 274 |
-
|
| 275 |
-
/// Element type
|
| 276 |
-
using Element = Element_;
|
| 277 |
-
|
| 278 |
-
/// Layout of source tile
|
| 279 |
-
using Layout = cutlass::layout::RowMajor;
|
| 280 |
-
|
| 281 |
-
/// Shape of one matrix product operation (concept: MatrixShape)
|
| 282 |
-
using InstructionShape = InstructionShape_;
|
| 283 |
-
|
| 284 |
-
/// Whether beta is zero
|
| 285 |
-
static bool const IsBetaZero = true;
|
| 286 |
-
|
| 287 |
-
/// Number of participating threads
|
| 288 |
-
static int const kThreads = 32;
|
| 289 |
-
|
| 290 |
-
/// Internal structure of iterator - made public to enable introspection
|
| 291 |
-
struct Policy {
|
| 292 |
-
static_assert(
|
| 293 |
-
!(Shape::kRow % InstructionShape::kM) &&
|
| 294 |
-
!(Shape::kColumn % InstructionShape::kN),
|
| 295 |
-
"Shape of warp-level Mma must be divisible by operator shape.");
|
| 296 |
-
static_assert(
|
| 297 |
-
!(AccumulatorShape::kRow % Shape::kRow) &&
|
| 298 |
-
!(AccumulatorShape::kColumn % Shape::kColumn),
|
| 299 |
-
"Shape of Warp Accumulator must be divisible by warp shape.");
|
| 300 |
-
static_assert(
|
| 301 |
-
!(kKBlockColumn % Shape::kColumn),
|
| 302 |
-
"KBlock size must be divisible by warp shape.");
|
| 303 |
-
|
| 304 |
-
/// Number of times this iterator can be incremented
|
| 305 |
-
static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
|
| 306 |
-
};
|
| 307 |
-
|
| 308 |
-
private:
|
| 309 |
-
|
| 310 |
-
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
|
| 311 |
-
|
| 312 |
-
/// Number of mma operations performed by a warp
|
| 313 |
-
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
| 314 |
-
Shape::kColumn / InstructionShape::kN>;
|
| 315 |
-
/// Number of mma operations performed by the entire accumulator
|
| 316 |
-
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
|
| 317 |
-
AccumulatorShape::kColumn / InstructionShape::kN>;
|
| 318 |
-
|
| 319 |
-
/// Number of K iterations
|
| 320 |
-
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
|
| 321 |
-
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
|
| 322 |
-
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
|
| 323 |
-
* (AccumulatorShape::kRow / Shape::kRow);
|
| 324 |
-
static int const kResidualIndex = kResidualColumn / Shape::kColumn
|
| 325 |
-
* (AccumulatorShape::kRow / Shape::kRow);
|
| 326 |
-
|
| 327 |
-
public:
|
| 328 |
-
|
| 329 |
-
//
|
| 330 |
-
// Derived quantities
|
| 331 |
-
//
|
| 332 |
-
|
| 333 |
-
/// Fragment object holding a thread's part of a tile
|
| 334 |
-
/// This is the fragment size produced by one access of the iterator.
|
| 335 |
-
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
| 336 |
-
|
| 337 |
-
/// Accumulator Fragment object
|
| 338 |
-
using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
private:
|
| 342 |
-
|
| 343 |
-
/// Internal access type
|
| 344 |
-
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
| 345 |
-
using FragmentAccessType = Array<Element, kElementsPerAccess>;
|
| 346 |
-
|
| 347 |
-
private:
|
| 348 |
-
//
|
| 349 |
-
// Data members
|
| 350 |
-
//
|
| 351 |
-
|
| 352 |
-
/// Accumulator tile
|
| 353 |
-
AccessType const *accumulators_;
|
| 354 |
-
|
| 355 |
-
/// Internal index
|
| 356 |
-
int index_;
|
| 357 |
-
|
| 358 |
-
/// Used to access residual tile first
|
| 359 |
-
bool is_residual_tile_;
|
| 360 |
-
|
| 361 |
-
public:
|
| 362 |
-
/// Constructs an iterator
|
| 363 |
-
CUTLASS_HOST_DEVICE
|
| 364 |
-
MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
|
| 365 |
-
: accumulators_(reinterpret_cast<AccessType const *>(&accum)),
|
| 366 |
-
index_(0), is_residual_tile_(true) {}
|
| 367 |
-
|
| 368 |
-
/// Add offset
|
| 369 |
-
CUTLASS_HOST_DEVICE
|
| 370 |
-
void add_offset(int index_offset) {
|
| 371 |
-
index_ += index_offset;
|
| 372 |
-
if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
|
| 373 |
-
index_ = index_ - kKBlockColumnIterations + kResidualIndex;
|
| 374 |
-
is_residual_tile_ = false;
|
| 375 |
-
}
|
| 376 |
-
}
|
| 377 |
-
|
| 378 |
-
/// Increments
|
| 379 |
-
CUTLASS_HOST_DEVICE
|
| 380 |
-
MmaTensorOpPureFragmentIterator &operator++() {
|
| 381 |
-
add_offset(1);
|
| 382 |
-
return *this;
|
| 383 |
-
}
|
| 384 |
-
|
| 385 |
-
/// Decrements
|
| 386 |
-
CUTLASS_HOST_DEVICE
|
| 387 |
-
MmaTensorOpPureFragmentIterator &operator--() {
|
| 388 |
-
add_offset(-1);
|
| 389 |
-
return *this;
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
/// Loads a fragment from the referenced part of the accumulator tile
|
| 393 |
-
CUTLASS_HOST_DEVICE
|
| 394 |
-
void load(Fragment &frag) const {
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
FragmentAccessType src_fragment;
|
| 398 |
-
src_fragment.clear();
|
| 399 |
-
|
| 400 |
-
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
|
| 401 |
-
|
| 402 |
-
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
|
| 403 |
-
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
|
| 404 |
-
* MmaIterations::kColumn;
|
| 405 |
-
|
| 406 |
-
CUTLASS_PRAGMA_UNROLL
|
| 407 |
-
for (int m = 0; m < MmaIterations::kRow; m++) {
|
| 408 |
-
for (int n = 0; n < MmaIterations::kColumn; n++) {
|
| 409 |
-
int accumulator_access_offset =
|
| 410 |
-
(m + index_m) * AccumulatorIterations::kColumn + n + index_n;
|
| 411 |
-
|
| 412 |
-
frag_ptr[m * MmaIterations::kColumn + n].clear();
|
| 413 |
-
if(!(is_residual_tile_ && index_ >= kResidualIndex))
|
| 414 |
-
frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]);
|
| 415 |
-
}
|
| 416 |
-
}
|
| 417 |
-
}
|
| 418 |
-
|
| 419 |
-
};
|
| 420 |
-
|
| 421 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 422 |
-
|
| 423 |
-
} // namespace warp
|
| 424 |
-
} // namespace gemm
|
| 425 |
-
} // namespace cutlass
|
| 426 |
-
|
| 427 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import gen_turing_and_volta as api_generator
|
| 34 |
-
import gen_sample as sample_creater
|
| 35 |
-
import gen_cmake as cmake_creater
|
| 36 |
-
import gen_verify as verify_creater
|
| 37 |
-
import gen_device as b2b_fused_generator
|
| 38 |
-
import replace_fix_impl_header
|
| 39 |
-
|
| 40 |
-
import argparse
|
| 41 |
-
import os
|
| 42 |
-
import json
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels")
|
| 46 |
-
parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate")
|
| 47 |
-
parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name")
|
| 48 |
-
parser.add_argument("--output-dir", default="", help="Specifies the output dir")
|
| 49 |
-
parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir")
|
| 50 |
-
parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.")
|
| 51 |
-
args = parser.parse_args()
|
| 52 |
-
|
| 53 |
-
gen_name = args.gen_name
|
| 54 |
-
|
| 55 |
-
cutlass_deps_dir = args.cutlass_dir
|
| 56 |
-
|
| 57 |
-
output_dir = args.output_dir
|
| 58 |
-
output_dir += "/"
|
| 59 |
-
|
| 60 |
-
cutlass_deps_root = args.gen_include_cutlass_dir
|
| 61 |
-
if cutlass_deps_root == '':
|
| 62 |
-
cutlass_deps_root = cutlass_deps_dir + "/include/"
|
| 63 |
-
cutlass_deps_root +='/'
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
if not os.path.exists(output_dir):
|
| 67 |
-
os.makedirs(output_dir)
|
| 68 |
-
|
| 69 |
-
if not os.path.exists(output_dir + "/" + "auto_gen"):
|
| 70 |
-
os.mkdir(output_dir + "/" + "auto_gen")
|
| 71 |
-
|
| 72 |
-
if not os.path.exists(output_dir + "/" + "fixed_impl"):
|
| 73 |
-
os.mkdir(output_dir + "/" + "fixed_impl" )
|
| 74 |
-
|
| 75 |
-
if not os.path.exists(output_dir + "/" + "sample"):
|
| 76 |
-
os.mkdir(output_dir + "/" + "sample" )
|
| 77 |
-
|
| 78 |
-
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"):
|
| 79 |
-
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device")
|
| 80 |
-
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"):
|
| 81 |
-
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel")
|
| 82 |
-
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"):
|
| 83 |
-
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock")
|
| 84 |
-
|
| 85 |
-
with open(args.config_file, 'r') as infile:
|
| 86 |
-
gemm_info_dict = json.load(infile)
|
| 87 |
-
|
| 88 |
-
keys = sorted(gemm_info_dict.keys())
|
| 89 |
-
fuse_gemm_info = [gemm_info_dict[k] for k in keys]
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
for_cutlass_gen_user_include_header_file = [
|
| 93 |
-
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
|
| 94 |
-
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
|
| 95 |
-
]
|
| 96 |
-
|
| 97 |
-
for_fused_wrapper = [
|
| 98 |
-
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
|
| 99 |
-
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
|
| 100 |
-
"auto_gen/device/" + gen_name + ".h",
|
| 101 |
-
cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h",
|
| 102 |
-
cutlass_deps_root + "cutlass/cutlass.h",
|
| 103 |
-
]
|
| 104 |
-
|
| 105 |
-
# Copy fixed implementation to the output directory
|
| 106 |
-
fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root)
|
| 107 |
-
fix_impl.gen_code()
|
| 108 |
-
|
| 109 |
-
auto_gen_output_dir = output_dir + "/auto_gen/"
|
| 110 |
-
project_root = ""
|
| 111 |
-
turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir)
|
| 112 |
-
turing_plus.gen_code(75, 'hmma1688', False)
|
| 113 |
-
|
| 114 |
-
api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
|
| 115 |
-
api.gen_code()
|
| 116 |
-
|
| 117 |
-
# Generate C++ sample
|
| 118 |
-
os.system("cp ../leaky_bias.h " + output_dir + "/sample/")
|
| 119 |
-
os.system("cp ../utils.h " + output_dir + "/sample/")
|
| 120 |
-
|
| 121 |
-
sample_dir = output_dir + "/sample/"
|
| 122 |
-
sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir)
|
| 123 |
-
sample.gen_cpp_sample()
|
| 124 |
-
|
| 125 |
-
cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir)
|
| 126 |
-
cmake_gen.gen_code()
|
| 127 |
-
|
| 128 |
-
verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
|
| 129 |
-
verify.gen_code()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py
DELETED
|
@@ -1,131 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
class gen_build_sys:
|
| 34 |
-
def __init__(self, cutlass_deps_dir, output_dir = "../"):
|
| 35 |
-
self.output_dir = output_dir
|
| 36 |
-
self.cutlass_deps_dir = cutlass_deps_dir
|
| 37 |
-
|
| 38 |
-
def gen_top(self):
|
| 39 |
-
code = ""
|
| 40 |
-
code += '''\
|
| 41 |
-
# Auto Generated code - Do not edit.
|
| 42 |
-
|
| 43 |
-
cmake_minimum_required(VERSION 3.8)
|
| 44 |
-
project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA)
|
| 45 |
-
find_package(CUDAToolkit)
|
| 46 |
-
set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}})
|
| 47 |
-
set(CUTLASS_PATH \"{cutlass_deps_dir}/include\")
|
| 48 |
-
set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\")
|
| 49 |
-
list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}})
|
| 50 |
-
'''.format(cutlass_deps_dir=self.cutlass_deps_dir)
|
| 51 |
-
|
| 52 |
-
code += '''\
|
| 53 |
-
set(GPU_ARCHS \"\" CACHE STRING
|
| 54 |
-
\"List of GPU architectures (semicolon-separated) to be compiled for.\")
|
| 55 |
-
|
| 56 |
-
if(\"${GPU_ARCHS}\" STREQUAL \"\")
|
| 57 |
-
set(GPU_ARCHS \"70\")
|
| 58 |
-
endif()
|
| 59 |
-
|
| 60 |
-
foreach(arch ${GPU_ARCHS})
|
| 61 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\")
|
| 62 |
-
if(SM STREQUAL 70 OR SM STREQUAL 75)
|
| 63 |
-
set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\")
|
| 64 |
-
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\")
|
| 65 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\")
|
| 66 |
-
endif()
|
| 67 |
-
endforeach()
|
| 68 |
-
|
| 69 |
-
set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\")
|
| 70 |
-
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\")
|
| 71 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\")
|
| 72 |
-
|
| 73 |
-
set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\")
|
| 74 |
-
set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\")
|
| 75 |
-
set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\")
|
| 76 |
-
|
| 77 |
-
set(CMAKE_CXX_STANDARD 11)
|
| 78 |
-
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 79 |
-
|
| 80 |
-
if(CMAKE_CXX_STANDARD STREQUAL \"11\")
|
| 81 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\")
|
| 82 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\")
|
| 83 |
-
endif()
|
| 84 |
-
|
| 85 |
-
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\")
|
| 86 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\")
|
| 87 |
-
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\")
|
| 88 |
-
|
| 89 |
-
set(COMMON_HEADER_DIRS
|
| 90 |
-
${PROJECT_SOURCE_DIR}
|
| 91 |
-
${CUDAToolkit_INCLUDE_DIRS}
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
set(COMMON_LIB_DIRS
|
| 95 |
-
${CUDAToolkit_LIBRARY_DIR}
|
| 96 |
-
)
|
| 97 |
-
list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH})
|
| 98 |
-
list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH})
|
| 99 |
-
'''
|
| 100 |
-
code += '''\
|
| 101 |
-
include_directories(
|
| 102 |
-
${COMMON_HEADER_DIRS}
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
link_directories(
|
| 106 |
-
${COMMON_LIB_DIRS}
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
| 110 |
-
add_definitions(-DGOOGLE_CUDA=1)
|
| 111 |
-
|
| 112 |
-
add_executable(sample
|
| 113 |
-
sample/sample.cu
|
| 114 |
-
one_api.cu
|
| 115 |
-
)
|
| 116 |
-
target_link_libraries(sample PRIVATE
|
| 117 |
-
-lcudart
|
| 118 |
-
-lnvToolsExt
|
| 119 |
-
${CMAKE_THREAD_LIBS_INIT}
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
if(NOT DEFINED LIB_INSTALL_PATH)
|
| 123 |
-
set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR})
|
| 124 |
-
endif()
|
| 125 |
-
'''
|
| 126 |
-
return code
|
| 127 |
-
|
| 128 |
-
def gen_code(self):
|
| 129 |
-
top_code = self.gen_top()
|
| 130 |
-
with open(self.output_dir + "CMakeLists.txt", "w") as f:
|
| 131 |
-
f.write(top_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import ast
|
| 34 |
-
|
| 35 |
-
fuse_gemm_info = [
|
| 36 |
-
{
|
| 37 |
-
'epilogue': {
|
| 38 |
-
'tp': 'LeakyRelu', #'CustomizedLeaky_RELU'
|
| 39 |
-
'bias': {'addbias': False, 'bias_tp': 'mat'},
|
| 40 |
-
'args': [('float', 'leaky_alpha', 1.3), ],
|
| 41 |
-
'func': '''
|
| 42 |
-
y = max(leaky_alpha * x, x)
|
| 43 |
-
y = y * x
|
| 44 |
-
'''
|
| 45 |
-
}
|
| 46 |
-
},
|
| 47 |
-
|
| 48 |
-
]
|
| 49 |
-
class AnalysisNodeVisitor(ast.NodeVisitor):
|
| 50 |
-
def visit_Import(self,node):
|
| 51 |
-
ast.NodeVisitor.generic_visit(self, node)
|
| 52 |
-
|
| 53 |
-
def visit_ImportFrom(self,node):
|
| 54 |
-
ast.NodeVisitor.generic_visit(self, node)
|
| 55 |
-
|
| 56 |
-
def visit_Assign(self,node):
|
| 57 |
-
print('Node type: Assign and fields: ', node._fields)
|
| 58 |
-
# print('Node type: Assign and targets value: ', node.targets, node.value)
|
| 59 |
-
|
| 60 |
-
ast.NodeVisitor.generic_visit(self, node)
|
| 61 |
-
|
| 62 |
-
def visit_BinOp(self, node):
|
| 63 |
-
print('Node type: BinOp and fields: ', node._fields)
|
| 64 |
-
print('node op: ', type(node.op).__name__)
|
| 65 |
-
ast.NodeVisitor.generic_visit(self, node)
|
| 66 |
-
|
| 67 |
-
def visit_Expr(self, node):
|
| 68 |
-
print('Node type: Expr and fields: ', node._fields)
|
| 69 |
-
ast.NodeVisitor.generic_visit(self, node)
|
| 70 |
-
|
| 71 |
-
def visit_Num(self,node):
|
| 72 |
-
print('Node type: Num and fields: ', node._fields)
|
| 73 |
-
print('Node type: Num: ', node.n)
|
| 74 |
-
|
| 75 |
-
def visit_Name(self,node):
|
| 76 |
-
print('Node type: Name and fields: ', node._fields)
|
| 77 |
-
print('Node type: Name and fields: ', type(node.ctx).__name__, node.id)
|
| 78 |
-
|
| 79 |
-
ast.NodeVisitor.generic_visit(self, node)
|
| 80 |
-
|
| 81 |
-
def visit_Str(self, node):
|
| 82 |
-
print('Node type: Str and fields: ', node._fields)
|
| 83 |
-
|
| 84 |
-
class CodeVisitor(ast.NodeVisitor):
|
| 85 |
-
def visit_BinOp(self, node):
|
| 86 |
-
if isinstance(node.op, ast.Add):
|
| 87 |
-
node.op = ast.Sub()
|
| 88 |
-
self.generic_visit(node)
|
| 89 |
-
|
| 90 |
-
def visit_Assign(self, node):
|
| 91 |
-
print('Assign %s' % node.value)
|
| 92 |
-
self.generic_visit(node)
|
| 93 |
-
|
| 94 |
-
def visit_Name(self, node):
|
| 95 |
-
print("Name:", node.id)
|
| 96 |
-
self.generic_visit(node)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def visit_FunctionDef(self, node):
|
| 100 |
-
print('Function Name:%s'% node.name.op)
|
| 101 |
-
self.generic_visit(node)
|
| 102 |
-
func_log_stmt = ast.Print(
|
| 103 |
-
dest = None,
|
| 104 |
-
values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
|
| 105 |
-
nl = True,
|
| 106 |
-
lineno = 0,
|
| 107 |
-
col_offset = 0,
|
| 108 |
-
)
|
| 109 |
-
node.body.insert(0, func_log_stmt)
|
| 110 |
-
|
| 111 |
-
visitor = AnalysisNodeVisitor()
|
| 112 |
-
|
| 113 |
-
code = \
|
| 114 |
-
'''
|
| 115 |
-
|
| 116 |
-
a=max(leaky_alpha * x, x +1)
|
| 117 |
-
|
| 118 |
-
'''
|
| 119 |
-
|
| 120 |
-
visitor.visit(ast.parse(code))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py
DELETED
|
@@ -1,469 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from typing import *
|
| 34 |
-
|
| 35 |
-
import helper
|
| 36 |
-
import gen_ir
|
| 37 |
-
|
| 38 |
-
import gen_kernel as gen_ker
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class gen_device:
|
| 42 |
-
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"):
|
| 43 |
-
self.fuse_gemm_info = fuse_gemm_info
|
| 44 |
-
self.raw_gemm_info = fuse_gemm_info
|
| 45 |
-
self.b2b_num = len(fuse_gemm_info)
|
| 46 |
-
self.user_header_file = user_header_file
|
| 47 |
-
self.args = {}
|
| 48 |
-
# device arg struct memebr
|
| 49 |
-
self.arg_member = []
|
| 50 |
-
self.gen_class_name = gen_class_name
|
| 51 |
-
self.gen_kernel_name = gen_class_name + "Kernel"
|
| 52 |
-
self.template_args = []
|
| 53 |
-
self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int}
|
| 54 |
-
|
| 55 |
-
self.file_name = output_dir + "/device/" +gen_class_name +".h"
|
| 56 |
-
self.sample_dir = output_dir
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 60 |
-
self.project_root = project_root
|
| 61 |
-
self.this_file_root = output_dir + "/device/"
|
| 62 |
-
|
| 63 |
-
self.first_use_1stage = False
|
| 64 |
-
|
| 65 |
-
## gen kernel
|
| 66 |
-
self.gen_kernel = gen_ker.gen_kernel(self.template_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def __check_arg_type(self, temp_arg):
|
| 70 |
-
if temp_arg in self.__tempalate_arg_list.keys():
|
| 71 |
-
return self.__tempalate_arg_list[temp_arg]
|
| 72 |
-
|
| 73 |
-
find_sub = False
|
| 74 |
-
for candidate_arg in self.__tempalate_arg_list.keys():
|
| 75 |
-
if (temp_arg.find(candidate_arg) != -1):
|
| 76 |
-
return self.__tempalate_arg_list[candidate_arg]
|
| 77 |
-
|
| 78 |
-
return 'typename'
|
| 79 |
-
|
| 80 |
-
# def gen_B2b2bGemm_class():
|
| 81 |
-
def set_arch(self, sm_cap, mma_tp):
|
| 82 |
-
if sm_cap == 75 or sm_cap == 80 or sm_cap == 86:
|
| 83 |
-
self.arch = "cutlass::arch::Sm" + str(sm_cap)
|
| 84 |
-
|
| 85 |
-
if mma_tp is 'hmma1688':
|
| 86 |
-
self.mma_shape = [16, 8, 8]
|
| 87 |
-
self.mma_tp = 'hmma'
|
| 88 |
-
elif mma_tp is 'imma8816':
|
| 89 |
-
self.mma_tp = 'imma'
|
| 90 |
-
self.mma_shape = [8, 8, 16]
|
| 91 |
-
else:
|
| 92 |
-
return 0
|
| 93 |
-
|
| 94 |
-
def gen_include_header(self):
|
| 95 |
-
code = '''\
|
| 96 |
-
/* Auto Generated code - Do not edit.*/
|
| 97 |
-
|
| 98 |
-
#pragma once
|
| 99 |
-
|
| 100 |
-
#include \"{cutlass_root}cutlass/cutlass.h\"
|
| 101 |
-
#include \"{cutlass_root}cutlass/numeric_types.h\"
|
| 102 |
-
#include \"{cutlass_root}cutlass/arch/arch.h\"
|
| 103 |
-
#include \"{cutlass_root}cutlass/device_kernel.h\"
|
| 104 |
-
|
| 105 |
-
#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\"
|
| 106 |
-
|
| 107 |
-
#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\"
|
| 108 |
-
#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\"
|
| 109 |
-
#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\"
|
| 110 |
-
|
| 111 |
-
#include \"{project_root}../kernel/b2b_gemm.h\"
|
| 112 |
-
#include \"{project_root}../kernel/default_b2b_gemm.h\"
|
| 113 |
-
'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root)
|
| 114 |
-
include_user_header = ""
|
| 115 |
-
for header in self.user_header_file:
|
| 116 |
-
include_user_header += "#include \"" + header + "\"\n"
|
| 117 |
-
return code + include_user_header
|
| 118 |
-
|
| 119 |
-
def gen_code(self, sm_cap, mma_tp, ifprint = True):
|
| 120 |
-
self.set_arch(sm_cap, mma_tp)
|
| 121 |
-
|
| 122 |
-
self.update_b2b_args()
|
| 123 |
-
print(self.fuse_gemm_info)
|
| 124 |
-
self.update_b2b_class_template_args()
|
| 125 |
-
|
| 126 |
-
func_code = self.gen_all_func()
|
| 127 |
-
member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n"
|
| 128 |
-
|
| 129 |
-
gen_code = gen_ir.gen_template_class(self.gen_class_name, self.template_args, func_code + member_var_code)
|
| 130 |
-
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code)))
|
| 131 |
-
|
| 132 |
-
if ifprint:
|
| 133 |
-
print(code)
|
| 134 |
-
|
| 135 |
-
print("[INFO]: Gen device code output Dir: is ", self.file_name)
|
| 136 |
-
with open(self.file_name, 'w+') as f:
|
| 137 |
-
f.write(code)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage)
|
| 141 |
-
print(gen_kernel)
|
| 142 |
-
|
| 143 |
-
def update_b2b_class_template_args(self):
|
| 144 |
-
for arg in self.args.keys():
|
| 145 |
-
self.template_args.append([self.__check_arg_type(arg), arg, self.args[arg]])
|
| 146 |
-
|
| 147 |
-
def update_b2b_args(self):
|
| 148 |
-
|
| 149 |
-
self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp'])
|
| 150 |
-
self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format'])
|
| 151 |
-
|
| 152 |
-
cnt = 0
|
| 153 |
-
|
| 154 |
-
warp_M_tile = 32
|
| 155 |
-
|
| 156 |
-
# Determine maximum N_tile
|
| 157 |
-
Max_Ntile = 0
|
| 158 |
-
for layer in self.fuse_gemm_info:
|
| 159 |
-
n_tile = layer['mnk'][1]
|
| 160 |
-
if n_tile > Max_Ntile:
|
| 161 |
-
Max_Ntile = n_tile
|
| 162 |
-
if Max_Ntile >= 256:
|
| 163 |
-
warp_M_tile = 16
|
| 164 |
-
|
| 165 |
-
stages_temp = []
|
| 166 |
-
|
| 167 |
-
for layer in self.fuse_gemm_info:
|
| 168 |
-
cnt_str = str(cnt)
|
| 169 |
-
B_tp_str= 'ElementB' + cnt_str
|
| 170 |
-
B_format_str = 'LayoutB' + cnt_str
|
| 171 |
-
C_tp_str= 'ElementC' + cnt_str
|
| 172 |
-
C_format_str = 'LayoutC' + cnt_str
|
| 173 |
-
Acc_str = 'ElementAccumulator' + cnt_str
|
| 174 |
-
|
| 175 |
-
self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp'])
|
| 176 |
-
self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format'])
|
| 177 |
-
self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp'])
|
| 178 |
-
self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format'])
|
| 179 |
-
self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp'])
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
mnk = layer['mnk'][:]
|
| 183 |
-
|
| 184 |
-
tile_mnk = mnk[:]
|
| 185 |
-
|
| 186 |
-
tile_mnk[2] = 32 # force the ktile is 32
|
| 187 |
-
|
| 188 |
-
#N tile gen
|
| 189 |
-
if mnk[1] > 1024:
|
| 190 |
-
assert(0)
|
| 191 |
-
elif mnk[1] > 512:
|
| 192 |
-
tile_mnk[1] = 1024
|
| 193 |
-
elif mnk[1] > 256:
|
| 194 |
-
tile_mnk[1] = 512
|
| 195 |
-
elif mnk[1] > 128:
|
| 196 |
-
tile_mnk[1] = 256
|
| 197 |
-
elif mnk[1] > 64:
|
| 198 |
-
tile_mnk[1] = 128
|
| 199 |
-
elif mnk[1] > 32:
|
| 200 |
-
tile_mnk[1] = 64
|
| 201 |
-
else :
|
| 202 |
-
tile_mnk[1] = 32
|
| 203 |
-
|
| 204 |
-
if tile_mnk[1] == 512:
|
| 205 |
-
stages_temp.append(1)
|
| 206 |
-
else:
|
| 207 |
-
stages_temp.append(2)
|
| 208 |
-
|
| 209 |
-
tile_mnk[0] = 4 * warp_M_tile
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
epilogue_setted_type = helper.get_epilogue_tp(layer)
|
| 214 |
-
cutlass_epilogue_name = "LinearCombinationRelu"
|
| 215 |
-
if epilogue_setted_type.lower() == 'leakyrelu':
|
| 216 |
-
cutlass_epilogue_name = "LinearCombinationLeakyRelu"
|
| 217 |
-
elif epilogue_setted_type.lower() == 'identity':
|
| 218 |
-
cutlass_epilogue_name = "LinearCombination"
|
| 219 |
-
|
| 220 |
-
epilogue_str = 'EpilogueOutputOp' + cnt_str
|
| 221 |
-
if cnt != len(self.fuse_gemm_info) - 1:
|
| 222 |
-
n = layer['mnk'][1]
|
| 223 |
-
Fragments = tile_mnk[1] // 8 * 2
|
| 224 |
-
self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "<ElementC0_, " + str(Fragments) +", ElementAccumulator0_, ElementAccumulator0_>"
|
| 225 |
-
else:
|
| 226 |
-
n = layer['mnk'][1]
|
| 227 |
-
n_mod_8 = n % 4
|
| 228 |
-
N_align_elements = 1
|
| 229 |
-
if n_mod_8 == 0:
|
| 230 |
-
N_align_elements = 8
|
| 231 |
-
elif n_mod_8 == 4:
|
| 232 |
-
N_align_elements = 4
|
| 233 |
-
elif n_mod_8 == 2 or n_mod_8 == 6:
|
| 234 |
-
N_align_elements = 2
|
| 235 |
-
|
| 236 |
-
self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<ElementC0_, " + str(N_align_elements) + ", ElementAccumulator0_, ElementAccumulator0_>"
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
ThreadBlockShape_str = 'ThreadblockShape' + cnt_str
|
| 241 |
-
|
| 242 |
-
self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
|
| 243 |
-
|
| 244 |
-
WarpShape_str = 'WarpShape' + cnt_str
|
| 245 |
-
tile_mnk[0] = warp_M_tile
|
| 246 |
-
self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
|
| 247 |
-
cnt += 1
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp'])
|
| 251 |
-
self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format'])
|
| 252 |
-
|
| 253 |
-
self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape)
|
| 254 |
-
self.args['OperatorClass'] = 'arch::OpClassTensorOp'
|
| 255 |
-
self.args['ArchTag'] = self.arch
|
| 256 |
-
self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle'
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
for i in range(self.b2b_num):
|
| 260 |
-
self.args[helper.var_idx('Stages', i)] = "2"
|
| 261 |
-
|
| 262 |
-
self.args['AlignmentA'] = str(8)
|
| 263 |
-
self.args['AlignmentB'] = str(8)
|
| 264 |
-
self.args['SplitKSerial'] = 'false'
|
| 265 |
-
self.args['Operator'] = 'typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB0_, ElementC0_, ElementAccumulator0_>::Operator'
|
| 266 |
-
self.args['IsBetaZero'] = 'false'
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
def gen_using_kernel(self):
|
| 270 |
-
code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n"
|
| 271 |
-
code += " " + "ElementA,\n"
|
| 272 |
-
code += " " + "LayoutA,\n"
|
| 273 |
-
|
| 274 |
-
for i in range(self.b2b_num):
|
| 275 |
-
code += " " + helper.var_idx("ElementB", i) + ",\n"
|
| 276 |
-
code += " " + helper.var_idx("LayoutB", i) + ",\n"
|
| 277 |
-
code += " " + helper.var_idx("ElementC", i) + ",\n"
|
| 278 |
-
code += " " + helper.var_idx("LayoutC", i) + ",\n"
|
| 279 |
-
code += " " + helper.var_idx("ElementAccumulator", i) + ",\n"
|
| 280 |
-
code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n"
|
| 281 |
-
code += " " + helper.var_idx("ThreadblockShape", i) + ",\n"
|
| 282 |
-
code += " " + helper.var_idx("WarpShape", i) + ",\n"
|
| 283 |
-
|
| 284 |
-
code += " " + "ElementD,\n"
|
| 285 |
-
code += " " + "LayoutD,\n"
|
| 286 |
-
code += " " + "InstructionShape,\n"
|
| 287 |
-
code += " " + "OperatorClass,\n"
|
| 288 |
-
code += " " + "ArchTag,\n"
|
| 289 |
-
code += " " + "ThreadblockSwizzle,\n"
|
| 290 |
-
|
| 291 |
-
for i in range(self.b2b_num):
|
| 292 |
-
code += " " + helper.var_idx("Stages", i) + ",\n"
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
code += " " + "AlignmentA,\n"
|
| 296 |
-
code += " " + "AlignmentB,\n"
|
| 297 |
-
code += " " + "SplitKSerial,\n"
|
| 298 |
-
code += " " + "Operator,\n"
|
| 299 |
-
code += " " + "IsBetaZero_\n"
|
| 300 |
-
|
| 301 |
-
code += ">::B2bGemmKernel;\n\n"
|
| 302 |
-
|
| 303 |
-
return code
|
| 304 |
-
|
| 305 |
-
def gen_args(self):
|
| 306 |
-
|
| 307 |
-
def gen_arg_member(b2b_num):
|
| 308 |
-
data_members = []
|
| 309 |
-
|
| 310 |
-
for i in range(b2b_num):
|
| 311 |
-
member_type = "GemmCoord"
|
| 312 |
-
member_name = "problem_size_" + str(i)
|
| 313 |
-
data_members.append((member_type, member_name))
|
| 314 |
-
|
| 315 |
-
member_type = "TensorRef<ElementA const, LayoutA>"
|
| 316 |
-
member_name = "ref_A0"
|
| 317 |
-
data_members.append((member_type, member_name))
|
| 318 |
-
|
| 319 |
-
for i in range(b2b_num):
|
| 320 |
-
member_type = "TensorRef<ElementB" + str(i) + " const, LayoutB" + str(i) +">"
|
| 321 |
-
member_name = "ref_B" + str(i)
|
| 322 |
-
data_members.append((member_type, member_name))
|
| 323 |
-
member_type = "TensorRef<ElementC" + str(i) + " const, LayoutC" + str(i) +">"
|
| 324 |
-
member_name = "ref_C" + str(i)
|
| 325 |
-
data_members.append((member_type, member_name))
|
| 326 |
-
|
| 327 |
-
member_type = "TensorRef<ElementD, LayoutD>"
|
| 328 |
-
member_name = helper.var_idx("ref_D", b2b_num - 1)
|
| 329 |
-
data_members.append((member_type, member_name))
|
| 330 |
-
|
| 331 |
-
for i in range(b2b_num):
|
| 332 |
-
member_type = "typename EpilogueOutputOp" + str(i) + "::Params"
|
| 333 |
-
member_name = "epilogue" + str(i)
|
| 334 |
-
data_members.append((member_type, member_name))
|
| 335 |
-
|
| 336 |
-
data_members.append(('int', 'batch_count'))
|
| 337 |
-
|
| 338 |
-
return data_members
|
| 339 |
-
|
| 340 |
-
def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value):
|
| 341 |
-
constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
|
| 342 |
-
gen_ir.indentation + struct_name + " (): "
|
| 343 |
-
for i in range(inital_param_num):
|
| 344 |
-
final_param = ','
|
| 345 |
-
if i == inital_param_num - 1:
|
| 346 |
-
final_param = '{ }'
|
| 347 |
-
constructs_code += data_members[i][1] + inital_value + final_param
|
| 348 |
-
|
| 349 |
-
constructs_code += "\n"
|
| 350 |
-
return constructs_code
|
| 351 |
-
|
| 352 |
-
def gen_arg_struct_ctor(struct_name, data_members):
|
| 353 |
-
constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
|
| 354 |
-
gen_ir.indentation + struct_name + " (\n"
|
| 355 |
-
cnt = 0
|
| 356 |
-
param_num = len(data_members)
|
| 357 |
-
for param in data_members:
|
| 358 |
-
final = ',\n'
|
| 359 |
-
if cnt == param_num - 1:
|
| 360 |
-
final = '\n):\n'
|
| 361 |
-
constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final
|
| 362 |
-
cnt += 1
|
| 363 |
-
|
| 364 |
-
cnt = 0
|
| 365 |
-
for param in data_members:
|
| 366 |
-
final = '),\n'
|
| 367 |
-
if cnt == param_num - 1:
|
| 368 |
-
final = ") { }\n"
|
| 369 |
-
constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final
|
| 370 |
-
cnt += 1
|
| 371 |
-
|
| 372 |
-
constructs_code += "\n"
|
| 373 |
-
return constructs_code
|
| 374 |
-
|
| 375 |
-
# (variable type, variable name)
|
| 376 |
-
struct_member = gen_arg_member(self.b2b_num)
|
| 377 |
-
self.arg_member = struct_member
|
| 378 |
-
|
| 379 |
-
codeBody = ""
|
| 380 |
-
for each_member in struct_member:
|
| 381 |
-
codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n"
|
| 382 |
-
|
| 383 |
-
codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n"
|
| 384 |
-
codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n"
|
| 385 |
-
struct_code = gen_ir.gen_struct("Arguments", codeBody)
|
| 386 |
-
return struct_code
|
| 387 |
-
|
| 388 |
-
def gen_func_constructs(self):
|
| 389 |
-
code = self.gen_class_name +"() {}"
|
| 390 |
-
return code
|
| 391 |
-
|
| 392 |
-
def gen_func_initialize(self):
|
| 393 |
-
code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \
|
| 394 |
-
"// Determine grid shape\n" + \
|
| 395 |
-
"ThreadblockSwizzle threadblock_swizzle;\n" + \
|
| 396 |
-
"cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \
|
| 397 |
-
" args.problem_size_0, \n" + \
|
| 398 |
-
" { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \
|
| 399 |
-
" args.batch_count);\n" + \
|
| 400 |
-
"// Initialize the Params structure\n" + \
|
| 401 |
-
"params_ = typename B2bGemmKernel::Params{\n"
|
| 402 |
-
for i in range(self.b2b_num):
|
| 403 |
-
code += helper.var_idx(" args.problem_size_", i) + ",\n"
|
| 404 |
-
code += " grid_shape,\n" + \
|
| 405 |
-
" args.ref_A0.non_const_ref(),\n"
|
| 406 |
-
for i in range(self.b2b_num):
|
| 407 |
-
code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n"
|
| 408 |
-
code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n"
|
| 409 |
-
|
| 410 |
-
code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n"
|
| 411 |
-
for i in range(self.b2b_num):
|
| 412 |
-
code += helper.var_idx(" args.epilogue", i) + ",\n"
|
| 413 |
-
|
| 414 |
-
code += " args.batch_count\n"
|
| 415 |
-
code += "};\n" + \
|
| 416 |
-
"return Status::kSuccess;\n" + \
|
| 417 |
-
"}\n"
|
| 418 |
-
return code
|
| 419 |
-
|
| 420 |
-
def gen_func_run(self):
|
| 421 |
-
code = "Status run(cudaStream_t stream = nullptr) {\n" + \
|
| 422 |
-
"\n" + \
|
| 423 |
-
" ThreadblockSwizzle threadblock_swizzle;\n" + \
|
| 424 |
-
"\n" + \
|
| 425 |
-
" dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \
|
| 426 |
-
" dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \
|
| 427 |
-
"\n" + \
|
| 428 |
-
" cudaError_t result;\n" + \
|
| 429 |
-
"\n" + \
|
| 430 |
-
" int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \
|
| 431 |
-
" if (smem_size >= (48 << 10)) {\n" + \
|
| 432 |
-
" result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \
|
| 433 |
-
"\n" + \
|
| 434 |
-
" if (result != cudaSuccess) {\n" + \
|
| 435 |
-
" return Status::kErrorInternal;\n" + \
|
| 436 |
-
" }\n" + \
|
| 437 |
-
" }\n" + \
|
| 438 |
-
" cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);\n" + \
|
| 439 |
-
" result = cudaGetLastError();\n" + \
|
| 440 |
-
" return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \
|
| 441 |
-
" }\n"
|
| 442 |
-
|
| 443 |
-
return code
|
| 444 |
-
def gen_func_operator(self):
|
| 445 |
-
opeartor_with_arg_code = "Status operator()(\n" + \
|
| 446 |
-
" Arguments const &args,\n" + \
|
| 447 |
-
" void *workspace = nullptr,\n" + \
|
| 448 |
-
" cudaStream_t stream = nullptr) {\n" + \
|
| 449 |
-
" Status status = initialize(args, workspace);\n" + \
|
| 450 |
-
" \n" + \
|
| 451 |
-
" if (status == Status::kSuccess) {\n" + \
|
| 452 |
-
" status = run(stream);\n" + \
|
| 453 |
-
" }\n" + \
|
| 454 |
-
" return status;\n" + \
|
| 455 |
-
"}\n"
|
| 456 |
-
operator_code = "Status operator()(\n" + \
|
| 457 |
-
" cudaStream_t stream = nullptr) {\n" + \
|
| 458 |
-
" Status status = run(stream);\n" + \
|
| 459 |
-
" return status;\n" + \
|
| 460 |
-
"}\n"
|
| 461 |
-
return opeartor_with_arg_code + "\n" + operator_code
|
| 462 |
-
|
| 463 |
-
def gen_all_func(self):
|
| 464 |
-
return self.gen_using_kernel() + "\n" + \
|
| 465 |
-
self.gen_args() + "\n" + \
|
| 466 |
-
self.gen_func_constructs() + "\n" + \
|
| 467 |
-
self.gen_func_initialize() + "\n" + \
|
| 468 |
-
self.gen_func_run() + "\n" + \
|
| 469 |
-
self.gen_func_operator()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py
DELETED
|
@@ -1,249 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import helper
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
indentation = " "
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def append_word(word):
|
| 40 |
-
code = ""
|
| 41 |
-
code += word
|
| 42 |
-
code += " "
|
| 43 |
-
return code
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def gen_namespace(namespace, codeBody):
|
| 47 |
-
code_gen = "namespace " + namespace + " {\n"
|
| 48 |
-
code_gen += codeBody
|
| 49 |
-
code_gen += "} // namespace " + namespace + "\n"
|
| 50 |
-
return code_gen
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def gen_expression(type, lval, rval = None):
|
| 54 |
-
code_gen = ""
|
| 55 |
-
code_gen += append_word(type)
|
| 56 |
-
code_gen += append_word(lval)
|
| 57 |
-
if rval is not None:
|
| 58 |
-
code_gen += append_word("=")
|
| 59 |
-
code_gen += append_word(rval)
|
| 60 |
-
return code_gen
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def gen_class(name, codeBody, inheritance_code = None):
|
| 64 |
-
code_gen = ""
|
| 65 |
-
if inheritance_code is None:
|
| 66 |
-
code_gen = "class " + name + "{\n"
|
| 67 |
-
else:
|
| 68 |
-
code_gen = "class " + name + " : "+ inheritance_code + "{\n"
|
| 69 |
-
code_gen += codeBody
|
| 70 |
-
code_gen += "}; // class " + name + "\n"
|
| 71 |
-
return code_gen
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def gen_struct(name, codeBody, specialized = None):
|
| 75 |
-
specialized_code = ""
|
| 76 |
-
if specialized is not None:
|
| 77 |
-
specialized_code = "<" + specialized + ">"
|
| 78 |
-
code_gen = "struct " + name + specialized_code + "{\n"
|
| 79 |
-
code_gen += codeBody
|
| 80 |
-
code_gen += "}; // struct " + name + "\n"
|
| 81 |
-
return code_gen
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def gen_template_arg(arg_type, arg_name, default_val = None):
|
| 85 |
-
rval = None
|
| 86 |
-
if default_val is not None:
|
| 87 |
-
rval = str(default_val)
|
| 88 |
-
|
| 89 |
-
arg_typename = ""
|
| 90 |
-
if arg_type is int:
|
| 91 |
-
arg_typename = "int"
|
| 92 |
-
elif arg_type is bool:
|
| 93 |
-
arg_typename = "bool"
|
| 94 |
-
else:
|
| 95 |
-
arg_typename = "typename"
|
| 96 |
-
|
| 97 |
-
internal_arg_name = arg_name + "_"
|
| 98 |
-
|
| 99 |
-
code_gen = indentation
|
| 100 |
-
code_gen += gen_expression(arg_typename, internal_arg_name, rval)
|
| 101 |
-
|
| 102 |
-
return code_gen
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def gen_template_args(args, set_default = True):
|
| 106 |
-
arg_len = len(args)
|
| 107 |
-
cnt = 1
|
| 108 |
-
code_gen = ""
|
| 109 |
-
for arg_tuple in args:
|
| 110 |
-
arg_type = arg_tuple[0]
|
| 111 |
-
arg_name = arg_tuple[1]
|
| 112 |
-
arg_default_val = None
|
| 113 |
-
if len(arg_tuple) == 3 and set_default:
|
| 114 |
-
arg_default_val = arg_tuple[2]
|
| 115 |
-
|
| 116 |
-
code_gen += gen_template_arg(arg_type, arg_name, arg_default_val)
|
| 117 |
-
if cnt != arg_len:
|
| 118 |
-
code_gen += ",\n"
|
| 119 |
-
cnt += 1
|
| 120 |
-
|
| 121 |
-
return code_gen
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def gen_template_head(args, set_default = True):
|
| 125 |
-
code_gen = "template <\n"
|
| 126 |
-
code_gen += gen_template_args(args, set_default)
|
| 127 |
-
code_gen += ">\n"
|
| 128 |
-
return code_gen
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def export_template_args(args):
|
| 132 |
-
code_gen = "public:\n"
|
| 133 |
-
for arg_tuple in args:
|
| 134 |
-
code_gen += indentation
|
| 135 |
-
arg_type = arg_tuple[0]
|
| 136 |
-
arg_name = arg_tuple[1]
|
| 137 |
-
internal_arg_name = arg_name + "_"
|
| 138 |
-
|
| 139 |
-
typename = ""
|
| 140 |
-
if arg_type is int:
|
| 141 |
-
typename = "static int const"
|
| 142 |
-
elif arg_type is bool:
|
| 143 |
-
typename = "static bool const"
|
| 144 |
-
else:
|
| 145 |
-
typename = "using"
|
| 146 |
-
|
| 147 |
-
code_gen += gen_expression(typename, arg_name, internal_arg_name)
|
| 148 |
-
code_gen += ";\n"
|
| 149 |
-
return code_gen
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None):
|
| 153 |
-
code_gen = ""
|
| 154 |
-
|
| 155 |
-
code_gen += gen_template_head(args, set_default)
|
| 156 |
-
code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code)
|
| 157 |
-
|
| 158 |
-
return code_gen
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True):
|
| 162 |
-
code_gen = ""
|
| 163 |
-
code_gen += gen_template_head(args, set_default)
|
| 164 |
-
code = export_template_args(args) + codeBody
|
| 165 |
-
if export_args is False:
|
| 166 |
-
code = codeBody
|
| 167 |
-
code_gen += gen_struct(struct_name, code , speicalized)
|
| 168 |
-
|
| 169 |
-
return code_gen
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def gen_declare_template_struct(name, *params):
|
| 173 |
-
code = name + "<"
|
| 174 |
-
cnt = 0
|
| 175 |
-
param_num = len(params)
|
| 176 |
-
for param in params:
|
| 177 |
-
final = ", "
|
| 178 |
-
if cnt == param_num - 1:
|
| 179 |
-
final = ""
|
| 180 |
-
code += param + final
|
| 181 |
-
cnt += 1
|
| 182 |
-
code += ">;\n"
|
| 183 |
-
return code
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def filtered_param(params, name_and_value_pair, keep_ = False):
|
| 187 |
-
rtn_template_args = []
|
| 188 |
-
speicalized_template_args = []
|
| 189 |
-
|
| 190 |
-
for param in params:
|
| 191 |
-
param_name = ""
|
| 192 |
-
if len(param) >= 1:
|
| 193 |
-
param_name = param[1]
|
| 194 |
-
else:
|
| 195 |
-
param_name = param[0]
|
| 196 |
-
|
| 197 |
-
hit_flag = False
|
| 198 |
-
set_value = ""
|
| 199 |
-
for n_v_pair in name_and_value_pair:
|
| 200 |
-
|
| 201 |
-
filter_name = n_v_pair[0]
|
| 202 |
-
set_value = n_v_pair[1]
|
| 203 |
-
|
| 204 |
-
if param_name == (filter_name + "_") or param_name == filter_name :
|
| 205 |
-
hit_flag = True
|
| 206 |
-
break
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
if hit_flag is False:
|
| 210 |
-
rtn_template_args.append(param)
|
| 211 |
-
|
| 212 |
-
if hit_flag is True:
|
| 213 |
-
speicalized_template_args.append(set_value)
|
| 214 |
-
else:
|
| 215 |
-
if keep_ is True:
|
| 216 |
-
speicalized_template_args.append(param_name + "_")
|
| 217 |
-
else:
|
| 218 |
-
speicalized_template_args.append(param_name)
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
specialized_template_arg_str = helper.list_2_string(speicalized_template_args)
|
| 222 |
-
|
| 223 |
-
return rtn_template_args, specialized_template_arg_str
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True):
|
| 227 |
-
code = "void " + func_name + "(\n"
|
| 228 |
-
for arg in arg_lists:
|
| 229 |
-
arg_tp = arg[0]
|
| 230 |
-
arg_nm = arg[1]
|
| 231 |
-
code += " " + arg_tp + " " + arg_nm + ",\n"
|
| 232 |
-
code += "cudaStream_t stream)"
|
| 233 |
-
if only_declare :
|
| 234 |
-
return code
|
| 235 |
-
code += "{\n"
|
| 236 |
-
|
| 237 |
-
code += code_body + "\n"
|
| 238 |
-
code += "}\n"
|
| 239 |
-
return code
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def indent_level(code, level = 0):
|
| 243 |
-
rtn_code = ""
|
| 244 |
-
for i in range(level):
|
| 245 |
-
rtn_code += " "
|
| 246 |
-
|
| 247 |
-
rtn_code += code
|
| 248 |
-
|
| 249 |
-
return rtn_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py
DELETED
|
@@ -1,476 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import gen_ir
|
| 34 |
-
import helper
|
| 35 |
-
import gen_threadblock as gen_tb
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class gen_default_Gemm:
|
| 39 |
-
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
|
| 40 |
-
self.gen_class_name = "B2bGemm"
|
| 41 |
-
self.template_param = template_param
|
| 42 |
-
self.b2b_num = b2b_num
|
| 43 |
-
|
| 44 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 45 |
-
self.project_root = project_root
|
| 46 |
-
|
| 47 |
-
def gen_B2bMma(self, specialized_template_args):
|
| 48 |
-
code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n"
|
| 49 |
-
code += specialized_template_args
|
| 50 |
-
code += ">::ThreadblockB2bMma;\n"
|
| 51 |
-
|
| 52 |
-
# print(code)
|
| 53 |
-
return code
|
| 54 |
-
|
| 55 |
-
def gen_epilogue(self):
|
| 56 |
-
epilogue_code = ""
|
| 57 |
-
epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n"
|
| 58 |
-
|
| 59 |
-
epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n"
|
| 60 |
-
epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n"
|
| 61 |
-
epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n"
|
| 62 |
-
epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n"
|
| 63 |
-
epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n"
|
| 64 |
-
epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n"
|
| 65 |
-
epilogue_code += ">::Epilogue;\n"
|
| 66 |
-
|
| 67 |
-
epilogue_code += "using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;\n\n"
|
| 68 |
-
|
| 69 |
-
return epilogue_code
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def gen_include_header(self):
|
| 73 |
-
code = '''
|
| 74 |
-
/* Auto Generated code - Do not edit.*/
|
| 75 |
-
|
| 76 |
-
#pragma once
|
| 77 |
-
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
| 78 |
-
|
| 79 |
-
#include \"{cutlass_dir}cutlass/layout/matrix.h\"
|
| 80 |
-
#include \"{cutlass_dir}cutlass/numeric_types.h\"
|
| 81 |
-
|
| 82 |
-
#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\"
|
| 83 |
-
#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\"
|
| 84 |
-
|
| 85 |
-
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
| 86 |
-
#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\"
|
| 87 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\"
|
| 88 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\"
|
| 89 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\"
|
| 90 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\"
|
| 91 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\"
|
| 92 |
-
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\"
|
| 93 |
-
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\"
|
| 94 |
-
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\"
|
| 95 |
-
|
| 96 |
-
#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\"
|
| 97 |
-
|
| 98 |
-
#include \"../kernel/b2b_gemm.h\"
|
| 99 |
-
#include \"../threadblock/default_b2b_mma.h\"
|
| 100 |
-
'''.format(cutlass_dir=self.cutlass_deps_root)
|
| 101 |
-
return code
|
| 102 |
-
|
| 103 |
-
def gen_code(self):
|
| 104 |
-
gen_using = ''
|
| 105 |
-
# Generate default template struct
|
| 106 |
-
gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
filter_list = []
|
| 110 |
-
filter_list.append(('Stages', 2))
|
| 111 |
-
filter_list.append(("OperatorClass", "arch::OpClassTensorOp"))
|
| 112 |
-
filter_list.append(("ArchTag", "arch::Sm75"))
|
| 113 |
-
|
| 114 |
-
for i in range(self.b2b_num):
|
| 115 |
-
filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor"))
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
B2bMma_code = self.gen_B2bMma(speicalized_template_args)
|
| 122 |
-
epilogue_and_rest_code = self.gen_epilogue()
|
| 123 |
-
|
| 124 |
-
gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False)
|
| 125 |
-
|
| 126 |
-
code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code)))
|
| 127 |
-
|
| 128 |
-
return self.gen_include_header() + code
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
class gen_Kernel:
|
| 132 |
-
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
|
| 133 |
-
self.gen_class_name = "B2bGemm"
|
| 134 |
-
self.template_param = template_param
|
| 135 |
-
self.b2bnum = b2b_num
|
| 136 |
-
|
| 137 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 138 |
-
self.project_root = project_root
|
| 139 |
-
|
| 140 |
-
def gen_include_header(self):
|
| 141 |
-
code = '''
|
| 142 |
-
#pragma once
|
| 143 |
-
|
| 144 |
-
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
| 145 |
-
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
| 146 |
-
#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root)
|
| 147 |
-
return code
|
| 148 |
-
|
| 149 |
-
def gen_Params(self):
|
| 150 |
-
gen_param = ""
|
| 151 |
-
for i in range(self.b2bnum):
|
| 152 |
-
gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n"
|
| 153 |
-
gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n"
|
| 154 |
-
gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n"
|
| 155 |
-
gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n"
|
| 156 |
-
|
| 157 |
-
for i in range(self.b2bnum):
|
| 158 |
-
gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n"
|
| 159 |
-
gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n"
|
| 160 |
-
if i == self.b2bnum - 1:
|
| 161 |
-
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n"
|
| 162 |
-
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n"
|
| 163 |
-
|
| 164 |
-
else:
|
| 165 |
-
gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n"
|
| 166 |
-
gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n"
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n"
|
| 172 |
-
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n"
|
| 173 |
-
|
| 174 |
-
for i in range(self.b2bnum):
|
| 175 |
-
gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n"
|
| 176 |
-
|
| 177 |
-
gen_param += " " + 'int batch_count' + ";\n"
|
| 178 |
-
gen_param += " " + 'int gemm_k_iterations_0' + ";\n"
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
return gen_param
|
| 182 |
-
|
| 183 |
-
def gen_Memberfunc(self):
|
| 184 |
-
code_default = "\nCUTLASS_HOST_DEVICE\n"
|
| 185 |
-
code_default += "Params()"
|
| 186 |
-
|
| 187 |
-
code_default += " { } \n\n"
|
| 188 |
-
|
| 189 |
-
code_construct = "\nCUTLASS_HOST_DEVICE\n"
|
| 190 |
-
code_construct += "Params(\n"
|
| 191 |
-
|
| 192 |
-
for i in range(self.b2bnum):
|
| 193 |
-
code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n"
|
| 194 |
-
|
| 195 |
-
code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n"
|
| 196 |
-
|
| 197 |
-
code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n"
|
| 198 |
-
|
| 199 |
-
for i in range(self.b2bnum):
|
| 200 |
-
code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n"
|
| 201 |
-
if i == self.b2bnum - 1:
|
| 202 |
-
code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n"
|
| 203 |
-
else:
|
| 204 |
-
code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n"
|
| 205 |
-
|
| 206 |
-
code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n"
|
| 207 |
-
for i in range(self.b2bnum):
|
| 208 |
-
code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n"
|
| 209 |
-
|
| 210 |
-
code_construct += " " + "int batch_count = 1\n"
|
| 211 |
-
|
| 212 |
-
code_construct += "):\n"
|
| 213 |
-
|
| 214 |
-
for i in range(self.b2bnum):
|
| 215 |
-
code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n"
|
| 216 |
-
|
| 217 |
-
code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n"
|
| 218 |
-
code_construct += " " + "params_A0(ref_A0.layout()),\n"
|
| 219 |
-
code_construct += " " + "ref_A0(ref_A0),\n"
|
| 220 |
-
|
| 221 |
-
for i in range(self.b2bnum):
|
| 222 |
-
code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n"
|
| 223 |
-
code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n"
|
| 224 |
-
code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n"
|
| 225 |
-
code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n"
|
| 226 |
-
|
| 227 |
-
code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n"
|
| 228 |
-
code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n"
|
| 229 |
-
|
| 230 |
-
for i in range(self.b2bnum):
|
| 231 |
-
code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n"
|
| 232 |
-
|
| 233 |
-
code_construct += " " + "batch_count(batch_count) {\n"
|
| 234 |
-
code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n"
|
| 235 |
-
|
| 236 |
-
code_construct += "}\n"
|
| 237 |
-
|
| 238 |
-
return code_default + code_construct
|
| 239 |
-
|
| 240 |
-
def gen_using(self):
|
| 241 |
-
code_using = ""
|
| 242 |
-
|
| 243 |
-
for i in range(self.b2bnum - 1):
|
| 244 |
-
code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n"
|
| 245 |
-
|
| 246 |
-
code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n"
|
| 247 |
-
|
| 248 |
-
for i in range(self.b2bnum - 1):
|
| 249 |
-
code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n"
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n"
|
| 253 |
-
code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n"
|
| 254 |
-
|
| 255 |
-
code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc())
|
| 256 |
-
|
| 257 |
-
code_using += "union SharedStorage {\n"
|
| 258 |
-
code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n"
|
| 259 |
-
code_using += " " + "typename Epilogue::SharedStorage epilogue;\n"
|
| 260 |
-
code_using += "};\n"
|
| 261 |
-
|
| 262 |
-
return code_using
|
| 263 |
-
|
| 264 |
-
def gen_can_implement(self):
|
| 265 |
-
gen_code = ""
|
| 266 |
-
return gen_code
|
| 267 |
-
|
| 268 |
-
def gen_operator_and_constr(self):
|
| 269 |
-
ctr_code = "CUTLASS_HOST_DEVICE\n"
|
| 270 |
-
ctr_code += self.gen_class_name + "() { } \n\n"
|
| 271 |
-
operator_code = "CUTLASS_DEVICE\n"
|
| 272 |
-
operator_code += "void operator()(Params const ¶ms, SharedStorage &shared_storage) {\n"
|
| 273 |
-
operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n"
|
| 274 |
-
operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
|
| 275 |
-
operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n"
|
| 276 |
-
operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n"
|
| 277 |
-
operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n"
|
| 278 |
-
operator_code += " " + " " + "return;\n"
|
| 279 |
-
operator_code += " " + "}\n"
|
| 280 |
-
|
| 281 |
-
operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n"
|
| 282 |
-
operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n"
|
| 283 |
-
operator_code += " " + " " + "0\n"
|
| 284 |
-
operator_code += " " + "};\n"
|
| 285 |
-
|
| 286 |
-
for i in range(self.b2bnum):
|
| 287 |
-
operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n"
|
| 288 |
-
operator_code += " " + " " + "0,\n"
|
| 289 |
-
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n"
|
| 290 |
-
operator_code += " " + "};\n"
|
| 291 |
-
|
| 292 |
-
operator_code += " " + "int thread_idx = threadIdx.x;\n\n"
|
| 293 |
-
|
| 294 |
-
operator_code += " " + "MatrixCoord threadblock_offset(\n"
|
| 295 |
-
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n"
|
| 296 |
-
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n"
|
| 297 |
-
operator_code += " " + ");\n"
|
| 298 |
-
|
| 299 |
-
operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n"
|
| 300 |
-
operator_code += " " + " " + "params.params_A0,\n"
|
| 301 |
-
operator_code += " " + " " + "params.ref_A0.data(),\n"
|
| 302 |
-
operator_code += " " + " " + "params.problem_size_0.mk(),\n"
|
| 303 |
-
operator_code += " " + " " + "thread_idx,\n"
|
| 304 |
-
operator_code += " " + " " + "tb_offset_A0);\n"
|
| 305 |
-
|
| 306 |
-
operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n"
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
for i in range (self.b2bnum):
|
| 310 |
-
operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n"
|
| 311 |
-
operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n"
|
| 312 |
-
operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n"
|
| 313 |
-
operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n"
|
| 314 |
-
operator_code += " " + " " + "thread_idx,\n"
|
| 315 |
-
operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n"
|
| 316 |
-
operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n"
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
for i in range (self.b2bnum - 1):
|
| 320 |
-
operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n"
|
| 321 |
-
operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n"
|
| 322 |
-
operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n"
|
| 323 |
-
operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n"
|
| 324 |
-
operator_code += " " + " " + "thread_idx,\n"
|
| 325 |
-
operator_code += " " + " " + "threadblock_offset" + ");\n"
|
| 326 |
-
operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n"
|
| 327 |
-
operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n"
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
for i in range (self.b2bnum - 1):
|
| 331 |
-
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n"
|
| 335 |
-
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
|
| 336 |
-
|
| 337 |
-
for i in range (self.b2bnum - 1):
|
| 338 |
-
operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n"
|
| 339 |
-
|
| 340 |
-
operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n"
|
| 341 |
-
|
| 342 |
-
operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n"
|
| 343 |
-
operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n"
|
| 344 |
-
|
| 345 |
-
operator_code += " " + "src_accum.clear();\n"
|
| 346 |
-
operator_code += " " + "accumulators.clear();\n"
|
| 347 |
-
operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, "
|
| 348 |
-
|
| 349 |
-
for i in range(self.b2bnum):
|
| 350 |
-
operator_code += helper.var_idx("iterator_B", i) + ", "
|
| 351 |
-
|
| 352 |
-
operator_code += "src_accum"
|
| 353 |
-
if self.b2bnum != 1:
|
| 354 |
-
operator_code += ", "
|
| 355 |
-
for i in range(self.b2bnum - 1):
|
| 356 |
-
operator_code += helper.var_idx("output_op_", i) + ", "
|
| 357 |
-
|
| 358 |
-
for i in range(self.b2bnum - 1):
|
| 359 |
-
operator_code += helper.var_idx("epilogue_", i) + ", "
|
| 360 |
-
|
| 361 |
-
for i in range(self.b2bnum - 1):
|
| 362 |
-
final = ", "
|
| 363 |
-
if i == self.b2bnum - 2:
|
| 364 |
-
final =""
|
| 365 |
-
operator_code += helper.var_idx("iterator_C", i) + final
|
| 366 |
-
operator_code += ");\n"
|
| 367 |
-
|
| 368 |
-
operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n"
|
| 369 |
-
operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n"
|
| 374 |
-
operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n"
|
| 375 |
-
operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n"
|
| 376 |
-
operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
|
| 377 |
-
operator_code += " " + " " + "thread_idx,\n"
|
| 378 |
-
operator_code += " " + " " + "threadblock_offset\n"
|
| 379 |
-
operator_code += " " + ");\n"
|
| 380 |
-
operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n"
|
| 381 |
-
|
| 382 |
-
operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n"
|
| 383 |
-
|
| 384 |
-
operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n"
|
| 385 |
-
operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n"
|
| 386 |
-
operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n"
|
| 387 |
-
operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
|
| 388 |
-
operator_code += " " + " " + "thread_idx,\n"
|
| 389 |
-
operator_code += " " + " " + "threadblock_offset\n"
|
| 390 |
-
operator_code += " " + ");\n"
|
| 391 |
-
operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n"
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
operator_code += " " + "Epilogue epilogue(\n"
|
| 395 |
-
operator_code += " " + " " + "shared_storage.epilogue,\n"
|
| 396 |
-
operator_code += " " + " " + "thread_idx,\n"
|
| 397 |
-
operator_code += " " + " " + "warp_idx,\n"
|
| 398 |
-
operator_code += " " + " " + "lane_idx\n"
|
| 399 |
-
operator_code += " " + ");\n"
|
| 400 |
-
|
| 401 |
-
operator_code += " " + "epilogue("
|
| 402 |
-
operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", "
|
| 403 |
-
operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", "
|
| 404 |
-
operator_code += "accumulators, "
|
| 405 |
-
operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n"
|
| 406 |
-
operator_code += "}\n"
|
| 407 |
-
|
| 408 |
-
return ctr_code + operator_code
|
| 409 |
-
|
| 410 |
-
def gen_include_header(self):
|
| 411 |
-
code = '''
|
| 412 |
-
#pragma once
|
| 413 |
-
|
| 414 |
-
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
| 415 |
-
|
| 416 |
-
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
| 417 |
-
#include \"{cutlass_dir}cutlass/matrix_coord.h\"
|
| 418 |
-
#include \"{cutlass_dir}cutlass/semaphore.h\"
|
| 419 |
-
'''.format(cutlass_dir=self.cutlass_deps_root)
|
| 420 |
-
return code
|
| 421 |
-
def gen_code(self):
|
| 422 |
-
|
| 423 |
-
template_param = []
|
| 424 |
-
template_param.append(("typename", "B2bMma"))
|
| 425 |
-
template_param.append(("typename", "Epilogue"))
|
| 426 |
-
template_param.append(("typename", "ThreadblockSwizzle"))
|
| 427 |
-
template_param.append((bool, "SplitKSerial"))
|
| 428 |
-
|
| 429 |
-
code_body = ""
|
| 430 |
-
code_body += self.gen_using()
|
| 431 |
-
code_body += self.gen_operator_and_constr()
|
| 432 |
-
|
| 433 |
-
struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body)
|
| 434 |
-
code = self.gen_include_header()
|
| 435 |
-
code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code)))
|
| 436 |
-
|
| 437 |
-
return self.gen_include_header() + code
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
class gen_kernel:
|
| 442 |
-
def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root):
|
| 443 |
-
self.template_param = template_param
|
| 444 |
-
|
| 445 |
-
self.gen_class_name = "B2bGemm"
|
| 446 |
-
self.gen_kernel_name = gen_class_name + "Kernel"
|
| 447 |
-
self.template_args = []
|
| 448 |
-
|
| 449 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 450 |
-
self.project_root = project_root
|
| 451 |
-
|
| 452 |
-
self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
| 453 |
-
self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
| 454 |
-
|
| 455 |
-
# Include gen_threadBlock
|
| 456 |
-
self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root)
|
| 457 |
-
|
| 458 |
-
self.file_dir = output_dir + "/kernel/"
|
| 459 |
-
|
| 460 |
-
def gen_code(self, first_use_1stage):
|
| 461 |
-
|
| 462 |
-
default_b2b_gemm = self.gen_default_b2b_gemm.gen_code()
|
| 463 |
-
|
| 464 |
-
print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir)
|
| 465 |
-
|
| 466 |
-
with open(self.file_dir + "default_b2b_gemm.h", "w+") as f:
|
| 467 |
-
f.write(default_b2b_gemm)
|
| 468 |
-
|
| 469 |
-
kernel = self.gen_Kerenl.gen_code()
|
| 470 |
-
print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir)
|
| 471 |
-
|
| 472 |
-
with open(self.file_dir + "b2b_gemm.h", "w+") as f:
|
| 473 |
-
f.write(kernel)
|
| 474 |
-
|
| 475 |
-
# Call code to gen threadblock
|
| 476 |
-
self.gen_threadBlock.gen_code(first_use_1stage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import helper
|
| 34 |
-
import gen_ir as ir
|
| 35 |
-
|
| 36 |
-
class gen_test:
|
| 37 |
-
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
| 38 |
-
self.fuse_gemm_info = fuse_gemm_info
|
| 39 |
-
self.gen_class_name = gen_class_name
|
| 40 |
-
self.user_header_file = user_header_file
|
| 41 |
-
self.sample_dir = output_dir
|
| 42 |
-
self.b2b_num = len(fuse_gemm_info)
|
| 43 |
-
|
| 44 |
-
def gen_cpp_sample(self):
|
| 45 |
-
code = "/* Auto Generated code - Do not edit.*/\n"
|
| 46 |
-
code += "#include <cstdio> \n"
|
| 47 |
-
|
| 48 |
-
code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
|
| 49 |
-
code += "#include \"cutlass/cutlass.h\" \n"
|
| 50 |
-
|
| 51 |
-
code += "#include \"../cutlass_irrelevant.h\" \n"
|
| 52 |
-
code += "#include \"../cutlass_verify.h\" \n"
|
| 53 |
-
|
| 54 |
-
code += "#include \"leaky_bias.h\" \n"
|
| 55 |
-
|
| 56 |
-
code += "#include \"utils.h\" \n"
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
code += "int main(int args, char * argv[]) {\n"
|
| 61 |
-
code += " " + "int M = atoi(argv[1]);\n"
|
| 62 |
-
code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n"
|
| 63 |
-
code += " " + "if(args == 3);\n"
|
| 64 |
-
code += " " + " " + "K0 = atoi(argv[2]);\n"
|
| 65 |
-
code += " " + "int B = 1;\n"
|
| 66 |
-
code += " " + "if(args == 4);\n"
|
| 67 |
-
code += " " + " " + "B = atoi(argv[3]);\n"
|
| 68 |
-
|
| 69 |
-
code += " " + "srand(1234UL);\n"
|
| 70 |
-
code += " " + "int device_id = 0;\n"
|
| 71 |
-
code += " " + "cudaGetDevice(&device_id);\n"
|
| 72 |
-
code += " " + "cudaDeviceProp prop;\n"
|
| 73 |
-
code += " " + "cudaGetDeviceProperties(&prop, device_id);\n"
|
| 74 |
-
code += " " + "int sm = prop.major *10 + prop.minor;\n"
|
| 75 |
-
code += "using ElementCompute = cutlass::half_t;\n"
|
| 76 |
-
|
| 77 |
-
for i in range(self.b2b_num):
|
| 78 |
-
code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n"
|
| 79 |
-
addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
|
| 80 |
-
if addbias:
|
| 81 |
-
code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n"
|
| 82 |
-
else:
|
| 83 |
-
code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n"
|
| 84 |
-
|
| 85 |
-
code += " " + "size_t flops = 0;\n"
|
| 86 |
-
|
| 87 |
-
for i in range(self.b2b_num):
|
| 88 |
-
m = self.fuse_gemm_info[i]['mnk'][0]
|
| 89 |
-
n = self.fuse_gemm_info[i]['mnk'][1]
|
| 90 |
-
k = self.fuse_gemm_info[i]['mnk'][2]
|
| 91 |
-
|
| 92 |
-
bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])
|
| 93 |
-
|
| 94 |
-
this_k = "K0"
|
| 95 |
-
if (i > 0):
|
| 96 |
-
this_k = str(k)
|
| 97 |
-
|
| 98 |
-
code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n"
|
| 99 |
-
|
| 100 |
-
code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n"
|
| 101 |
-
|
| 102 |
-
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n"
|
| 103 |
-
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n"
|
| 104 |
-
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n"
|
| 105 |
-
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n"
|
| 106 |
-
|
| 107 |
-
code += " " + helper.var_idx("Mat_A", i) + ".init();\n"
|
| 108 |
-
code += " " + helper.var_idx("Mat_B", i) + ".init();\n"
|
| 109 |
-
code += " " + helper.var_idx("Mat_C", i) + ".init();\n"
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n"
|
| 114 |
-
|
| 115 |
-
params = []
|
| 116 |
-
params.append("M")
|
| 117 |
-
params.append("B")
|
| 118 |
-
|
| 119 |
-
params.append("Mat_A0.device_ptr")
|
| 120 |
-
for i in range(self.b2b_num):
|
| 121 |
-
params.append(helper.var_idx("Mat_B", i) + ".device_ptr")
|
| 122 |
-
params.append(helper.var_idx("Mat_C", i) + ".device_ptr")
|
| 123 |
-
if i != self.b2b_num-1:
|
| 124 |
-
params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr")
|
| 125 |
-
params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr")
|
| 126 |
-
|
| 127 |
-
code += " " + "Param arguments = {\n"
|
| 128 |
-
code += " " + " " + "M,\n"
|
| 129 |
-
code += " " + " " + "K0,\n"
|
| 130 |
-
code += " " + " " + "B,\n"
|
| 131 |
-
|
| 132 |
-
code += " " + " " + "reinterpret_cast<const void*>(Mat_A0.device_ptr),\n"
|
| 133 |
-
cnt = 1
|
| 134 |
-
for i in range(self.b2b_num):
|
| 135 |
-
bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
|
| 136 |
-
code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n"
|
| 137 |
-
cnt += 1
|
| 138 |
-
if bias_flag:
|
| 139 |
-
code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n"
|
| 140 |
-
cnt += 1
|
| 141 |
-
else:
|
| 142 |
-
code += " " + " " + "reinterpret_cast<const void*>(NULL),\n"
|
| 143 |
-
|
| 144 |
-
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
| 145 |
-
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
| 146 |
-
for arg in epilogue_args:
|
| 147 |
-
arg_value = str(arg[2])
|
| 148 |
-
|
| 149 |
-
code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n"
|
| 150 |
-
|
| 151 |
-
if i != self.b2b_num - 1:
|
| 152 |
-
code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n"
|
| 153 |
-
else:
|
| 154 |
-
code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n"
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
code += " " + "TI(FUSED_CUTLASS);\n"
|
| 160 |
-
code += " " + "for(int i = 0; i < 100; i++){\n"
|
| 161 |
-
code += " " + " " + "one_api(arguments, sm, NULL);\n"
|
| 162 |
-
|
| 163 |
-
code += " " + "}\n"
|
| 164 |
-
code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n"
|
| 165 |
-
|
| 166 |
-
code += "\n"
|
| 167 |
-
|
| 168 |
-
for i in range(self.b2b_num):
|
| 169 |
-
code_this = ""
|
| 170 |
-
|
| 171 |
-
N_str = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 172 |
-
|
| 173 |
-
code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
|
| 174 |
-
code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n"
|
| 175 |
-
ldmA = str(self.fuse_gemm_info[i]['mnk'][2])
|
| 176 |
-
if i == 0:
|
| 177 |
-
ldmA = "K0"
|
| 178 |
-
ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
|
| 179 |
-
if i == 0:
|
| 180 |
-
ldmB = "K0"
|
| 181 |
-
ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 182 |
-
|
| 183 |
-
ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
|
| 184 |
-
|
| 185 |
-
if self.fuse_gemm_info[i]['A_format'] is 'Col':
|
| 186 |
-
ldmA = "M"
|
| 187 |
-
if self.fuse_gemm_info[i]['B_format'] is 'Row':
|
| 188 |
-
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 189 |
-
if self.fuse_gemm_info[i]['C_format'] is 'Col':
|
| 190 |
-
ldmC = "M"
|
| 191 |
-
|
| 192 |
-
if i == 0:
|
| 193 |
-
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
| 194 |
-
else:
|
| 195 |
-
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
| 196 |
-
|
| 197 |
-
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
|
| 198 |
-
|
| 199 |
-
M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
|
| 200 |
-
|
| 201 |
-
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
|
| 202 |
-
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n"
|
| 203 |
-
code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
|
| 204 |
-
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
|
| 205 |
-
arg_value = str(epilogue_arg[2])
|
| 206 |
-
code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")"
|
| 207 |
-
code_this += " " + " },\n"
|
| 208 |
-
code_this += " " + " " + "B};\n"
|
| 209 |
-
|
| 210 |
-
code += code_this
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
code += " " + "TI(UNFUSED_CUTLASS);\n"
|
| 215 |
-
code += " " + "for(int i = 0; i < 100; i++){\n"
|
| 216 |
-
code += " " + " " + self.gen_class_name + "_verify(\n"
|
| 217 |
-
for i in range(self.b2b_num):
|
| 218 |
-
code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n"
|
| 219 |
-
code += " " + " " + " " + "NULL);\n"
|
| 220 |
-
|
| 221 |
-
code += " " + "}\n"
|
| 222 |
-
code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n"
|
| 223 |
-
|
| 224 |
-
code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n"
|
| 225 |
-
code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n"
|
| 226 |
-
code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \
|
| 227 |
-
+ helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n"
|
| 228 |
-
|
| 229 |
-
code += "\n\n}\n"
|
| 230 |
-
|
| 231 |
-
with open(self.sample_dir + "sample.cu", "w+") as f:
|
| 232 |
-
f.write(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py
DELETED
|
@@ -1,1013 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import gen_ir
|
| 34 |
-
import helper
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class gen_default_b2b_mma:
|
| 38 |
-
def __init__(self, template_param, gen_class_name, b2b_num,cutlass_deps_root, project_root):
|
| 39 |
-
self.gen_class_name = "DefaultB2bMma"
|
| 40 |
-
self.template_param = template_param
|
| 41 |
-
self.b2b_num = b2b_num
|
| 42 |
-
|
| 43 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 44 |
-
self.project_root = project_root
|
| 45 |
-
|
| 46 |
-
def gen_include_header(self):
|
| 47 |
-
code = '''
|
| 48 |
-
/* Auto Generated code - Do not edit.*/
|
| 49 |
-
|
| 50 |
-
#pragma once
|
| 51 |
-
|
| 52 |
-
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
| 53 |
-
#include \"{cutlass_dir}cutlass/numeric_types.h\"
|
| 54 |
-
#include \"{cutlass_dir}cutlass/arch/arch.h\"
|
| 55 |
-
|
| 56 |
-
#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\"
|
| 57 |
-
#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h\"
|
| 58 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\"
|
| 59 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\"
|
| 60 |
-
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\"
|
| 61 |
-
|
| 62 |
-
#include \"../threadblock/b2b_mma_pipelined.h\"
|
| 63 |
-
#include \"../../fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h\"
|
| 64 |
-
#include \"../../fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h\"
|
| 65 |
-
#include \"../../fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h\"
|
| 66 |
-
'''.format(cutlass_dir=self.cutlass_deps_root)
|
| 67 |
-
return code
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def gen_using_MmaCore(self, stage):
|
| 71 |
-
threadBlockShape = "ThreadblockShape"
|
| 72 |
-
warpShape = "WarpShape"
|
| 73 |
-
instrunctionShape = "InstructionShape"
|
| 74 |
-
Mma_typename = "typename cutlass::gemm::threadblock::DefaultMmaCore"
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
gen_code = ""
|
| 78 |
-
|
| 79 |
-
for i in range(self.b2b_num):
|
| 80 |
-
code_using = "using MmaCore" + str(i)
|
| 81 |
-
gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(Mma_typename, \
|
| 82 |
-
helper.var_idx(threadBlockShape, i), helper.var_idx(warpShape, i), instrunctionShape, \
|
| 83 |
-
"ElementA", "LayoutA", \
|
| 84 |
-
helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), \
|
| 85 |
-
helper.var_idx("ElementAccumulator", i), "layout::RowMajor", \
|
| 86 |
-
"OperatorClass", str(stage), "Operator")
|
| 87 |
-
return gen_code
|
| 88 |
-
|
| 89 |
-
def gen_using_FusedAddBiasEpilogue(self):
|
| 90 |
-
gen_code = ""
|
| 91 |
-
for i in range(self.b2b_num - 1):
|
| 92 |
-
code_using = helper.var_idx("using FusedAddBiasEpilogue", i)
|
| 93 |
-
epilogue_name = "typename cutlass::epilogue::threadblock::DefaultFusedBiasActEpilogueTensorOp"
|
| 94 |
-
template_args = helper.var_idx("<ThreadblockShape", i) + helper.var_idx(",typename MmaCore", i) + helper.var_idx("::MmaPolicy::Operator, 1, EpilogueOutputOp", i) + ", 2>::Epilogue"
|
| 95 |
-
|
| 96 |
-
gen_code += code_using + " = " + epilogue_name + template_args + ";\n"
|
| 97 |
-
|
| 98 |
-
return gen_code
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def gen_using_Iterator(self):
|
| 102 |
-
code_using = "using IteratorA0"
|
| 103 |
-
iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator"
|
| 104 |
-
MmaCore = "MmaCore0"
|
| 105 |
-
matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kM, " + MmaCore + "::Shape::kK>"
|
| 106 |
-
iterator_map = "typename " + MmaCore + "::IteratorThreadMapA"
|
| 107 |
-
gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \
|
| 108 |
-
matrix_shape, "ElementA", "LayoutA", "1", iterator_map, "AlignmentA_")
|
| 109 |
-
|
| 110 |
-
for i in range(self.b2b_num):
|
| 111 |
-
code_using = "using IteratorB" + str(i)
|
| 112 |
-
iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator"
|
| 113 |
-
MmaCore = "MmaCore" + str(i)
|
| 114 |
-
matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kK, " + MmaCore + "::Shape::kN>"
|
| 115 |
-
iterator_map = "typename " + MmaCore + "::IteratorThreadMapB"
|
| 116 |
-
|
| 117 |
-
gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \
|
| 118 |
-
matrix_shape, helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), "0", iterator_map, "AlignmentB_")
|
| 119 |
-
|
| 120 |
-
return gen_code
|
| 121 |
-
|
| 122 |
-
def gen_fragment_iterator(self):
|
| 123 |
-
gen_code = "using AccumulatorLayout = cutlass::layout::ColumnMajor;\n"
|
| 124 |
-
|
| 125 |
-
for i in range(1, self.b2b_num):
|
| 126 |
-
code_using = "using FragmentIteratorA" + str(i)
|
| 127 |
-
iterator_typename = "cutlass::gemm::warp::MmaTensorOpPureFragmentIterator"
|
| 128 |
-
curr_MmaCore = "MmaCore" + str(i)
|
| 129 |
-
prev_MmaCore = "MmaCore" + str(i - 1)
|
| 130 |
-
Matrix_shape_curr = "cutlass::MatrixShape<" + curr_MmaCore + "::WarpShape::kM, " + curr_MmaCore + "::InstructionShape::kK>"
|
| 131 |
-
Matrix_shape_prev = "cutlass::MatrixShape<" + prev_MmaCore + "::WarpShape::kM, " + prev_MmaCore + "::WarpShape::kN>"
|
| 132 |
-
Curr_shape_kK = curr_MmaCore + "::Shape::kK"
|
| 133 |
-
|
| 134 |
-
gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \
|
| 135 |
-
Matrix_shape_curr, Matrix_shape_prev, Curr_shape_kK, \
|
| 136 |
-
helper.var_idx("ElementAccumulator", i-1), "ElementA", \
|
| 137 |
-
"AccumulatorLayout", "InstructionShape_", "true")
|
| 138 |
-
|
| 139 |
-
return gen_code
|
| 140 |
-
|
| 141 |
-
def gen_threadblockmma(self):
|
| 142 |
-
code_using = "using ThreadblockB2bMma"
|
| 143 |
-
iterator_typename = "cutlass::gemm::threadblock::B2bMmaPipelined"
|
| 144 |
-
|
| 145 |
-
MmaPipelined_param_Mma0_shape = "typename MmaCore0::Shape"
|
| 146 |
-
MmaPipelined_param_Mma0_iteratorA = "IteratorA0"
|
| 147 |
-
MmaPipelined_param_Mma0_smemIteratorA = "typename MmaCore0::SmemIteratorA"
|
| 148 |
-
MmaPipelined_param_Mma0_iteratorB = "IteratorB0"
|
| 149 |
-
MmaPipelined_param_Mma0_smemIteratorB = "typename MmaCore0::SmemIteratorB"
|
| 150 |
-
|
| 151 |
-
MmaPipelined_param_list = MmaPipelined_param_Mma0_shape + ", " + MmaPipelined_param_Mma0_iteratorA + ", " + MmaPipelined_param_Mma0_smemIteratorA + ", " + MmaPipelined_param_Mma0_iteratorB + ", " + MmaPipelined_param_Mma0_smemIteratorB + ", "
|
| 152 |
-
|
| 153 |
-
for i in range(1, self.b2b_num):
|
| 154 |
-
MmaPipelined_param_Mma_shape = "typename MmaCore" + str(i) + "::Shape"
|
| 155 |
-
MmaPipelined_param_Mma_iteratorA = "FragmentIteratorA" + str(i)
|
| 156 |
-
MmaPipelined_param_Mma_iteratorB = "IteratorB" + str(i)
|
| 157 |
-
MmaPipelined_param_Mma_smemIteratorB = "typename MmaCore" + str(i) + "::SmemIteratorB"
|
| 158 |
-
|
| 159 |
-
MmaPipelined_param_list += MmaPipelined_param_Mma_shape + ", " + MmaPipelined_param_Mma_iteratorA + ", " + MmaPipelined_param_Mma_iteratorB + ", " + MmaPipelined_param_Mma_smemIteratorB + ", "
|
| 160 |
-
|
| 161 |
-
MmaPipelined_param_list += "ElementAccumulator0, layout::RowMajor, "
|
| 162 |
-
|
| 163 |
-
for i in range(self.b2b_num - 1):
|
| 164 |
-
epilogue_name = "EpilogueOutputOp" + str(i)
|
| 165 |
-
MmaPipelined_param_list += epilogue_name + ", "
|
| 166 |
-
|
| 167 |
-
for i in range(self.b2b_num - 1):
|
| 168 |
-
epilogue_name = "FusedAddBiasEpilogue" + str(i)
|
| 169 |
-
MmaPipelined_param_list += epilogue_name + ", "
|
| 170 |
-
|
| 171 |
-
for i in range(self.b2b_num):
|
| 172 |
-
MmaPolicy = "typename MmaCore" + str(i) + "::MmaPolicy"
|
| 173 |
-
MmaPipelined_param_list += MmaPolicy + ", "
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
cnt = 0
|
| 177 |
-
for i in range(self.b2b_num):
|
| 178 |
-
MmaStage = helper.var_idx("Stages", i)
|
| 179 |
-
final = ", "
|
| 180 |
-
if cnt == self.b2b_num - 1:
|
| 181 |
-
final = ""
|
| 182 |
-
MmaPipelined_param_list += MmaStage + final
|
| 183 |
-
cnt += 1
|
| 184 |
-
|
| 185 |
-
gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, MmaPipelined_param_list)
|
| 186 |
-
|
| 187 |
-
return gen_code
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def gen_code(self):
|
| 192 |
-
gen_using = ''
|
| 193 |
-
# Generate default template struct
|
| 194 |
-
gen_code = gen_ir.gen_template_struct(self.gen_class_name, self.template_param, "", speicalized = None, set_default=False)
|
| 195 |
-
|
| 196 |
-
# Generate specialized template struct
|
| 197 |
-
|
| 198 |
-
mmacore_codebody = self.gen_using_MmaCore(2)
|
| 199 |
-
iterator_codebody = self.gen_using_Iterator()
|
| 200 |
-
fragment_iterator_codebody = self.gen_fragment_iterator()
|
| 201 |
-
epilogue_iterator_codebody = self.gen_using_FusedAddBiasEpilogue()
|
| 202 |
-
threadBlockMma = self.gen_threadblockmma()
|
| 203 |
-
specialized_code = mmacore_codebody + iterator_codebody + fragment_iterator_codebody + epilogue_iterator_codebody + threadBlockMma
|
| 204 |
-
|
| 205 |
-
# Specialize layout C -> cutlass::layout::RowMajor
|
| 206 |
-
|
| 207 |
-
rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, [ ('LayoutD', "cutlass::layout::RowMajor")], keep_= True)
|
| 208 |
-
|
| 209 |
-
gen_speical_code = gen_ir.gen_template_struct(self.gen_class_name, rtn_template_args, specialized_code, speicalized = speicalized_template_args, set_default=False)
|
| 210 |
-
code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", gen_code + gen_speical_code)))
|
| 211 |
-
|
| 212 |
-
return self.gen_include_header() + code
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
class gen_b2b_mme_pipelined:
|
| 216 |
-
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
|
| 217 |
-
self.gen_class_name = "B2bMmaPipelined"
|
| 218 |
-
self.template_param = template_param
|
| 219 |
-
self.b2b_num = b2b_num
|
| 220 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 221 |
-
self.project_root = project_root
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
def gen_include_header(self):
|
| 225 |
-
code = '''
|
| 226 |
-
#pragma once
|
| 227 |
-
|
| 228 |
-
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
| 229 |
-
#include \"{cutlass_dir}cutlass/array.h\"
|
| 230 |
-
#include \"{cutlass_dir}cutlass/aligned_buffer.h\"
|
| 231 |
-
#include \"{cutlass_dir}cutlass/numeric_conversion.h\"
|
| 232 |
-
|
| 233 |
-
#include \"{cutlass_dir}cutlass/numeric_types.h\"
|
| 234 |
-
#include \"{cutlass_dir}cutlass/matrix_shape.h\"
|
| 235 |
-
|
| 236 |
-
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
| 237 |
-
#include \"{cutlass_dir}cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h\"
|
| 238 |
-
|
| 239 |
-
#include \"../threadblock/b2b_mma_base.h\"\n'''.format(cutlass_dir = self.cutlass_deps_root)
|
| 240 |
-
return code
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
def gen_using(self):
|
| 244 |
-
code_using = "using FragmentA0 = typename IteratorA0::Fragment;\n"
|
| 245 |
-
|
| 246 |
-
code_using += "using Base = B2bMmaBase<"
|
| 247 |
-
for i in range(self.b2b_num):
|
| 248 |
-
code_using += helper.var_idx("Shape", i) + "_, "
|
| 249 |
-
for i in range(self.b2b_num):
|
| 250 |
-
code_using += helper.var_idx("Policy", i) + "_, "
|
| 251 |
-
for i in range(self.b2b_num):
|
| 252 |
-
code_using += helper.var_idx("Stage", i) + "_, "
|
| 253 |
-
code_using = code_using[: -2] + ">;\n"
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
for i in range(self.b2b_num):
|
| 257 |
-
code_using += helper.var_idx("using FragmentB", i) + helper.var_idx(" = typename IteratorB", i) + "::Fragment;\n"
|
| 258 |
-
code_using += helper.var_idx("using FragmentC", i) + helper.var_idx(" = typename Policy", i) + "::Operator::FragmentC;\n"
|
| 259 |
-
code_using += helper.var_idx("using Operator", i) + helper.var_idx(" = typename Policy", i) + "::Operator;\n"
|
| 260 |
-
|
| 261 |
-
for i in range(self.b2b_num - 1):
|
| 262 |
-
code_using += helper.var_idx("using IteratorC", i) + helper.var_idx(" = typename FusedAddBiasEpilogue", i) + "::OutputTileIterator;\n"
|
| 263 |
-
|
| 264 |
-
code_using += "using ArchTag = typename Policy0::Operator::ArchTag;\n"
|
| 265 |
-
code_using += "static ComplexTransform const kTransformA0 = Operator0::kTransformA;\n"
|
| 266 |
-
|
| 267 |
-
for i in range(self.b2b_num):
|
| 268 |
-
code_using += helper.var_idx("static ComplexTransform const kTransformB", i) + helper.var_idx(" = Operator", i) + "::kTransformB;\n"
|
| 269 |
-
|
| 270 |
-
code_using += "private:\n"
|
| 271 |
-
code_using += "using WarpFragmentA0 = typename Operator0::FragmentA;\n"
|
| 272 |
-
code_using += "using WarpFragmentB0 = typename Operator0::FragmentB;\n"
|
| 273 |
-
|
| 274 |
-
for i in range(1, self.b2b_num):
|
| 275 |
-
code_using += helper.var_idx("using WarpFragmentA", i) + helper.var_idx(" = typename FragmentIteratorA", i) + "::Fragment;\n"
|
| 276 |
-
code_using += helper.var_idx("using WarpFragmentB", i) + helper.var_idx(" = typename Operator", i) + "::FragmentB;\n"
|
| 277 |
-
|
| 278 |
-
code_using += "protected:\n"
|
| 279 |
-
|
| 280 |
-
code_using += "SmemIteratorA0 smem_iterator_A_;\n"
|
| 281 |
-
|
| 282 |
-
for i in range(self.b2b_num):
|
| 283 |
-
code_using += helper.var_idx("SmemIteratorB", i) + helper.var_idx(" smem_iterator_B", i) + "_;\n"
|
| 284 |
-
|
| 285 |
-
return code_using
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def gen_operator(self, first_use_1stage = False):
|
| 289 |
-
code = ""
|
| 290 |
-
def gen_operator_param(b2b_num):
|
| 291 |
-
param_code = ""
|
| 292 |
-
param_code += "int gemm_k_iterations_0,\n"
|
| 293 |
-
param_code += helper.var_idx("FragmentC", b2b_num-1) + helper.var_idx(" &accum", b2b_num-1) + ",\n"
|
| 294 |
-
param_code += "IteratorA0 iterator_A,\n"
|
| 295 |
-
|
| 296 |
-
for i in range(b2b_num):
|
| 297 |
-
param_code += helper.var_idx("IteratorB", i) + " " + helper.var_idx("iterator_B", i) + ",\n"
|
| 298 |
-
|
| 299 |
-
param_code += "FragmentC0 const &src_accum, \n"
|
| 300 |
-
|
| 301 |
-
for i in range(b2b_num - 1):
|
| 302 |
-
param_code += helper.var_idx("OutputOp", i) + " " + helper.var_idx("output_op_", i) + ",\n"
|
| 303 |
-
for i in range(b2b_num - 1):
|
| 304 |
-
param_code += helper.var_idx("FusedAddBiasEpilogue", i) + " " + helper.var_idx("epilogue_", i) + ",\n"
|
| 305 |
-
for i in range(b2b_num - 1):
|
| 306 |
-
param_code += helper.var_idx("IteratorC", i) + " " + helper.var_idx("iterator_C", i) + ",\n"
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
param_code += "TransformA0 transform_A0 = TransformA0(), \n"
|
| 310 |
-
|
| 311 |
-
for i in range(b2b_num):
|
| 312 |
-
final = "(),\n"
|
| 313 |
-
if i == b2b_num - 1:
|
| 314 |
-
final = "()\n"
|
| 315 |
-
param_code += helper.var_idx("TransformB", i) + " " + helper.var_idx("transform_B", i) + " = " +helper.var_idx("TransformB", i) + final
|
| 316 |
-
|
| 317 |
-
return param_code
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
def gen_first_gemm_1stage(b2b_num):
|
| 322 |
-
accu_code = " FragmentC0 accum0 = src_accum;\n"
|
| 323 |
-
if b2b_num == 1:
|
| 324 |
-
accu_code = " accum0 = src_accum;\n"
|
| 325 |
-
|
| 326 |
-
code ="\
|
| 327 |
-
\n\
|
| 328 |
-
FragmentA0 tb_frag_A;\n\
|
| 329 |
-
FragmentB0 tb_frag_B0;\n\
|
| 330 |
-
\n\
|
| 331 |
-
int smem_write_stage_idx = 1;\n\
|
| 332 |
-
\n\
|
| 333 |
-
tb_frag_A.clear();\n\
|
| 334 |
-
tb_frag_B0.clear();\n\
|
| 335 |
-
\n\
|
| 336 |
-
// The last kblock is loaded in the prolog\n\
|
| 337 |
-
iterator_A.load(tb_frag_A);\n\
|
| 338 |
-
iterator_B0.load(tb_frag_B0);\n\
|
| 339 |
-
\n\
|
| 340 |
-
++iterator_A;\n\
|
| 341 |
-
++iterator_B0;\n\
|
| 342 |
-
\n\
|
| 343 |
-
WarpFragmentA0 warp_frag_A0;\n\
|
| 344 |
-
WarpFragmentB0 warp_frag_B0;\n\
|
| 345 |
-
\n\
|
| 346 |
-
Operator0 warp_mma0;\n\
|
| 347 |
-
\n\
|
| 348 |
-
// Avoid reading out of bounds\n\
|
| 349 |
-
if (gemm_k_iterations_0 <= 1) {\n\
|
| 350 |
-
iterator_A.clear_mask();\n\
|
| 351 |
-
iterator_B0.clear_mask();\n\
|
| 352 |
-
}\n\
|
| 353 |
-
\n\
|
| 354 |
-
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\
|
| 355 |
-
// shared memory loads (which have the tightest latency requirement).\n\
|
| 356 |
-
\n\
|
| 357 |
-
//\n\
|
| 358 |
-
// Mainloop\n\
|
| 359 |
-
//\n\
|
| 360 |
-
\n\
|
| 361 |
-
// Note: The main loop does not support Base::WarpGemmIterations == 2.\n\
|
| 362 |
-
CUTLASS_GEMM_LOOP\n\
|
| 363 |
-
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\
|
| 364 |
-
\n\
|
| 365 |
-
this->smem_iterator_A_.store(tb_frag_A);\n\
|
| 366 |
-
this->smem_iterator_B0_.store(tb_frag_B0);\n\
|
| 367 |
-
\n\
|
| 368 |
-
__syncthreads();\n\
|
| 369 |
-
//\n\
|
| 370 |
-
// Loop over GEMM K dimension\n\
|
| 371 |
-
//\n\
|
| 372 |
-
\n\
|
| 373 |
-
CUTLASS_PRAGMA_UNROLL\n\
|
| 374 |
-
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\
|
| 375 |
-
\n\
|
| 376 |
-
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\
|
| 377 |
-
// as the case may be.\n\
|
| 378 |
-
\n\
|
| 379 |
-
this->warp_tile_iterator_A0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\
|
| 380 |
-
this->warp_tile_iterator_B0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\
|
| 381 |
-
\n\
|
| 382 |
-
this->warp_tile_iterator_A0_.load(warp_frag_A0);\n\
|
| 383 |
-
this->warp_tile_iterator_B0_.load(warp_frag_B0);\n\
|
| 384 |
-
\n\
|
| 385 |
-
++this->warp_tile_iterator_A0_;\n\
|
| 386 |
-
++this->warp_tile_iterator_B0_;\n\
|
| 387 |
-
\n\
|
| 388 |
-
warp_mma0(accum0, warp_frag_A0, warp_frag_B0, accum0);\n\
|
| 389 |
-
}\n\
|
| 390 |
-
this->warp_tile_iterator_A0_.add_tile_offset({0, -Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\
|
| 391 |
-
this->warp_tile_iterator_B0_.add_tile_offset({-Policy0::kPartitionsK * Base::kWarpGemmIterations0, 0});\n\
|
| 392 |
-
\n\
|
| 393 |
-
__syncthreads();\n\
|
| 394 |
-
iterator_A.load(tb_frag_A);\n\
|
| 395 |
-
iterator_B0.load(tb_frag_B0);\n\
|
| 396 |
-
\n\
|
| 397 |
-
++iterator_A;\n\
|
| 398 |
-
++iterator_B0;\n\
|
| 399 |
-
\n\
|
| 400 |
-
if(gemm_k_iterations_0 <= 2) {\n\
|
| 401 |
-
iterator_A.clear_mask();\n\
|
| 402 |
-
iterator_B0.clear_mask();\n\
|
| 403 |
-
}\n\
|
| 404 |
-
}\n"
|
| 405 |
-
|
| 406 |
-
return accu_code + code
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
def gen_first_gemm_2stage(b2b_num):
|
| 410 |
-
|
| 411 |
-
accu_code = " FragmentC0 accum0 = src_accum;\n"
|
| 412 |
-
if b2b_num == 1:
|
| 413 |
-
accu_code = " accum0 = src_accum;\n"
|
| 414 |
-
|
| 415 |
-
code ="\
|
| 416 |
-
\n\
|
| 417 |
-
FragmentA0 tb_frag_A;\n\
|
| 418 |
-
FragmentB0 tb_frag_B0;\n\
|
| 419 |
-
\n\
|
| 420 |
-
tb_frag_A.clear();\n\
|
| 421 |
-
tb_frag_B0.clear();\n\
|
| 422 |
-
\n\
|
| 423 |
-
// The last kblock is loaded in the prolog\n\
|
| 424 |
-
iterator_A.load(tb_frag_A);\n\
|
| 425 |
-
iterator_B0.load(tb_frag_B0);\n\
|
| 426 |
-
\n\
|
| 427 |
-
++iterator_A;\n\
|
| 428 |
-
++iterator_B0;\n\
|
| 429 |
-
\n\
|
| 430 |
-
this->smem_iterator_A_.store(tb_frag_A);\n\
|
| 431 |
-
this->smem_iterator_B0_.store(tb_frag_B0);\n\
|
| 432 |
-
\n\
|
| 433 |
-
++this->smem_iterator_A_;\n\
|
| 434 |
-
++this->smem_iterator_B0_;\n\
|
| 435 |
-
\n\
|
| 436 |
-
__syncthreads();\n\
|
| 437 |
-
\n\
|
| 438 |
-
// Pair of fragments used to overlap shared memory loads and math instructions\n\
|
| 439 |
-
WarpFragmentA0 warp_frag_A0[2];\n\
|
| 440 |
-
WarpFragmentB0 warp_frag_B0[2];\n\
|
| 441 |
-
\n\
|
| 442 |
-
this->warp_tile_iterator_A0_.set_kgroup_index(0);\n\
|
| 443 |
-
this->warp_tile_iterator_B0_.set_kgroup_index(0);\n\
|
| 444 |
-
\n\
|
| 445 |
-
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);\n\
|
| 446 |
-
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);\n\
|
| 447 |
-
\n\
|
| 448 |
-
++this->warp_tile_iterator_A0_;\n\
|
| 449 |
-
++this->warp_tile_iterator_B0_;\n\
|
| 450 |
-
\n\
|
| 451 |
-
Operator0 warp_mma0;\n\
|
| 452 |
-
\n\
|
| 453 |
-
int smem_write_stage_idx = 1;\n\
|
| 454 |
-
\n\
|
| 455 |
-
// Avoid reading out of bounds\n\
|
| 456 |
-
if (gemm_k_iterations_0 <= 1) {\n\
|
| 457 |
-
iterator_A.clear_mask();\n\
|
| 458 |
-
iterator_B0.clear_mask();\n\
|
| 459 |
-
}\n\
|
| 460 |
-
\n\
|
| 461 |
-
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\
|
| 462 |
-
// shared memory loads (which have the tightest latency requirement).\n\
|
| 463 |
-
iterator_A.load(tb_frag_A);\n\
|
| 464 |
-
\n\
|
| 465 |
-
//\n\
|
| 466 |
-
// Mainloop\n\
|
| 467 |
-
//\n\
|
| 468 |
-
\n\
|
| 469 |
-
// Note: The main loop does not support Base::WarpGemmIterations == 2.\n\
|
| 470 |
-
CUTLASS_GEMM_LOOP\n\
|
| 471 |
-
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\
|
| 472 |
-
\n\
|
| 473 |
-
//\n\
|
| 474 |
-
// Loop over GEMM K dimension\n\
|
| 475 |
-
//\n\
|
| 476 |
-
\n\
|
| 477 |
-
CUTLASS_PRAGMA_UNROLL\n\
|
| 478 |
-
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\
|
| 479 |
-
\n\
|
| 480 |
-
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\
|
| 481 |
-
// as the case may be.\n\
|
| 482 |
-
\n\
|
| 483 |
-
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {\n\
|
| 484 |
-
\n\
|
| 485 |
-
// Write fragments to shared memory\n\
|
| 486 |
-
this->smem_iterator_A_.store(tb_frag_A);\n\
|
| 487 |
-
\n\
|
| 488 |
-
this->smem_iterator_B0_.store(tb_frag_B0);\n\
|
| 489 |
-
\n\
|
| 490 |
-
__syncthreads();\n\
|
| 491 |
-
\n\
|
| 492 |
-
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\
|
| 493 |
-
// shared memory loads (which have the tightest latency requirement).\n\
|
| 494 |
-
iterator_A.load(tb_frag_A);\n\
|
| 495 |
-
\n\
|
| 496 |
-
++this->smem_iterator_B0_;\n\
|
| 497 |
-
++this->smem_iterator_A_;\n\
|
| 498 |
-
\n\
|
| 499 |
-
\n\
|
| 500 |
-
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory\n\
|
| 501 |
-
if (smem_write_stage_idx == 1) {\n\
|
| 502 |
-
this->smem_iterator_A_.add_tile_offset({0, -Base::Stage0});\n\
|
| 503 |
-
this->smem_iterator_B0_.add_tile_offset({-Base::Stage0, 0});\n\
|
| 504 |
-
}\n\
|
| 505 |
-
else {\n\
|
| 506 |
-
this->warp_tile_iterator_A0_.add_tile_offset(\n\
|
| 507 |
-
{0, -Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\
|
| 508 |
-
this->warp_tile_iterator_B0_.add_tile_offset(\n\
|
| 509 |
-
{-Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0,\n\
|
| 510 |
-
0});\n\
|
| 511 |
-
}\n\
|
| 512 |
-
\n\
|
| 513 |
-
smem_write_stage_idx ^= 1;\n\
|
| 514 |
-
}\n\
|
| 515 |
-
\n\
|
| 516 |
-
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\
|
| 517 |
-
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\
|
| 518 |
-
\n\
|
| 519 |
-
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);\n\
|
| 520 |
-
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);\n\
|
| 521 |
-
\n\
|
| 522 |
-
++this->warp_tile_iterator_A0_;\n\
|
| 523 |
-
++this->warp_tile_iterator_B0_;\n\
|
| 524 |
-
\n\
|
| 525 |
-
if (warp_mma_k == 0) {\n\
|
| 526 |
-
\n\
|
| 527 |
-
iterator_B0.load(tb_frag_B0);\n\
|
| 528 |
-
\n\
|
| 529 |
-
++iterator_A;\n\
|
| 530 |
-
++iterator_B0;\n\
|
| 531 |
-
\n\
|
| 532 |
-
// Avoid reading out of bounds if this was the last loop iteration\n\
|
| 533 |
-
if (gemm_k_iterations_0 <= 2) {\n\
|
| 534 |
-
iterator_A.clear_mask();\n\
|
| 535 |
-
iterator_B0.clear_mask();\n\
|
| 536 |
-
}\n\
|
| 537 |
-
}\n\
|
| 538 |
-
\n\
|
| 539 |
-
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);\n\
|
| 540 |
-
}\n\
|
| 541 |
-
}\n"
|
| 542 |
-
return accu_code + code
|
| 543 |
-
|
| 544 |
-
def gen_other_gemms_2stage(b2b_num):
|
| 545 |
-
|
| 546 |
-
code = ""
|
| 547 |
-
|
| 548 |
-
def gemm_teamplate(id):
|
| 549 |
-
code = "// " + str(id + 1) + " Gemm"
|
| 550 |
-
code += " /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile\n"
|
| 551 |
-
|
| 552 |
-
code += " " + helper.var_idx("FragmentC", id - 1) + helper.var_idx(" after_epilogue_accu", id - 1) + ";\n"
|
| 553 |
-
code += " " + helper.var_idx("epilogue_", id - 1) + helper.var_idx("(output_op_", id - 1) + helper.var_idx(", accum", id - 1) \
|
| 554 |
-
+ helper.var_idx(", after_epilogue_accu", id - 1) + helper.var_idx(", iterator_C", id - 1) +");\n"
|
| 555 |
-
|
| 556 |
-
# FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
| 557 |
-
code += " " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx(" warp_tile_iterator_A", id) +"_(" + helper.var_idx("after_epilogue_accu", id - 1) + ");\n"
|
| 558 |
-
# FragmentB1 tb_frag_B1;
|
| 559 |
-
code += " " + helper.var_idx("FragmentB", id) + " " + helper.var_idx("tb_frag_B", id) + ";\n"
|
| 560 |
-
# tb_frag_B1.clear();
|
| 561 |
-
code += " " + helper.var_idx("tb_frag_B", id) + ".clear();\n"
|
| 562 |
-
# iterator_B1.load(tb_frag_B1);
|
| 563 |
-
code += " " + helper.var_idx("iterator_B", id) + ".load(" + helper.var_idx("tb_frag_B", id) + ");\n"
|
| 564 |
-
# ++iterator_B1;
|
| 565 |
-
code += " " + "++" + helper.var_idx("iterator_B", id) + ";\n"
|
| 566 |
-
# this->smem_iterator_B1_.store(tb_frag_B1);
|
| 567 |
-
code += " " + helper.var_idx("this->smem_iterator_B", id) + "_.store(" + helper.var_idx("tb_frag_B", id) + ");\n"
|
| 568 |
-
# ++this->smem_iterator_B1_;
|
| 569 |
-
code += " " + helper.var_idx("++this->smem_iterator_B", id) + "_;\n"
|
| 570 |
-
# __syncthreads();
|
| 571 |
-
code += " " + "__syncthreads();\n"
|
| 572 |
-
# WarpFragmentA1 warp_frag_A1[2];
|
| 573 |
-
code += " " + helper.var_idx("WarpFragmentA", id) + helper.var_idx(" warp_frag_A", id) + "[2];\n"
|
| 574 |
-
# WarpFragmentB1 warp_frag_B1[2];
|
| 575 |
-
code += " " + helper.var_idx("WarpFragmentB", id) + helper.var_idx(" warp_frag_B", id) + "[2];\n"
|
| 576 |
-
# this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
| 577 |
-
code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.set_kgroup_index(0);\n"
|
| 578 |
-
# warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
| 579 |
-
code += " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[0]);\n"
|
| 580 |
-
# this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
| 581 |
-
code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[0]);\n"
|
| 582 |
-
# ++warp_tile_iterator_A1_;
|
| 583 |
-
code += " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n"
|
| 584 |
-
# ++this->warp_tile_iterator_B1_;
|
| 585 |
-
code += " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n"
|
| 586 |
-
# Operator1 warp_mma1;
|
| 587 |
-
code += " " + helper.var_idx("Operator", id) + " " + helper.var_idx("warp_mma", id) + ";\n"
|
| 588 |
-
# smem_write_stage_idx = 1;
|
| 589 |
-
code += " " + "smem_write_stage_idx = 1;\n"
|
| 590 |
-
# int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
| 591 |
-
code += " " + helper.var_idx("int gemm_k_iterations_", id) + " = " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx("::Policy::kIterations / Base::kWarpGemmIterations", id) +";\n"
|
| 592 |
-
# if (gemm_k_iterations_1 <= 1) {
|
| 593 |
-
# iterator_B1.clear_mask();
|
| 594 |
-
# }
|
| 595 |
-
code += " " + "if (" + helper.var_idx("gemm_k_iterations_", id) + " <= 1 ){\n" \
|
| 596 |
-
+ " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \
|
| 597 |
-
+ " " +"}\n"
|
| 598 |
-
# CUTLASS_PRAGMA_UNROLL
|
| 599 |
-
code += " " + "CUTLASS_PRAGMA_UNROLL\n"
|
| 600 |
-
# for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
| 601 |
-
code += " " + helper.var_idx("for (; gemm_k_iterations_", id) + helper.var_idx(" > 0; --gemm_k_iterations_", id) + ") {\n"
|
| 602 |
-
# CUTLASS_PRAGMA_UNROLL
|
| 603 |
-
code += " " + " " + "CUTLASS_PRAGMA_UNROLL\n"
|
| 604 |
-
# for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
| 605 |
-
code += " " + " " + helper.var_idx("for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations", id) + "; ++warp_mma_k) {\n"
|
| 606 |
-
# if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
| 607 |
-
code += " " + " " + " " + helper.var_idx("if (warp_mma_k == Base::kWarpGemmIterations", id) + " - 1) {\n"
|
| 608 |
-
# this->smem_iterator_B1_.store(tb_frag_B1);
|
| 609 |
-
code += " " + " " + " " + " " + helper.var_idx(" this->smem_iterator_B", id) + helper.var_idx("_.store(tb_frag_B", id) + ");\n"
|
| 610 |
-
# __syncthreads();
|
| 611 |
-
code += " " + " " + " " + " " + "__syncthreads();\n"
|
| 612 |
-
# ++smem_iterator_B1_;
|
| 613 |
-
code += " " + " " + " " + " " + helper.var_idx(" ++smem_iterator_B", id) + "_;\n"
|
| 614 |
-
# if (smem_write_stage_idx == 1) {
|
| 615 |
-
# smem_iterator_B1_.add_tile_offset({-Base::Stage, 0});
|
| 616 |
-
# }
|
| 617 |
-
code += " " + " " + " " + " " + "if ( smem_write_stage_idx == 1 ) {\n" \
|
| 618 |
-
+ " " + " " + " " + " " + " " + helper.var_idx("smem_iterator_B", id) + helper.var_idx("_.add_tile_offset({-Base::Stage", i) + ", 0});\n" \
|
| 619 |
-
+ " " + " " + " " + " " +"}\n"
|
| 620 |
-
# else {
|
| 621 |
-
# this->warp_tile_iterator_B1_.add_tile_offset(
|
| 622 |
-
# {-Base::Stage * Policy1::kPartitionsK *
|
| 623 |
-
# Base::kWarpGemmIterations1,
|
| 624 |
-
# 0});
|
| 625 |
-
# }
|
| 626 |
-
code += " " + " " + " " + " " + "else {\n" \
|
| 627 |
-
+ " " + " " + " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.add_tile_offset(\n" \
|
| 628 |
-
+ " " + " " + " " + " " + " " + helper.var_idx("{-Base::Stage", id) + helper.var_idx(" * Policy", id) + "::kPartitionsK *\n" \
|
| 629 |
-
+ " " + " " + " " + " " + " " + helper.var_idx("Base::kWarpGemmIterations", id) + ",\n" \
|
| 630 |
-
+ " " + " " + " " + " " + " " + "0});\n" \
|
| 631 |
-
+ " " + " " + " " + " " + "}\n"
|
| 632 |
-
|
| 633 |
-
# smem_write_stage_idx ^= 1;
|
| 634 |
-
# }
|
| 635 |
-
code += " " + " " + " " + " " + "smem_write_stage_idx ^= 1;\n" \
|
| 636 |
-
+ " " + " " + " " + "}\n"
|
| 637 |
-
|
| 638 |
-
# this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
| 639 |
-
code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations", id) + ");\n"
|
| 640 |
-
# warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
| 641 |
-
code += " " + " " + " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[(warp_mma_k + 1) % 2]);\n"
|
| 642 |
-
# this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
| 643 |
-
code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[(warp_mma_k + 1) % 2]);\n"
|
| 644 |
-
# ++warp_tile_iterator_A1_;
|
| 645 |
-
code += " " + " " + " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n"
|
| 646 |
-
# ++this->warp_tile_iterator_B1_;
|
| 647 |
-
code += " " + " " + " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n"
|
| 648 |
-
# if (warp_mma_k == 0) {
|
| 649 |
-
# iterator_B1.load(tb_frag_B1);
|
| 650 |
-
# ++iterator_B1;
|
| 651 |
-
# if (gemm_k_iterations_1 <= 2) {
|
| 652 |
-
# iterator_B1.clear_mask();
|
| 653 |
-
# }
|
| 654 |
-
# }
|
| 655 |
-
code += " " + " " + " " + " if (warp_mma_k == 0) {\n" \
|
| 656 |
-
+ " " + " " + " " + " " + helper.var_idx("iterator_B", id) + helper.var_idx(".load(tb_frag_B", id) + ");\n" \
|
| 657 |
-
+ " " + " " + " " + " " + helper.var_idx("++iterator_B", id) +";\n" \
|
| 658 |
-
+ " " + " " + " " + " " + helper.var_idx("if (gemm_k_iterations_", id) +" <= 2) {\n" \
|
| 659 |
-
+ " " + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \
|
| 660 |
-
+ " " + " " + " " + " " + "}\n" \
|
| 661 |
-
+ " " + " " + " " + "}\n"
|
| 662 |
-
# warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
|
| 663 |
-
# }
|
| 664 |
-
# }
|
| 665 |
-
code += " " + " " + " " + helper.var_idx("warp_mma", id) + helper.var_idx("(accum", id) + helper.var_idx(", warp_frag_A", id) + helper.var_idx("[warp_mma_k % 2], warp_frag_B", id) + helper.var_idx("[warp_mma_k % 2], accum", id) + ");\n" \
|
| 666 |
-
+ " " + " " + "}\n" \
|
| 667 |
-
+ " " + "}\n\n\n"
|
| 668 |
-
|
| 669 |
-
return code
|
| 670 |
-
|
| 671 |
-
for i in range (1, b2b_num):
|
| 672 |
-
clear_accu = ""
|
| 673 |
-
if i != b2b_num - 1:
|
| 674 |
-
clear_accu = " " + helper.var_idx("FragmentC", i) + helper.var_idx(" accum", i) +";\n"
|
| 675 |
-
clear_accu += " " + helper.var_idx("accum", i) +".clear();\n"
|
| 676 |
-
code += clear_accu + gemm_teamplate(i)
|
| 677 |
-
|
| 678 |
-
return code
|
| 679 |
-
|
| 680 |
-
operator_code = " CUTLASS_DEVICE\n\
|
| 681 |
-
void operator()(\n " + gen_operator_param(self.b2b_num) + ") {\n"
|
| 682 |
-
if first_use_1stage:
|
| 683 |
-
operator_code += gen_first_gemm_1stage(self.b2b_num)
|
| 684 |
-
else:
|
| 685 |
-
operator_code += gen_first_gemm_2stage(self.b2b_num)
|
| 686 |
-
operator_code += gen_other_gemms_2stage(self.b2b_num) + "}\n"
|
| 687 |
-
return operator_code
|
| 688 |
-
|
| 689 |
-
def gen_construct_func(self):
|
| 690 |
-
name = self.gen_class_name
|
| 691 |
-
func_code = "CUTLASS_DEVICE\n"
|
| 692 |
-
func_code += name + "(\n" \
|
| 693 |
-
+ " " + "typename Base::B2bMmaSharedStorage &shared_storage,\n" \
|
| 694 |
-
+ " " + "int thread_idx,\n" \
|
| 695 |
-
+ " " + "int warp_idx,\n" \
|
| 696 |
-
+ " " + "int lane_idx\n" \
|
| 697 |
-
+ "):\n"
|
| 698 |
-
func_code += " " + "Base(shared_storage, thread_idx, warp_idx, lane_idx),\n" \
|
| 699 |
-
+ " " + "smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),\n"
|
| 700 |
-
|
| 701 |
-
for i in range(self.b2b_num):
|
| 702 |
-
final = ",\n"
|
| 703 |
-
if i == self.b2b_num - 1:
|
| 704 |
-
final = " {\n"
|
| 705 |
-
func_code += helper.var_idx("smem_iterator_B", i) + helper.var_idx("_(shared_storage.sharedStorage", i) +".operand_B_ref(), thread_idx)" + final
|
| 706 |
-
|
| 707 |
-
func_code += " " + "int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);\n"
|
| 708 |
-
func_code += " " + "int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);\n"
|
| 709 |
-
|
| 710 |
-
func_code += " " + "int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;\n"
|
| 711 |
-
func_code += " " + "int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;\n"
|
| 712 |
-
|
| 713 |
-
for i in range(self.b2b_num):
|
| 714 |
-
func_code += " " + helper.var_idx("int tile_offset_k", i) + helper.var_idx(" = Base::kWarpGemmIterations", i) + " * warp_idx_k;\n"
|
| 715 |
-
|
| 716 |
-
func_code += " " + "this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k0});\n"
|
| 717 |
-
|
| 718 |
-
for i in range(self.b2b_num):
|
| 719 |
-
func_code += " " + helper.var_idx("this->warp_tile_iterator_B", i) + helper.var_idx("_.add_tile_offset({tile_offset_k", i) + ", warp_idx_n});\n"
|
| 720 |
-
|
| 721 |
-
func_code += "}\n"
|
| 722 |
-
|
| 723 |
-
return func_code
|
| 724 |
-
|
| 725 |
-
def gen_member_func(self, first_use_1stage):
|
| 726 |
-
code = "public:\n"
|
| 727 |
-
code += self.gen_operator(first_use_1stage)
|
| 728 |
-
code += self.gen_construct_func()
|
| 729 |
-
|
| 730 |
-
return code
|
| 731 |
-
|
| 732 |
-
def gen_code(self, first_use_1stage):
|
| 733 |
-
|
| 734 |
-
def gen_template_args(b2b_num):
|
| 735 |
-
template_param = []
|
| 736 |
-
template_param.append(("typename", "Shape0"))
|
| 737 |
-
template_param.append(("typename", "IteratorA0"))
|
| 738 |
-
template_param.append(("typename", "SmemIteratorA0"))
|
| 739 |
-
template_param.append(("typename", "IteratorB0"))
|
| 740 |
-
template_param.append(("typename", "SmemIteratorB0"))
|
| 741 |
-
|
| 742 |
-
for i in range(1, b2b_num):
|
| 743 |
-
template_param.append(("typename", helper.var_idx("Shape", i)))
|
| 744 |
-
template_param.append(("typename", helper.var_idx("FragmentIteratorA", i)))
|
| 745 |
-
template_param.append(("typename", helper.var_idx("IteratorB", i)))
|
| 746 |
-
template_param.append(("typename", helper.var_idx("SmemIteratorB", i)))
|
| 747 |
-
|
| 748 |
-
template_param.append(("typename", "ElementC"))
|
| 749 |
-
template_param.append(("typename", "LayoutC"))
|
| 750 |
-
|
| 751 |
-
for i in range(0, b2b_num - 1):
|
| 752 |
-
template_param.append(("typename", helper.var_idx("OutputOp", i)))
|
| 753 |
-
|
| 754 |
-
for i in range(0, b2b_num - 1):
|
| 755 |
-
template_param.append(("typename", helper.var_idx("FusedAddBiasEpilogue", i)))
|
| 756 |
-
|
| 757 |
-
for i in range(0, b2b_num):
|
| 758 |
-
template_param.append(("typename", helper.var_idx("Policy", i)))
|
| 759 |
-
for i in range(0, b2b_num):
|
| 760 |
-
template_param.append((int, helper.var_idx("Stage", i)))
|
| 761 |
-
|
| 762 |
-
template_param.append(("typename","TransformA0", "NumericArrayConverter<typename SmemIteratorA0_::Element, typename IteratorA0_::Element, IteratorA0_::Fragment::kElements>"))
|
| 763 |
-
|
| 764 |
-
for i in range(0, b2b_num):
|
| 765 |
-
cvtr = helper.var_idx("NumericArrayConverter<typename SmemIteratorB", i) + helper.var_idx("_::Element, typename IteratorB", i) + helper.var_idx("_::Element, IteratorB", i) + "_::Fragment::kElements>"
|
| 766 |
-
template_param.append(("typename", helper.var_idx("TransformB", i), cvtr))
|
| 767 |
-
|
| 768 |
-
template_param.append(("typename", "Enable", "bool"))
|
| 769 |
-
|
| 770 |
-
return template_param
|
| 771 |
-
|
| 772 |
-
template_param = gen_template_args(self.b2b_num)
|
| 773 |
-
inheritance_code = "public B2bMmaBase<"
|
| 774 |
-
for i in range(self.b2b_num):
|
| 775 |
-
inheritance_code += helper.var_idx("Shape", i) + "_, "
|
| 776 |
-
for i in range(self.b2b_num):
|
| 777 |
-
inheritance_code += helper.var_idx("Policy", i) + "_, "
|
| 778 |
-
for i in range(self.b2b_num - 1):
|
| 779 |
-
inheritance_code += helper.var_idx("Stage", i) + "_, "
|
| 780 |
-
inheritance_code += helper.var_idx("Stage", self.b2b_num - 1) + "_"
|
| 781 |
-
inheritance_code += ">"
|
| 782 |
-
|
| 783 |
-
code_body = ""
|
| 784 |
-
using_code= self.gen_using()
|
| 785 |
-
func_code = self.gen_member_func(first_use_1stage)
|
| 786 |
-
|
| 787 |
-
code_body = using_code + func_code
|
| 788 |
-
|
| 789 |
-
class_code = gen_ir.gen_template_class(self.gen_class_name, template_param, code_body, inheritance_code = inheritance_code)
|
| 790 |
-
|
| 791 |
-
code = self.gen_include_header()
|
| 792 |
-
code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code)))
|
| 793 |
-
# print(code)
|
| 794 |
-
return code
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
class gen_b2b_mma_base:
|
| 798 |
-
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
|
| 799 |
-
self.gen_class_name = gen_class_name
|
| 800 |
-
self.template_param = template_param
|
| 801 |
-
self.b2b_num = b2b_num
|
| 802 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 803 |
-
self.project_root = project_root
|
| 804 |
-
|
| 805 |
-
def gen_include_header(self):
|
| 806 |
-
code = '''
|
| 807 |
-
#pragma once
|
| 808 |
-
|
| 809 |
-
#include \"{cutlass_dirs}cutlass/aligned_buffer.h\"
|
| 810 |
-
#include \"{cutlass_dirs}cutlass/arch/memory.h\"
|
| 811 |
-
#include \"{cutlass_dirs}cutlass/array.h\"
|
| 812 |
-
#include \"{cutlass_dirs}cutlass/cutlass.h\"
|
| 813 |
-
#include \"{cutlass_dirs}cutlass/gemm/gemm.h\"
|
| 814 |
-
#include \"{cutlass_dirs}cutlass/matrix_shape.h\"
|
| 815 |
-
#include \"{cutlass_dirs}cutlass/numeric_types.h\"\n'''.format(cutlass_dirs=self.cutlass_deps_root)
|
| 816 |
-
return code
|
| 817 |
-
|
| 818 |
-
def gen_shared_storage(self):
|
| 819 |
-
code = \
|
| 820 |
-
" template< \n\
|
| 821 |
-
typename Shape_,\n\
|
| 822 |
-
typename Policy_,\n\
|
| 823 |
-
int ThisStage_\n\
|
| 824 |
-
>\n\
|
| 825 |
-
class SharedStorage {\n\
|
| 826 |
-
public:\n\
|
| 827 |
-
using Shape = Shape_;\n\
|
| 828 |
-
using Policy = Policy_;\n\
|
| 829 |
-
static int const ThisStage = ThisStage_;\n\
|
| 830 |
-
using Operator = typename Policy::Operator;\n\
|
| 831 |
-
\
|
| 832 |
-
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;\n\
|
| 833 |
-
\
|
| 834 |
-
/// Tensor reference to the B operand \n\
|
| 835 |
-
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;\n\
|
| 836 |
-
\n\
|
| 837 |
-
/// Shape of the A matrix operand in shared memory \n\
|
| 838 |
-
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,\n\
|
| 839 |
-
Shape::kK * ThisStage +\n\
|
| 840 |
-
Policy::SmemPaddingA::kColumn>;\n\
|
| 841 |
-
\n\
|
| 842 |
-
/// Shape of the B matrix operand in shared memory\n\
|
| 843 |
-
using ShapeB =\n\
|
| 844 |
-
MatrixShape<Shape::kK * ThisStage + Policy::SmemPaddingB::kRow,\n\
|
| 845 |
-
Shape::kN + Policy::SmemPaddingB::kColumn>;\n\
|
| 846 |
-
\n\
|
| 847 |
-
public:\n\
|
| 848 |
-
\n\
|
| 849 |
-
/// Buffer for A operand\n\
|
| 850 |
-
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;\n\
|
| 851 |
-
\n\
|
| 852 |
-
/// Buffer for B operand\n\
|
| 853 |
-
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;\n\
|
| 854 |
-
\n\
|
| 855 |
-
public:\n\
|
| 856 |
-
\n\
|
| 857 |
-
/// Returns a layout object for the A matrix\n\
|
| 858 |
-
CUTLASS_DEVICE\n\
|
| 859 |
-
static typename Operator::LayoutA LayoutA() {\n\
|
| 860 |
-
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});\n\
|
| 861 |
-
}\n\
|
| 862 |
-
\n\
|
| 863 |
-
/// Returns a layout object for the B matrix\n\
|
| 864 |
-
CUTLASS_HOST_DEVICE\n\
|
| 865 |
-
static typename Operator::LayoutB LayoutB() {\n\
|
| 866 |
-
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});\n\
|
| 867 |
-
}\n\
|
| 868 |
-
\n\
|
| 869 |
-
/// Returns a TensorRef to the A operand\n\
|
| 870 |
-
CUTLASS_HOST_DEVICE\n\
|
| 871 |
-
TensorRefA operand_A_ref() {\n\
|
| 872 |
-
return TensorRefA{operand_A.data(), LayoutA()};\n\
|
| 873 |
-
}\n\
|
| 874 |
-
\n\
|
| 875 |
-
/// Returns a TensorRef to the B operand\n\
|
| 876 |
-
CUTLASS_HOST_DEVICE\n\
|
| 877 |
-
TensorRefB operand_B_ref() {\n\
|
| 878 |
-
return TensorRefB{operand_B.data(), LayoutB()};\n\
|
| 879 |
-
}\n\
|
| 880 |
-
CUTLASS_HOST_DEVICE\n\
|
| 881 |
-
void * get_B_Shared_ptr() {\n\
|
| 882 |
-
return operand_B.data();\n\
|
| 883 |
-
}\n\
|
| 884 |
-
};\n"
|
| 885 |
-
return code
|
| 886 |
-
|
| 887 |
-
def gen_using_and_misc(self, b2b_num):
|
| 888 |
-
code_using = ""
|
| 889 |
-
for i in range(b2b_num):
|
| 890 |
-
code_using += "using Operator" +str(i) + " = typename Policy" + str(i) +"::Operator;\n"
|
| 891 |
-
|
| 892 |
-
for i in range(b2b_num):
|
| 893 |
-
code_using += "using WarpGemm" +str(i) + " = typename Policy" + str(i) +"::Operator::Shape;\n"
|
| 894 |
-
|
| 895 |
-
for i in range(b2b_num):
|
| 896 |
-
code_using += "using WarpCount" +str(i) + " = GemmShape<" + helper.var_idx("Shape", i) +"::kM / " + helper.var_idx("WarpGemm", i) +"::kM, "\
|
| 897 |
-
+ helper.var_idx("Shape", i) +"::kN / " + helper.var_idx("WarpGemm", i) +"::kN, "\
|
| 898 |
-
+ helper.var_idx("Shape", i) +"::kK / " + helper.var_idx("WarpGemm", i) +"::kK>;\n"
|
| 899 |
-
|
| 900 |
-
code_misc = ""
|
| 901 |
-
for i in range(b2b_num):
|
| 902 |
-
code_misc += "static int const " + helper.var_idx("kWarpGemmIterations", i) + " = (" + helper.var_idx("WarpGemm", i) + "::kK / " + helper.var_idx("Operator", i) +"::Policy::MmaShape::kK);\n"
|
| 903 |
-
|
| 904 |
-
code = code_using + code_misc + self.gen_shared_storage()
|
| 905 |
-
|
| 906 |
-
for i in range(b2b_num):
|
| 907 |
-
code += "using " + helper.var_idx("SharedStorage", i) + " = SharedStorage<" + helper.var_idx("Shape", i) + ", " + helper.var_idx("Policy", i) +", " + helper.var_idx("Stage", i) + ">;\n"
|
| 908 |
-
|
| 909 |
-
def gen_union_shared_storage(b2b_num):
|
| 910 |
-
code = ""
|
| 911 |
-
for i in range(b2b_num):
|
| 912 |
-
code += " " +helper.var_idx("SharedStorage", i) + " " + helper.var_idx("sharedStorage", i) +";\n"
|
| 913 |
-
return code
|
| 914 |
-
|
| 915 |
-
code += "union B2bMmaSharedStorage {\n" + gen_union_shared_storage(self.b2b_num) + "};\n"
|
| 916 |
-
|
| 917 |
-
for i in range(b2b_num - 1):
|
| 918 |
-
code += helper.var_idx("void * C", i) + "_smm_ptr;\n"
|
| 919 |
-
|
| 920 |
-
return code
|
| 921 |
-
|
| 922 |
-
def gen_protected(self):
|
| 923 |
-
code = "\nprotected:\n"
|
| 924 |
-
code += "typename Operator0::IteratorA warp_tile_iterator_A0_;\n"
|
| 925 |
-
for i in range(self.b2b_num):
|
| 926 |
-
code += "typename Operator" +str(i) + "::IteratorB" +" warp_tile_iterator_B" + str(i) + "_;\n"
|
| 927 |
-
return code
|
| 928 |
-
|
| 929 |
-
def gen_public_member(self):
|
| 930 |
-
code = "\npublic:\n"
|
| 931 |
-
|
| 932 |
-
code += "CUTLASS_DEVICE\n"
|
| 933 |
-
code += \
|
| 934 |
-
"B2bMmaBase(\n" + \
|
| 935 |
-
" B2bMmaSharedStorage & shared_storage,\n" + \
|
| 936 |
-
" int thread_idx,\n" + \
|
| 937 |
-
" int warp_idx,\n" + \
|
| 938 |
-
" int lane_idx\n" + \
|
| 939 |
-
"):\n" + \
|
| 940 |
-
" warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),\n"
|
| 941 |
-
for i in range(self.b2b_num):
|
| 942 |
-
final = ",\n"
|
| 943 |
-
if i == self.b2b_num-1:
|
| 944 |
-
final = "\n"
|
| 945 |
-
|
| 946 |
-
iterator = " warp_tile_iterator_B" + str(i) + "_"
|
| 947 |
-
shared_storage = "shared_storage.sharedStorage" + str(i) + ".operand_B_ref()"
|
| 948 |
-
code += iterator + "(" + shared_storage + ", lane_idx)" + final
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
code += "{\n"
|
| 952 |
-
for i in range(self.b2b_num - 1):
|
| 953 |
-
code += helper.var_idx(" C", i) + helper.var_idx("_smm_ptr = shared_storage.sharedStorage", i) + ".get_B_Shared_ptr();\n"
|
| 954 |
-
code += "}\n"
|
| 955 |
-
|
| 956 |
-
return code
|
| 957 |
-
|
| 958 |
-
def gen_code(self):
|
| 959 |
-
|
| 960 |
-
template_arg = []
|
| 961 |
-
for i in range(self.b2b_num):
|
| 962 |
-
template_arg.append(("typename", helper.var_idx("Shape", i)))
|
| 963 |
-
for i in range(self.b2b_num):
|
| 964 |
-
template_arg.append(("typename", helper.var_idx("Policy", i)))
|
| 965 |
-
for i in range(self.b2b_num):
|
| 966 |
-
template_arg.append((int, helper.var_idx("Stage", i)))
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
code_body = self.gen_using_and_misc(self.b2b_num)
|
| 971 |
-
code_body += self.gen_protected()
|
| 972 |
-
code_body += self.gen_public_member()
|
| 973 |
-
|
| 974 |
-
class_code = gen_ir.gen_template_class("B2bMmaBase", template_arg, code_body)
|
| 975 |
-
|
| 976 |
-
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code)))
|
| 977 |
-
|
| 978 |
-
return code
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
class gen_threadblock:
|
| 982 |
-
def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root):
|
| 983 |
-
self.gen_class_name = gen_class_name
|
| 984 |
-
self.template_param = template_param
|
| 985 |
-
self.b2b_num = b2b_num
|
| 986 |
-
self.file_dir = output_dir + "/threadblock/"
|
| 987 |
-
|
| 988 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 989 |
-
self.project_root = project_root
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
self.gen_b2b_mma_base = gen_b2b_mma_base(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
| 993 |
-
self.gen_b2b_mma_pipelined = gen_b2b_mme_pipelined(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
| 994 |
-
self.gen_default_b2b_mma = gen_default_b2b_mma(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
def gen_code(self, first_use_1stage):
|
| 998 |
-
|
| 999 |
-
base_code = self.gen_b2b_mma_base.gen_code()
|
| 1000 |
-
print("[INFO]: Gen kernel code [b2b_mma_base.h]output Dir: is ", self.file_dir)
|
| 1001 |
-
|
| 1002 |
-
with open(self.file_dir + "b2b_mma_base.h", "w+") as f:
|
| 1003 |
-
f.write(base_code)
|
| 1004 |
-
pipeline_code = self.gen_b2b_mma_pipelined.gen_code(first_use_1stage = first_use_1stage)
|
| 1005 |
-
print("[INFO]: Gen kernel code [b2b_mma_pipelined.h]output Dir: is ", self.file_dir)
|
| 1006 |
-
|
| 1007 |
-
with open(self.file_dir + "b2b_mma_pipelined.h", "w+") as f:
|
| 1008 |
-
f.write(pipeline_code)
|
| 1009 |
-
default_code = self.gen_default_b2b_mma.gen_code()
|
| 1010 |
-
print("[INFO]: Gen kernel code [default_b2b_mma.h]output Dir: is ", self.file_dir)
|
| 1011 |
-
|
| 1012 |
-
with open(self.file_dir + "default_b2b_mma.h", "w+") as f:
|
| 1013 |
-
f.write(default_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py
DELETED
|
@@ -1,456 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import helper
|
| 34 |
-
import gen_ir as ir
|
| 35 |
-
|
| 36 |
-
class gen_turing_impl:
|
| 37 |
-
def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
| 38 |
-
self.fuse_gemm_info = fuse_gemm_info
|
| 39 |
-
self.class_name = gen_class_name
|
| 40 |
-
self.gen_class_name = gen_class_name + "_turing_impl"
|
| 41 |
-
self.user_header_file = ""
|
| 42 |
-
for header in user_header_file:
|
| 43 |
-
self.user_header_file += "#include \"" + header + "\"\n"
|
| 44 |
-
self.output_dir = output_dir
|
| 45 |
-
self.b2b_num = len(fuse_gemm_info)
|
| 46 |
-
|
| 47 |
-
self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
| 48 |
-
|
| 49 |
-
def gen_using(self):
|
| 50 |
-
code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + "<cutlass::half_t>;"
|
| 51 |
-
|
| 52 |
-
return code_using + "\n"
|
| 53 |
-
|
| 54 |
-
def gen_initialize(self):
|
| 55 |
-
code = ""
|
| 56 |
-
for i in range(self.b2b_num):
|
| 57 |
-
code_this = ""
|
| 58 |
-
|
| 59 |
-
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
|
| 60 |
-
beta = "(1)"
|
| 61 |
-
|
| 62 |
-
if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False:
|
| 63 |
-
beta = "(0)"
|
| 64 |
-
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
|
| 65 |
-
k_str = str(self.fuse_gemm_info[i]['mnk'][2])
|
| 66 |
-
if i == 0:
|
| 67 |
-
k_str = "K0"
|
| 68 |
-
code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
|
| 69 |
-
code += code_this
|
| 70 |
-
code += "typename b2b_gemm::Arguments arguments{\n"
|
| 71 |
-
|
| 72 |
-
for i in range(self.b2b_num):
|
| 73 |
-
code += " " + helper.var_idx("problem_size_", i) + ",\n"
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n"
|
| 77 |
-
|
| 78 |
-
for i in range(self.b2b_num):
|
| 79 |
-
|
| 80 |
-
ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
|
| 81 |
-
if i == 0:
|
| 82 |
-
ldmB = "K0"
|
| 83 |
-
|
| 84 |
-
if self.fuse_gemm_info[i]['B_format'] is 'Row':
|
| 85 |
-
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 86 |
-
|
| 87 |
-
ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
|
| 88 |
-
|
| 89 |
-
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n"
|
| 90 |
-
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n"
|
| 91 |
-
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n"
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
for i in range(self.b2b_num):
|
| 95 |
-
code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
|
| 96 |
-
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
|
| 97 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
|
| 98 |
-
code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
|
| 99 |
-
code += "},\n"
|
| 100 |
-
code += " " + "Batch};\n\n"
|
| 101 |
-
|
| 102 |
-
code += " " "b2b_gemm gemm_op;\n"
|
| 103 |
-
code += " " + "gemm_op.initialize(arguments);\n"
|
| 104 |
-
return code + "\n"
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def gen_run(self):
|
| 109 |
-
code = " " + "gemm_op(stream);\n"
|
| 110 |
-
|
| 111 |
-
return code
|
| 112 |
-
|
| 113 |
-
def gen_wrapper(self):
|
| 114 |
-
code_body = ""
|
| 115 |
-
|
| 116 |
-
arg_lists = []
|
| 117 |
-
arg_lists.append(["int", "M"])
|
| 118 |
-
arg_lists.append(["int", "K0"])
|
| 119 |
-
arg_lists.append(["int", "Batch"])
|
| 120 |
-
arg_lists.append(["void*", helper.var_idx("A", 0)])
|
| 121 |
-
for i in range(self.b2b_num):
|
| 122 |
-
arg_lists.append(["void*", helper.var_idx("B", i)])
|
| 123 |
-
arg_lists.append(["void*", helper.var_idx("C", i)])
|
| 124 |
-
arg_lists.append(["void*", helper.var_idx("D", i)])
|
| 125 |
-
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
| 126 |
-
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
| 127 |
-
for arg in epilogue_args:
|
| 128 |
-
arg_tp = arg[0]
|
| 129 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
| 130 |
-
arg_lists.append([arg_tp, arg_name])
|
| 131 |
-
|
| 132 |
-
if self.b2b_num == 1:
|
| 133 |
-
code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta
|
| 134 |
-
code_body += self.gen_turing_unfused.gen_initialize()
|
| 135 |
-
code_body += self.gen_turing_unfused.gen_run()
|
| 136 |
-
else:
|
| 137 |
-
code_body += self.gen_using()
|
| 138 |
-
code_body += self.gen_initialize()
|
| 139 |
-
code_body += self.gen_run()
|
| 140 |
-
|
| 141 |
-
code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
|
| 142 |
-
|
| 143 |
-
return code
|
| 144 |
-
|
| 145 |
-
def gen_code(self):
|
| 146 |
-
|
| 147 |
-
code = self.gen_wrapper()
|
| 148 |
-
helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code)
|
| 149 |
-
|
| 150 |
-
class gen_volta_turing_fuse_act_impl:
|
| 151 |
-
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
| 152 |
-
self.fuse_gemm_info = fuse_gemm_info
|
| 153 |
-
self.gen_class_name = gen_class_name + "_volta_impl"
|
| 154 |
-
self.user_header_file = ""
|
| 155 |
-
for header in user_header_file:
|
| 156 |
-
self.user_header_file += "#include \"" + header + "\"\n"
|
| 157 |
-
self.output_dir = output_dir
|
| 158 |
-
self.b2b_num = len(fuse_gemm_info)
|
| 159 |
-
|
| 160 |
-
def perf_tiling(self, layer_mnk):
|
| 161 |
-
mnk = layer_mnk[:]
|
| 162 |
-
block_tile = mnk[:]
|
| 163 |
-
block_tile[2] = 32 # force the K tile to be 32
|
| 164 |
-
|
| 165 |
-
# M tile gen
|
| 166 |
-
block_tile[0] = 32
|
| 167 |
-
|
| 168 |
-
# N tile gen
|
| 169 |
-
if mnk[1] > 128:
|
| 170 |
-
block_tile[1] = 256
|
| 171 |
-
elif mnk[1] > 64:
|
| 172 |
-
block_tile[1] = 128
|
| 173 |
-
elif mnk[1] > 32:
|
| 174 |
-
block_tile[1] = 64
|
| 175 |
-
else :
|
| 176 |
-
block_tile[1] = 32
|
| 177 |
-
|
| 178 |
-
warp_tile = block_tile[:]
|
| 179 |
-
if block_tile[1] == 256:
|
| 180 |
-
warp_tile[1] = 64
|
| 181 |
-
elif block_tile[1] == 128:
|
| 182 |
-
warp_tile[1] = 32
|
| 183 |
-
elif block_tile[1] == 64:
|
| 184 |
-
warp_tile[1] = 32
|
| 185 |
-
else :
|
| 186 |
-
warp_tile[1] = 32
|
| 187 |
-
|
| 188 |
-
warp_tile[0] = 32
|
| 189 |
-
|
| 190 |
-
return block_tile, warp_tile
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp):
|
| 194 |
-
epilogue_setted_type = epilogue_tp
|
| 195 |
-
cutlass_epilogue_name = "LinearCombinationRelu"
|
| 196 |
-
if epilogue_setted_type.lower() == 'leakyrelu':
|
| 197 |
-
cutlass_epilogue_name = "LinearCombinationLeakyRelu"
|
| 198 |
-
elif epilogue_setted_type.lower() == 'identity':
|
| 199 |
-
cutlass_epilogue_name = "LinearCombination"
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
n_mod_8 = n % 4
|
| 203 |
-
N_align_elements = 1
|
| 204 |
-
if n_mod_8 == 0:
|
| 205 |
-
N_align_elements = 8
|
| 206 |
-
elif n_mod_8 == 4:
|
| 207 |
-
N_align_elements = 4
|
| 208 |
-
elif n_mod_8 == 2 or n_mod_8 == 6:
|
| 209 |
-
N_align_elements = 2
|
| 210 |
-
|
| 211 |
-
epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">"
|
| 212 |
-
|
| 213 |
-
return epilogue_str
|
| 214 |
-
|
| 215 |
-
def gen_using(self, volta = True):
|
| 216 |
-
code_using = ""
|
| 217 |
-
volta_arch = "cutlass::arch::Sm70"
|
| 218 |
-
volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>"
|
| 219 |
-
|
| 220 |
-
turing_arch = "cutlass::arch::Sm75"
|
| 221 |
-
turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>"
|
| 222 |
-
|
| 223 |
-
arch = ""
|
| 224 |
-
tc = ""
|
| 225 |
-
if volta:
|
| 226 |
-
arch = volta_arch
|
| 227 |
-
tc = volta_tc
|
| 228 |
-
else:
|
| 229 |
-
arch = turing_arch
|
| 230 |
-
tc = turing_tc
|
| 231 |
-
|
| 232 |
-
for i in range(self.b2b_num):
|
| 233 |
-
|
| 234 |
-
k = self.fuse_gemm_info[i]['mnk'][2]
|
| 235 |
-
|
| 236 |
-
k_mod_8 = k % 4
|
| 237 |
-
ab_ldm = 1
|
| 238 |
-
if k_mod_8 == 0:
|
| 239 |
-
ab_ldm = 8
|
| 240 |
-
elif k_mod_8 == 4:
|
| 241 |
-
ab_ldm = 4
|
| 242 |
-
elif k_mod_8 == 2 or k_mod_8 == 6:
|
| 243 |
-
ab_ldm = 2
|
| 244 |
-
|
| 245 |
-
block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk'])
|
| 246 |
-
|
| 247 |
-
this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n"
|
| 248 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n"
|
| 249 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n"
|
| 250 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n"
|
| 251 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n"
|
| 252 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n"
|
| 253 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n"
|
| 254 |
-
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n"
|
| 255 |
-
this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n"
|
| 256 |
-
this_gemm_config += " " + arch + ",\n"
|
| 257 |
-
this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n"
|
| 258 |
-
this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n"
|
| 259 |
-
this_gemm_config += " " + tc + ",\n"
|
| 260 |
-
this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n"
|
| 261 |
-
this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n"
|
| 262 |
-
this_gemm_config += " " + "2,\n"
|
| 263 |
-
this_gemm_config += " " + str(ab_ldm) + ",\n"
|
| 264 |
-
this_gemm_config += " " + str(ab_ldm) + ">;\n"
|
| 265 |
-
|
| 266 |
-
code_using += this_gemm_config + "\n"
|
| 267 |
-
|
| 268 |
-
return code_using + "\n"
|
| 269 |
-
|
| 270 |
-
def gen_initialize(self):
|
| 271 |
-
code = ""
|
| 272 |
-
for i in range(self.b2b_num):
|
| 273 |
-
code_this = ""
|
| 274 |
-
|
| 275 |
-
N_str = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 276 |
-
|
| 277 |
-
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
|
| 278 |
-
beta = "(1)"
|
| 279 |
-
if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False:
|
| 280 |
-
beta = "(0)"
|
| 281 |
-
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
|
| 282 |
-
|
| 283 |
-
k_str = str(self.fuse_gemm_info[i]['mnk'][2])
|
| 284 |
-
if i == 0:
|
| 285 |
-
k_str = "K0"
|
| 286 |
-
code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
|
| 287 |
-
code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
|
| 288 |
-
code_this += " " + helper.var_idx("problem_size_", i) + ",\n"
|
| 289 |
-
ldmA = k_str
|
| 290 |
-
ldmB = k_str
|
| 291 |
-
ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 292 |
-
|
| 293 |
-
ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
|
| 294 |
-
|
| 295 |
-
if self.fuse_gemm_info[i]['A_format'] is 'Col':
|
| 296 |
-
ldmA = "M"
|
| 297 |
-
if self.fuse_gemm_info[i]['B_format'] is 'Row':
|
| 298 |
-
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
|
| 299 |
-
if self.fuse_gemm_info[i]['C_format'] is 'Col':
|
| 300 |
-
ldmC = "M"
|
| 301 |
-
|
| 302 |
-
if i == 0:
|
| 303 |
-
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
| 304 |
-
else:
|
| 305 |
-
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
| 306 |
-
|
| 307 |
-
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
|
| 308 |
-
|
| 309 |
-
M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
|
| 310 |
-
|
| 311 |
-
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
|
| 312 |
-
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n"
|
| 313 |
-
code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
|
| 314 |
-
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
|
| 315 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
|
| 316 |
-
code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
|
| 317 |
-
code_this += " },\n"
|
| 318 |
-
code_this += " " + "Batch};\n"
|
| 319 |
-
|
| 320 |
-
code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
|
| 321 |
-
code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n"
|
| 322 |
-
|
| 323 |
-
code += code_this + "\n"
|
| 324 |
-
return code + "\n"
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
def gen_run(self):
|
| 328 |
-
code = ""
|
| 329 |
-
for i in range(self.b2b_num):
|
| 330 |
-
code_this = ""
|
| 331 |
-
code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n"
|
| 332 |
-
|
| 333 |
-
code += code_this
|
| 334 |
-
return code
|
| 335 |
-
|
| 336 |
-
def gen_wrapper(self):
|
| 337 |
-
code_body = ""
|
| 338 |
-
|
| 339 |
-
arg_lists = []
|
| 340 |
-
arg_lists.append(["int", "M"])
|
| 341 |
-
arg_lists.append(["int", "K0"])
|
| 342 |
-
arg_lists.append(["int", "Batch"])
|
| 343 |
-
arg_lists.append(["void*", helper.var_idx("A", 0)])
|
| 344 |
-
for i in range(self.b2b_num):
|
| 345 |
-
arg_lists.append(["void*", helper.var_idx("B", i)])
|
| 346 |
-
arg_lists.append(["void*", helper.var_idx("C", i)])
|
| 347 |
-
arg_lists.append(["void*", helper.var_idx("D", i)])
|
| 348 |
-
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
| 349 |
-
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
| 350 |
-
for arg in epilogue_args:
|
| 351 |
-
arg_tp = arg[0]
|
| 352 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
| 353 |
-
arg_lists.append([arg_tp, arg_name])
|
| 354 |
-
code_body += self.gen_using()
|
| 355 |
-
code_body += self.gen_initialize()
|
| 356 |
-
code_body += self.gen_run()
|
| 357 |
-
|
| 358 |
-
code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
|
| 359 |
-
|
| 360 |
-
return code
|
| 361 |
-
|
| 362 |
-
def gen_code(self):
|
| 363 |
-
code = self.gen_wrapper()
|
| 364 |
-
helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code)
|
| 365 |
-
|
| 366 |
-
class gen_one_API:
|
| 367 |
-
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
| 368 |
-
self.fuse_gemm_info = fuse_gemm_info
|
| 369 |
-
self.gen_class_name = gen_class_name
|
| 370 |
-
self.user_header_file = ""
|
| 371 |
-
for header in user_header_file:
|
| 372 |
-
self.user_header_file += "#include \"" + header + "\"\n"
|
| 373 |
-
self.output_dir = output_dir
|
| 374 |
-
self.b2b_num = len(fuse_gemm_info)
|
| 375 |
-
|
| 376 |
-
self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
| 377 |
-
|
| 378 |
-
self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
| 379 |
-
|
| 380 |
-
def gen_CUTLASS_irrelevant_API(self):
|
| 381 |
-
code = ""
|
| 382 |
-
code += "#include <cuda_runtime.h>\n"
|
| 383 |
-
code += "#include <cassert>\n"
|
| 384 |
-
|
| 385 |
-
param_name = "Fused" + str(self.b2b_num) + "xGemm_"
|
| 386 |
-
for i in range(self.b2b_num):
|
| 387 |
-
param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_"
|
| 388 |
-
param_name += "Params"
|
| 389 |
-
params = ""
|
| 390 |
-
params += " " + "int M;\n"
|
| 391 |
-
params += " " + "int K0;\n"
|
| 392 |
-
params += " " + "int Batch;\n"
|
| 393 |
-
params += " " + "const void* A0;\n"
|
| 394 |
-
for i in range(self.b2b_num):
|
| 395 |
-
params += " " + "const void* " + helper.var_idx("B", i) + ";\n"
|
| 396 |
-
params += " " + "const void* " + helper.var_idx("C", i) + ";\n"
|
| 397 |
-
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
| 398 |
-
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
| 399 |
-
for arg in epilogue_args:
|
| 400 |
-
arg_tp = arg[0]
|
| 401 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
| 402 |
-
params += " " + arg_tp + " " + arg_name + ";\n"
|
| 403 |
-
params += " " + "void* " + helper.var_idx("D", i) + ";\n"
|
| 404 |
-
code += ir.gen_struct(param_name, params)
|
| 405 |
-
code += "using Param = " + param_name + ";\n"
|
| 406 |
-
code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n"
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
return code
|
| 410 |
-
|
| 411 |
-
def gen_one_api(self):
|
| 412 |
-
code = ""
|
| 413 |
-
code += "/* Auto Generated code - Do not edit.*/\n"
|
| 414 |
-
code += "#include \"cutlass_irrelevant.h\"\n"
|
| 415 |
-
code += "#include \"api.h\"\n"
|
| 416 |
-
code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n"
|
| 417 |
-
|
| 418 |
-
code += " " + "if (sm == 70) \n"
|
| 419 |
-
code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
|
| 420 |
-
for i in range(self.b2b_num):
|
| 421 |
-
code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
|
| 422 |
-
code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
|
| 423 |
-
code += helper.var_idx("param.D", i) + ", "
|
| 424 |
-
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
| 425 |
-
for arg in epilogue_args:
|
| 426 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
| 427 |
-
code += "param." + arg_name + ", "
|
| 428 |
-
code += "stream);\n"
|
| 429 |
-
code += " " + "else if(sm >= 75) \n"
|
| 430 |
-
code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
|
| 431 |
-
for i in range(self.b2b_num):
|
| 432 |
-
code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
|
| 433 |
-
code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
|
| 434 |
-
code += helper.var_idx("param.D", i) + ", "
|
| 435 |
-
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
| 436 |
-
for arg in epilogue_args:
|
| 437 |
-
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
| 438 |
-
code += "param." + arg_name + ", "
|
| 439 |
-
code += "stream);\n"
|
| 440 |
-
code += " " + "else assert(0);\n"
|
| 441 |
-
code += "}\n"
|
| 442 |
-
return code
|
| 443 |
-
|
| 444 |
-
def gen_code(self):
|
| 445 |
-
|
| 446 |
-
turing_code = self.gen_turing.gen_wrapper()
|
| 447 |
-
volta_code = self.gen_volta.gen_wrapper()
|
| 448 |
-
cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API()
|
| 449 |
-
|
| 450 |
-
one_api_code = self.gen_one_api()
|
| 451 |
-
with open(self.output_dir + "one_api.cu", "w+") as f:
|
| 452 |
-
f.write(one_api_code)
|
| 453 |
-
|
| 454 |
-
helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code)
|
| 455 |
-
|
| 456 |
-
helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import helper
|
| 34 |
-
import gen_ir as ir
|
| 35 |
-
|
| 36 |
-
import gen_turing_and_volta as gen_basic
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class gen_verify:
|
| 40 |
-
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
| 41 |
-
self.fuse_gemm_info = fuse_gemm_info
|
| 42 |
-
self.name = gen_class_name + "_verify"
|
| 43 |
-
self.b2b_num = len(fuse_gemm_info)
|
| 44 |
-
self.params = []
|
| 45 |
-
self.user_header_file = ""
|
| 46 |
-
for header in user_header_file:
|
| 47 |
-
self.user_header_file += "#include \"" + header + "\"\n"
|
| 48 |
-
self.separate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
| 49 |
-
self.gen_params()
|
| 50 |
-
self.output_dir = output_dir
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def gen_code(self):
|
| 54 |
-
code = ""
|
| 55 |
-
code += self.user_header_file
|
| 56 |
-
code += self.separate_cutlass.gen_using(False) #False -> Turing, True -> Volta
|
| 57 |
-
|
| 58 |
-
code_body = ""
|
| 59 |
-
for i in range(self.b2b_num):
|
| 60 |
-
code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
|
| 61 |
-
code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n"
|
| 62 |
-
|
| 63 |
-
code_body += self.separate_cutlass.gen_run()
|
| 64 |
-
|
| 65 |
-
code += ir.gen_func(self.name, self.params, code_body)
|
| 66 |
-
helper.write_2_headfile("cutlass_verify.h", self.output_dir, code)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def gen_params(self):
|
| 70 |
-
for i in range(self.b2b_num):
|
| 71 |
-
self.params.append(
|
| 72 |
-
(
|
| 73 |
-
helper.var_idx("typename Gemm", i)+ "::Arguments",
|
| 74 |
-
helper.var_idx("Arguments_", i)
|
| 75 |
-
)
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_params(self, declaration = True):
|
| 80 |
-
code = ""
|
| 81 |
-
if declaration:
|
| 82 |
-
for param in self.params:
|
| 83 |
-
code += param[0] + " " + param[1] + ";\n"
|
| 84 |
-
|
| 85 |
-
return code
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def gen_initialize():
|
| 89 |
-
code = ""
|
| 90 |
-
initialize_code = self.separate_cutlass.gen_initialize()
|
| 91 |
-
|
| 92 |
-
code = ir.gen_func("initialize", [[]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
def type_2_cutlass_type(input_type = "fp16"):
|
| 34 |
-
# float point type
|
| 35 |
-
if input_type == "fp32":
|
| 36 |
-
return "float"
|
| 37 |
-
if input_type == "bf16":
|
| 38 |
-
return "cutlass::bfloat16_t"
|
| 39 |
-
if input_type == "fp16":
|
| 40 |
-
return "cutlass::half_t"
|
| 41 |
-
|
| 42 |
-
# integer type
|
| 43 |
-
if(input_type == "int32"):
|
| 44 |
-
return "int32_t"
|
| 45 |
-
if(input_type == "int8"):
|
| 46 |
-
return "int8_t"
|
| 47 |
-
|
| 48 |
-
if input_type == 'Row':
|
| 49 |
-
return 'cutlass::layout::RowMajor'
|
| 50 |
-
if input_type == 'Col':
|
| 51 |
-
return 'cutlass::layout::ColumnMajor'
|
| 52 |
-
|
| 53 |
-
def cvt_2_cutlass_shape(gemm_shape):
|
| 54 |
-
# gemm shape
|
| 55 |
-
if len(gemm_shape) == 3:
|
| 56 |
-
val = "cutlass::gemm::GemmShape<" \
|
| 57 |
-
+ str(gemm_shape[0]) + ", " \
|
| 58 |
-
+ str(gemm_shape[1]) + ", " \
|
| 59 |
-
+ str(gemm_shape[2]) + ">"
|
| 60 |
-
return val
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def write_2_headfile(filename, file_dir, string):
|
| 64 |
-
with open(file_dir + filename, 'w') as f:
|
| 65 |
-
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
|
| 66 |
-
|
| 67 |
-
def var_idx(variable, index):
|
| 68 |
-
return variable + str(index)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def list_2_string(input_list, ):
|
| 72 |
-
rtn_string = ""
|
| 73 |
-
|
| 74 |
-
cnt = 0
|
| 75 |
-
|
| 76 |
-
for element in input_list:
|
| 77 |
-
final = ", \n"
|
| 78 |
-
if cnt == len(input_list) - 1:
|
| 79 |
-
final = "\n"
|
| 80 |
-
cnt += 1
|
| 81 |
-
rtn_string += str(element) + final
|
| 82 |
-
|
| 83 |
-
return rtn_string
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def get_epilogue_info(layer_info):
|
| 87 |
-
return layer_info['epilogue']
|
| 88 |
-
|
| 89 |
-
def get_epilogue_tp(layer_info):
|
| 90 |
-
epilogue_info = get_epilogue_info(layer_info)
|
| 91 |
-
return epilogue_info['tp']
|
| 92 |
-
|
| 93 |
-
def get_epilogue_add_bias_or_not(layer_info):
|
| 94 |
-
epilogue_info = get_epilogue_info(layer_info)
|
| 95 |
-
return epilogue_info['bias']['addbias']
|
| 96 |
-
|
| 97 |
-
def get_epilogue_add_bias_tp(layer_info):
|
| 98 |
-
epilogue_info = get_epilogue_info(layer_info)
|
| 99 |
-
return epilogue_info['bias']['bias_tp']
|
| 100 |
-
|
| 101 |
-
def get_epilogue_args(layer_info):
|
| 102 |
-
epilogue_info = get_epilogue_info(layer_info)
|
| 103 |
-
return epilogue_info['args']
|
| 104 |
-
|
| 105 |
-
def get_epilogue_bias_shape(layer_info):
|
| 106 |
-
bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
|
| 107 |
-
mn_shape = layer_info['mnk'][:-1]
|
| 108 |
-
|
| 109 |
-
if bias_tp == 'mat':
|
| 110 |
-
mn_shape[0] = 'M'
|
| 111 |
-
return mn_shape
|
| 112 |
-
elif bias_tp == 'vec':
|
| 113 |
-
mn_shape[0] = 1
|
| 114 |
-
return mn_shape
|
| 115 |
-
else:
|
| 116 |
-
assert(0)
|
| 117 |
-
|
| 118 |
-
def get_epilogue_bias_ldm(layer_info):
|
| 119 |
-
bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
|
| 120 |
-
mn_shape = layer_info['mnk'][:-1]
|
| 121 |
-
|
| 122 |
-
c_layout = layer_info['C_format'].lower()
|
| 123 |
-
|
| 124 |
-
if c_layout != 'row':
|
| 125 |
-
assert(0)
|
| 126 |
-
|
| 127 |
-
if bias_tp == 'mat':
|
| 128 |
-
return mn_shape[1]
|
| 129 |
-
elif bias_tp == 'vec':
|
| 130 |
-
return 0
|
| 131 |
-
else:
|
| 132 |
-
assert(0)
|
| 133 |
-
|
| 134 |
-
def get_epilogue_compute_tp(layer_info):
|
| 135 |
-
return layer_info['Acc_tp']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import os
|
| 34 |
-
|
| 35 |
-
class replace_fix_impl:
|
| 36 |
-
def __init__(self, src_dir, dst_dir, cutlass_deps_root):
|
| 37 |
-
self.src_dir = src_dir
|
| 38 |
-
self.dst_dir = dst_dir
|
| 39 |
-
self.cutlass_deps_root = cutlass_deps_root
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def gen_code(self):
|
| 44 |
-
for sub_dir in os.walk(self.src_dir):
|
| 45 |
-
files_in_sub_dir = sub_dir[2]
|
| 46 |
-
|
| 47 |
-
src_dirs = sub_dir[0]
|
| 48 |
-
output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):]
|
| 49 |
-
|
| 50 |
-
if not os.path.exists(output_dirs):
|
| 51 |
-
os.mkdir(output_dirs)
|
| 52 |
-
|
| 53 |
-
for f in files_in_sub_dir:
|
| 54 |
-
with open(src_dirs +"/" + f, 'r') as current_file:
|
| 55 |
-
output_lines = []
|
| 56 |
-
lines = current_file.readlines()
|
| 57 |
-
|
| 58 |
-
for line in lines:
|
| 59 |
-
if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"):
|
| 60 |
-
new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):]
|
| 61 |
-
# print(new_line)
|
| 62 |
-
output_lines.append(new_line)
|
| 63 |
-
else:
|
| 64 |
-
output_lines.append(line)
|
| 65 |
-
|
| 66 |
-
with open(output_dirs + "/" + f, "w+") as dest_file:
|
| 67 |
-
dest_file.writelines(output_lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h
DELETED
|
@@ -1,292 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
#include <cuda_fp16.h>
|
| 34 |
-
|
| 35 |
-
template <typename T>
|
| 36 |
-
__device__
|
| 37 |
-
T add(T const & a, T const &b){
|
| 38 |
-
return (a + b);
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
template <>
|
| 42 |
-
__device__
|
| 43 |
-
half2 add(half2 const & a, half2 const &b){
|
| 44 |
-
return (__hadd2(a,b));
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
template <typename T>
|
| 48 |
-
struct RELU{
|
| 49 |
-
__device__
|
| 50 |
-
T operator()(T const & a){
|
| 51 |
-
return a > T(0) ? a : T(0);
|
| 52 |
-
}
|
| 53 |
-
__device__
|
| 54 |
-
half2 operator()(half2 const & a){
|
| 55 |
-
float2 a_fp32x2 = __half22float2(a);
|
| 56 |
-
a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f;
|
| 57 |
-
a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f;
|
| 58 |
-
if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f)
|
| 59 |
-
printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y);
|
| 60 |
-
return __float22half2_rn(a_fp32x2);
|
| 61 |
-
}
|
| 62 |
-
};
|
| 63 |
-
|
| 64 |
-
template <typename T>
|
| 65 |
-
struct LEAKY_RELU{
|
| 66 |
-
__device__
|
| 67 |
-
T operator()(T const & a, T const & scale = half(1)){
|
| 68 |
-
return a > T(0) ? a : scale * a;
|
| 69 |
-
}
|
| 70 |
-
__device__
|
| 71 |
-
half2 operator()(half2 const & a, half const & scale = half(1)){
|
| 72 |
-
half2 zero = __half2half2(half(0));
|
| 73 |
-
half2 gt_zero = __hge2(a, zero);
|
| 74 |
-
half2 le_zero = __hle2(a, zero);
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
half2 scale_f16x2 = __half2half2(scale);
|
| 78 |
-
half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero);
|
| 79 |
-
return __hmul2(a, mask_scale_f16x2);
|
| 80 |
-
}
|
| 81 |
-
};
|
| 82 |
-
|
| 83 |
-
template <int N, int BLOCKDIM>
|
| 84 |
-
__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){
|
| 85 |
-
|
| 86 |
-
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
| 87 |
-
|
| 88 |
-
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
| 89 |
-
|
| 90 |
-
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
| 91 |
-
|
| 92 |
-
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
| 93 |
-
|
| 94 |
-
LEAKY_RELU<half> Act;
|
| 95 |
-
Access_tp src_v[iter];
|
| 96 |
-
Access_tp bias_v[iter];
|
| 97 |
-
|
| 98 |
-
int batch_id = blockIdx.y;
|
| 99 |
-
int batch_offset = batch_id * gridDim.x * N;
|
| 100 |
-
|
| 101 |
-
for(int i = 0; i < iter; i++){
|
| 102 |
-
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
| 103 |
-
if (idx < N){
|
| 104 |
-
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
| 105 |
-
if (mat_bias)
|
| 106 |
-
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
|
| 107 |
-
else
|
| 108 |
-
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
|
| 109 |
-
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale);
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
}
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
template <int N, int BLOCKDIM>
|
| 118 |
-
__global__ void leaky_and_activation(half* inout, half scale){
|
| 119 |
-
|
| 120 |
-
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
| 121 |
-
|
| 122 |
-
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
| 123 |
-
|
| 124 |
-
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
| 125 |
-
|
| 126 |
-
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
| 127 |
-
|
| 128 |
-
int batch_id = blockIdx.y;
|
| 129 |
-
int batch_offset = batch_id * gridDim.x * N;
|
| 130 |
-
|
| 131 |
-
LEAKY_RELU<half> Act;
|
| 132 |
-
Access_tp src_v[iter];
|
| 133 |
-
|
| 134 |
-
for(int i = 0; i < iter; i++){
|
| 135 |
-
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
| 136 |
-
if (idx < N){
|
| 137 |
-
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
| 138 |
-
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale);
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
}
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
template <int N, int BLOCKDIM>
|
| 147 |
-
void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){
|
| 148 |
-
|
| 149 |
-
dim3 grid(m, b);
|
| 150 |
-
if (bias == nullptr)
|
| 151 |
-
leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, scale);
|
| 152 |
-
else
|
| 153 |
-
leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, scale, mat_bias);
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
template <int N, int BLOCKDIM>
|
| 157 |
-
__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){
|
| 158 |
-
|
| 159 |
-
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
| 160 |
-
|
| 161 |
-
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
| 162 |
-
|
| 163 |
-
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
| 164 |
-
|
| 165 |
-
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
| 166 |
-
|
| 167 |
-
RELU<half> Act;
|
| 168 |
-
Access_tp src_v[iter];
|
| 169 |
-
Access_tp bias_v[iter];
|
| 170 |
-
|
| 171 |
-
int batch_id = blockIdx.y;
|
| 172 |
-
int batch_offset = batch_id * gridDim.x * N;
|
| 173 |
-
|
| 174 |
-
for(int i = 0; i < iter; i++){
|
| 175 |
-
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
| 176 |
-
if (idx < N){
|
| 177 |
-
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
| 178 |
-
if (mat_bias)
|
| 179 |
-
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
|
| 180 |
-
else
|
| 181 |
-
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
|
| 182 |
-
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]));
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
}
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
template <int N, int BLOCKDIM>
|
| 191 |
-
__global__ void relu_and_activation(half* inout){
|
| 192 |
-
|
| 193 |
-
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
| 194 |
-
|
| 195 |
-
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
| 196 |
-
|
| 197 |
-
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
| 198 |
-
|
| 199 |
-
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
| 200 |
-
|
| 201 |
-
int batch_id = blockIdx.y;
|
| 202 |
-
int batch_offset = batch_id * gridDim.x * N;
|
| 203 |
-
|
| 204 |
-
RELU<half> Act;
|
| 205 |
-
Access_tp src_v[iter];
|
| 206 |
-
|
| 207 |
-
for(int i = 0; i < iter; i++){
|
| 208 |
-
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
| 209 |
-
if (idx < N){
|
| 210 |
-
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
| 211 |
-
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]);
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
}
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
template <int N, int BLOCKDIM>
|
| 220 |
-
void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
|
| 221 |
-
dim3 grid(m, b);
|
| 222 |
-
if (bias == nullptr)
|
| 223 |
-
relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
|
| 224 |
-
else
|
| 225 |
-
relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
template <int N, int BLOCKDIM>
|
| 230 |
-
__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){
|
| 231 |
-
|
| 232 |
-
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
| 233 |
-
|
| 234 |
-
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
| 235 |
-
|
| 236 |
-
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
| 237 |
-
|
| 238 |
-
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
| 239 |
-
|
| 240 |
-
int batch_id = blockIdx.y;
|
| 241 |
-
int batch_offset = batch_id * gridDim.x * N;
|
| 242 |
-
|
| 243 |
-
Access_tp src_v[iter];
|
| 244 |
-
Access_tp bias_v[iter];
|
| 245 |
-
|
| 246 |
-
for(int i = 0; i < iter; i++){
|
| 247 |
-
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
| 248 |
-
if (idx < N){
|
| 249 |
-
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
| 250 |
-
if (mat_bias)
|
| 251 |
-
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
|
| 252 |
-
else
|
| 253 |
-
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
|
| 254 |
-
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i]));
|
| 255 |
-
}
|
| 256 |
-
|
| 257 |
-
}
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
template <int N, int BLOCKDIM>
|
| 261 |
-
__global__ void identity_and_activation(half* inout){
|
| 262 |
-
|
| 263 |
-
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
| 264 |
-
|
| 265 |
-
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
| 266 |
-
|
| 267 |
-
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
| 268 |
-
|
| 269 |
-
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
| 270 |
-
|
| 271 |
-
int batch_id = blockIdx.y;
|
| 272 |
-
int batch_offset = batch_id * gridDim.x * N;
|
| 273 |
-
Access_tp src_v[iter];
|
| 274 |
-
|
| 275 |
-
for(int i = 0; i < iter; i++){
|
| 276 |
-
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
| 277 |
-
if (idx < N){
|
| 278 |
-
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
| 279 |
-
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]);
|
| 280 |
-
}
|
| 281 |
-
|
| 282 |
-
}
|
| 283 |
-
}
|
| 284 |
-
|
| 285 |
-
template <int N, int BLOCKDIM>
|
| 286 |
-
void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
|
| 287 |
-
dim3 grid(m, b);
|
| 288 |
-
if (bias == nullptr)
|
| 289 |
-
identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
|
| 290 |
-
else
|
| 291 |
-
identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
|
| 292 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
#define TI(tag) \
|
| 34 |
-
cudaEvent_t _event_start_ ##tag; \
|
| 35 |
-
cudaEvent_t _event_end_ ##tag; \
|
| 36 |
-
float _event_time_ ##tag; \
|
| 37 |
-
cudaEventCreate(& _event_start_ ##tag); \
|
| 38 |
-
cudaEventCreate(& _event_end_ ##tag); \
|
| 39 |
-
cudaEventRecord(_event_start_ ##tag);
|
| 40 |
-
|
| 41 |
-
#define TO(tag, str, times) \
|
| 42 |
-
cudaEventRecord(_event_end_ ##tag); \
|
| 43 |
-
cudaEventSynchronize(_event_end_ ##tag); \
|
| 44 |
-
cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \
|
| 45 |
-
float _event_time_once_ ##tag = _event_time_ ##tag / times; \
|
| 46 |
-
printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \
|
| 47 |
-
cudaDeviceSynchronize(); \
|
| 48 |
-
printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError()));
|
| 49 |
-
|
| 50 |
-
template<typename T>
|
| 51 |
-
struct memory_unit{
|
| 52 |
-
T* host_ptr;
|
| 53 |
-
T* device_ptr;
|
| 54 |
-
int size_bytes;
|
| 55 |
-
int elements;
|
| 56 |
-
void h2d(){
|
| 57 |
-
cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice);
|
| 58 |
-
}
|
| 59 |
-
void d2h(){
|
| 60 |
-
cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost);
|
| 61 |
-
}
|
| 62 |
-
void free_all(){
|
| 63 |
-
free(host_ptr);
|
| 64 |
-
cudaFree(device_ptr);
|
| 65 |
-
}
|
| 66 |
-
memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){
|
| 67 |
-
host_ptr = (T*) malloc(elements_ * sizeof(T));
|
| 68 |
-
cudaMalloc((void**)&device_ptr, elements_ * sizeof(T));
|
| 69 |
-
}
|
| 70 |
-
void init(int abs_range = 1){
|
| 71 |
-
for(int i = 0; i < elements; i++){
|
| 72 |
-
host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range);
|
| 73 |
-
}
|
| 74 |
-
h2d();
|
| 75 |
-
}
|
| 76 |
-
};
|
| 77 |
-
|
| 78 |
-
template<typename T>
|
| 79 |
-
int check_result(T * a, T * b, int N){
|
| 80 |
-
int cnt = 0;
|
| 81 |
-
for(int i = 0; i < N; i ++){
|
| 82 |
-
float std = float(a[i]);
|
| 83 |
-
float my = float(b[i]);
|
| 84 |
-
|
| 85 |
-
if(abs(std - my) / abs(std) > 1e-2)
|
| 86 |
-
{
|
| 87 |
-
// printf("my: %f , std: %f\n", my, std);
|
| 88 |
-
cnt++;
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
}
|
| 92 |
-
printf("total err: %d / %d\n", cnt, N);
|
| 93 |
-
return cnt;
|
| 94 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h
DELETED
|
@@ -1,499 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Performs a dual gemm in one fused kernel:
|
| 33 |
-
```
|
| 34 |
-
D0 = epilogue0(X @ B0, C0)
|
| 35 |
-
D1 = epilogue1(X @ B1, C1)
|
| 36 |
-
D2 = element_wise(D0, D1)
|
| 37 |
-
```
|
| 38 |
-
*/
|
| 39 |
-
|
| 40 |
-
#pragma once
|
| 41 |
-
|
| 42 |
-
#include "cutlass/cutlass.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
#include "cutlass/arch/arch.h"
|
| 45 |
-
#include "cutlass/device_kernel.h"
|
| 46 |
-
|
| 47 |
-
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
| 48 |
-
|
| 49 |
-
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
| 50 |
-
#include "cutlass/gemm/threadblock/default_mma.h"
|
| 51 |
-
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 52 |
-
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 53 |
-
|
| 54 |
-
#include "../kernel/dual_gemm.h"
|
| 55 |
-
#include "../dual_gemm_common.h"
|
| 56 |
-
|
| 57 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 58 |
-
|
| 59 |
-
namespace cutlass {
|
| 60 |
-
namespace gemm {
|
| 61 |
-
namespace device {
|
| 62 |
-
|
| 63 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
-
|
| 65 |
-
template <
|
| 66 |
-
/// Element type for A matrix operand
|
| 67 |
-
typename ElementA_,
|
| 68 |
-
/// Layout type for A matrix operand
|
| 69 |
-
typename LayoutA_,
|
| 70 |
-
/// Element type for B matrix operand
|
| 71 |
-
typename ElementB_,
|
| 72 |
-
/// Layout type for B0 matrix operand
|
| 73 |
-
typename LayoutB0_,
|
| 74 |
-
/// Layout type for B1 matrix operand
|
| 75 |
-
typename LayoutB1_,
|
| 76 |
-
/// Element type for C and D matrix operands
|
| 77 |
-
typename ElementC_,
|
| 78 |
-
/// Layout type for C and D matrix operands
|
| 79 |
-
typename LayoutC_,
|
| 80 |
-
/// Element type for internal accumulation
|
| 81 |
-
typename ElementAccumulator_,
|
| 82 |
-
/// Operator class tag
|
| 83 |
-
typename OperatorClass_,
|
| 84 |
-
/// Tag indicating architecture to tune for
|
| 85 |
-
typename ArchTag_,
|
| 86 |
-
/// Threadblock-level tile size (concept: GemmShape)
|
| 87 |
-
typename ThreadblockShape_,
|
| 88 |
-
/// Warp-level tile size (concept: GemmShape)
|
| 89 |
-
typename WarpShape_,
|
| 90 |
-
/// Instruction-level tile size (concept: GemmShape)
|
| 91 |
-
typename InstructionShape_,
|
| 92 |
-
/// Epilogue output operator
|
| 93 |
-
typename EpilogueOutputOp0_,
|
| 94 |
-
typename EpilogueOutputOp1_,
|
| 95 |
-
typename EpilogueOutputOp2_,
|
| 96 |
-
/// Threadblock-level swizzling operator
|
| 97 |
-
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
| 98 |
-
/// Number of stages used in the pipelined mainloop
|
| 99 |
-
int Stages =
|
| 100 |
-
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
| 101 |
-
ElementC_, ElementAccumulator_>::kStages,
|
| 102 |
-
bool StoreD0 = true,
|
| 103 |
-
bool StoreD1 = true,
|
| 104 |
-
/// If true, kernel supports split-K with serial reduction
|
| 105 |
-
bool SplitKSerial = false,
|
| 106 |
-
/// Access granularity of A matrix in units of elements
|
| 107 |
-
int AlignmentA =
|
| 108 |
-
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
| 109 |
-
ElementC_, ElementAccumulator_>::kAlignmentA,
|
| 110 |
-
/// Access granularity of B matrix in units of elements
|
| 111 |
-
int AlignmentB =
|
| 112 |
-
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
| 113 |
-
ElementC_, ElementAccumulator_>::kAlignmentB,
|
| 114 |
-
/// Operation performed by GEMM
|
| 115 |
-
typename Operator_ = typename DefaultGemmConfiguration<
|
| 116 |
-
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 117 |
-
ElementAccumulator_>::Operator>
|
| 118 |
-
class DualGemm {
|
| 119 |
-
public:
|
| 120 |
-
|
| 121 |
-
using ElementA = ElementA_;
|
| 122 |
-
using LayoutA = LayoutA_;
|
| 123 |
-
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
| 124 |
-
using ElementB = ElementB_;
|
| 125 |
-
using LayoutB0 = LayoutB0_;
|
| 126 |
-
using LayoutB1 = LayoutB1_;
|
| 127 |
-
using TensorRefB0 = TensorRef<ElementB const, LayoutB0>;
|
| 128 |
-
using TensorRefB1 = TensorRef<ElementB const, LayoutB1>;
|
| 129 |
-
using ElementC = ElementC_;
|
| 130 |
-
using LayoutC = LayoutC_;
|
| 131 |
-
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
| 132 |
-
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
| 133 |
-
using ElementAccumulator = ElementAccumulator_;
|
| 134 |
-
using OperatorClass = OperatorClass_;
|
| 135 |
-
using ArchTag = ArchTag_;
|
| 136 |
-
using ThreadblockShape = ThreadblockShape_;
|
| 137 |
-
using WarpShape = WarpShape_;
|
| 138 |
-
using InstructionShape = InstructionShape_;
|
| 139 |
-
using EpilogueOutputOp0 = EpilogueOutputOp0_;
|
| 140 |
-
using EpilogueOutputOp1 = EpilogueOutputOp1_;
|
| 141 |
-
using EpilogueOutputOp2 = EpilogueOutputOp2_;
|
| 142 |
-
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 143 |
-
using Operator = Operator_;
|
| 144 |
-
static int const kStages = Stages;
|
| 145 |
-
static int const kAlignmentA = AlignmentA;
|
| 146 |
-
static int const kAlignmentB = AlignmentB;
|
| 147 |
-
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
| 148 |
-
static bool const kSplitKSerial = SplitKSerial;
|
| 149 |
-
static bool constexpr kStoreD0 = StoreD0;
|
| 150 |
-
static bool constexpr kStoreD1 = StoreD1;
|
| 151 |
-
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
| 152 |
-
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
| 153 |
-
|
| 154 |
-
using LayoutScaleBias = layout::RowMajor;
|
| 155 |
-
/// Define the kernel
|
| 156 |
-
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 157 |
-
static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
|
| 158 |
-
static_assert(kStages >= 3, "Only multistage is implemented");
|
| 159 |
-
using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
|
| 160 |
-
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
|
| 161 |
-
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
| 162 |
-
ThreadblockShape, WarpShape,
|
| 163 |
-
InstructionShape, Stages, Operator>::ThreadblockMma;
|
| 164 |
-
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
|
| 165 |
-
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
|
| 166 |
-
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
| 167 |
-
ThreadblockShape, WarpShape,
|
| 168 |
-
InstructionShape, Stages, Operator>::ThreadblockMma;
|
| 169 |
-
using DualMma = threadblock::DualMmaMultistage<
|
| 170 |
-
typename Mma0::Shape,
|
| 171 |
-
typename Mma0::IteratorA,
|
| 172 |
-
typename Mma0::SmemIteratorA,
|
| 173 |
-
Mma0::kCacheOpA,
|
| 174 |
-
typename Mma0::IteratorB,
|
| 175 |
-
typename Mma0::SmemIteratorB,
|
| 176 |
-
Mma0::kCacheOpB,
|
| 177 |
-
typename Mma1::IteratorB,
|
| 178 |
-
typename Mma1::SmemIteratorB,
|
| 179 |
-
typename Mma0::ElementC,
|
| 180 |
-
typename Mma0::LayoutC,
|
| 181 |
-
typename Mma0::Policy,
|
| 182 |
-
typename Mma1::Policy,
|
| 183 |
-
Mma0::kStages,
|
| 184 |
-
SharedMemoryClearOption::kNone
|
| 185 |
-
>;
|
| 186 |
-
|
| 187 |
-
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 188 |
-
|
| 189 |
-
/// Define the epilogue
|
| 190 |
-
using Epilogue0 =
|
| 191 |
-
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 192 |
-
ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0,
|
| 193 |
-
EpilogueOutputOp0::kCount>::Epilogue;
|
| 194 |
-
using Epilogue1 =
|
| 195 |
-
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 196 |
-
ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1,
|
| 197 |
-
EpilogueOutputOp1::kCount>::Epilogue;
|
| 198 |
-
|
| 199 |
-
/// Define the kernel-level GEMM operator.
|
| 200 |
-
using DualGemmKernel = kernel::DualGemm<
|
| 201 |
-
DualMma,
|
| 202 |
-
Epilogue0, Epilogue1, EpilogueOutputOp2,
|
| 203 |
-
ThreadblockSwizzle, kSplitKSerial,
|
| 204 |
-
kStoreD0, kStoreD1>;
|
| 205 |
-
|
| 206 |
-
/// Argument structure
|
| 207 |
-
struct Arguments {
|
| 208 |
-
|
| 209 |
-
//
|
| 210 |
-
// Data members
|
| 211 |
-
//
|
| 212 |
-
|
| 213 |
-
DualGemmMode mode;
|
| 214 |
-
GemmCoord problem_size;
|
| 215 |
-
TensorRef<ElementA const, LayoutA> ref_A0;
|
| 216 |
-
TensorRef<ElementB const, LayoutB0> ref_B0;
|
| 217 |
-
TensorRef<ElementC const, LayoutC> ref_C0;
|
| 218 |
-
TensorRef<ElementC, LayoutC> ref_D0;
|
| 219 |
-
TensorRef<ElementB const, LayoutB1> ref_B1;
|
| 220 |
-
TensorRef<ElementC const, LayoutC> ref_C1;
|
| 221 |
-
TensorRef<ElementC, LayoutC> ref_D1;
|
| 222 |
-
TensorRef<ElementC, LayoutC> ref_D2;
|
| 223 |
-
typename EpilogueOutputOp0::Params epilogue0;
|
| 224 |
-
typename EpilogueOutputOp1::Params epilogue1;
|
| 225 |
-
typename EpilogueOutputOp2::Params epilogue2;
|
| 226 |
-
int split_k_slices;
|
| 227 |
-
|
| 228 |
-
int batch_count;
|
| 229 |
-
int64_t batch_stride_A;
|
| 230 |
-
int64_t batch_stride_B0;
|
| 231 |
-
int64_t batch_stride_B1;
|
| 232 |
-
int64_t batch_stride_C;
|
| 233 |
-
int64_t batch_stride_D;
|
| 234 |
-
|
| 235 |
-
//
|
| 236 |
-
// Methods
|
| 237 |
-
//
|
| 238 |
-
|
| 239 |
-
/// Default ctor
|
| 240 |
-
CUTLASS_HOST_DEVICE
|
| 241 |
-
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
|
| 242 |
-
|
| 243 |
-
}
|
| 244 |
-
|
| 245 |
-
/// Constructs an Arguments structure
|
| 246 |
-
CUTLASS_HOST_DEVICE
|
| 247 |
-
Arguments(
|
| 248 |
-
DualGemmMode mode,
|
| 249 |
-
GemmCoord problem_size_,
|
| 250 |
-
TensorRef<ElementA const, LayoutA> ref_A0_,
|
| 251 |
-
TensorRef<ElementB const, LayoutB0> ref_B0_,
|
| 252 |
-
TensorRef<ElementC const, LayoutC> ref_C0_,
|
| 253 |
-
TensorRef<ElementC, LayoutC> ref_D0_,
|
| 254 |
-
TensorRef<ElementB const, LayoutB1> ref_B1_,
|
| 255 |
-
TensorRef<ElementC const, LayoutC> ref_C1_,
|
| 256 |
-
TensorRef<ElementC, LayoutC> ref_D1_,
|
| 257 |
-
TensorRef<ElementC, LayoutC> ref_D2_,
|
| 258 |
-
typename EpilogueOutputOp0::Params epilogue0_ =
|
| 259 |
-
typename EpilogueOutputOp0::Params(),
|
| 260 |
-
typename EpilogueOutputOp1::Params epilogue1_ =
|
| 261 |
-
typename EpilogueOutputOp1::Params(),
|
| 262 |
-
typename EpilogueOutputOp2::Params epilogue2_ =
|
| 263 |
-
typename EpilogueOutputOp2::Params(),
|
| 264 |
-
int split_k_slices_ = 1,
|
| 265 |
-
int batch_count = 1,
|
| 266 |
-
int64_t batch_stride_A = 0,
|
| 267 |
-
int64_t batch_stride_B0 = 0,
|
| 268 |
-
int64_t batch_stride_B1 = 0,
|
| 269 |
-
int64_t batch_stride_C = 0,
|
| 270 |
-
int64_t batch_stride_D = 0
|
| 271 |
-
):
|
| 272 |
-
mode(mode),
|
| 273 |
-
problem_size(problem_size_),
|
| 274 |
-
ref_A0(ref_A0_),
|
| 275 |
-
ref_B0(ref_B0_),
|
| 276 |
-
ref_C0(ref_C0_),
|
| 277 |
-
ref_D0(ref_D0_),
|
| 278 |
-
ref_B1(ref_B1_),
|
| 279 |
-
ref_C1(ref_C1_),
|
| 280 |
-
ref_D1(ref_D1_),
|
| 281 |
-
ref_D2(ref_D2_),
|
| 282 |
-
epilogue0(epilogue0_),
|
| 283 |
-
epilogue1(epilogue1_),
|
| 284 |
-
epilogue2(epilogue2_),
|
| 285 |
-
split_k_slices(split_k_slices_),
|
| 286 |
-
batch_count(batch_count),
|
| 287 |
-
batch_stride_A(batch_stride_A),
|
| 288 |
-
batch_stride_B0(batch_stride_B0),
|
| 289 |
-
batch_stride_B1(batch_stride_B1),
|
| 290 |
-
batch_stride_C(batch_stride_C),
|
| 291 |
-
batch_stride_D(batch_stride_D) {
|
| 292 |
-
|
| 293 |
-
}
|
| 294 |
-
};
|
| 295 |
-
|
| 296 |
-
private:
|
| 297 |
-
|
| 298 |
-
/// Kernel parameters object
|
| 299 |
-
typename DualGemmKernel::Params params_;
|
| 300 |
-
|
| 301 |
-
public:
|
| 302 |
-
|
| 303 |
-
/// Constructs the GEMM.
|
| 304 |
-
DualGemm() = default;
|
| 305 |
-
|
| 306 |
-
/// Determines whether the GEMM can execute the given problem.
|
| 307 |
-
static Status can_implement(Arguments const &args) {
|
| 308 |
-
|
| 309 |
-
if (args.mode == DualGemmMode::kBatched && kSplitKSerial) {
|
| 310 |
-
return Status::kErrorInvalidProblem;
|
| 311 |
-
}
|
| 312 |
-
if (!kSplitKSerial && args.split_k_slices > 1) {
|
| 313 |
-
return Status::kErrorInvalidProblem;
|
| 314 |
-
}
|
| 315 |
-
if (kStoreD0 != (args.ref_D0.data() != nullptr)) {
|
| 316 |
-
return Status::kErrorInternal;
|
| 317 |
-
}
|
| 318 |
-
if (kStoreD1 != (args.ref_D1.data() != nullptr)) {
|
| 319 |
-
return Status::kErrorInternal;
|
| 320 |
-
}
|
| 321 |
-
|
| 322 |
-
Status status = DualGemmKernel::can_implement(
|
| 323 |
-
args.problem_size,
|
| 324 |
-
args.ref_A0.non_const_ref(),
|
| 325 |
-
args.ref_B0.non_const_ref(),
|
| 326 |
-
args.ref_C0.non_const_ref(),
|
| 327 |
-
args.ref_D0,
|
| 328 |
-
args.ref_B1.non_const_ref(),
|
| 329 |
-
args.ref_C1.non_const_ref(),
|
| 330 |
-
args.ref_D1,
|
| 331 |
-
args.ref_D2
|
| 332 |
-
);
|
| 333 |
-
|
| 334 |
-
if (status != Status::kSuccess) {
|
| 335 |
-
return status;
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
return Status::kSuccess;
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
/// Gets the workspace size
|
| 342 |
-
static size_t get_workspace_size(Arguments const &args) {
|
| 343 |
-
|
| 344 |
-
size_t bytes = 0;
|
| 345 |
-
|
| 346 |
-
if (kSplitKSerial && args.split_k_slices > 1) {
|
| 347 |
-
// Determine grid shape
|
| 348 |
-
ThreadblockSwizzle threadblock_swizzle;
|
| 349 |
-
|
| 350 |
-
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
| 351 |
-
args.problem_size,
|
| 352 |
-
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 353 |
-
args.split_k_slices);
|
| 354 |
-
|
| 355 |
-
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
return bytes;
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
/// Initializes GEMM state from arguments.
|
| 362 |
-
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
| 363 |
-
|
| 364 |
-
// Determine grid shape
|
| 365 |
-
ThreadblockSwizzle threadblock_swizzle;
|
| 366 |
-
|
| 367 |
-
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
| 368 |
-
args.problem_size,
|
| 369 |
-
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 370 |
-
args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices);
|
| 371 |
-
|
| 372 |
-
if (kSplitKSerial) {
|
| 373 |
-
if (args.split_k_slices > 1) {
|
| 374 |
-
if (!workspace) {
|
| 375 |
-
return Status::kErrorWorkspaceNull;
|
| 376 |
-
}
|
| 377 |
-
|
| 378 |
-
size_t bytes = get_workspace_size(args);
|
| 379 |
-
|
| 380 |
-
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
| 381 |
-
|
| 382 |
-
if (result != cudaSuccess) {
|
| 383 |
-
return Status::kErrorInternal;
|
| 384 |
-
}
|
| 385 |
-
}
|
| 386 |
-
}
|
| 387 |
-
else {
|
| 388 |
-
|
| 389 |
-
if (args.split_k_slices > 1) {
|
| 390 |
-
return Status::kErrorInvalidProblem;
|
| 391 |
-
}
|
| 392 |
-
}
|
| 393 |
-
|
| 394 |
-
// Initialize the Params structure
|
| 395 |
-
params_ = typename DualGemmKernel::Params{
|
| 396 |
-
args.mode,
|
| 397 |
-
args.problem_size,
|
| 398 |
-
grid_shape,
|
| 399 |
-
args.ref_A0.non_const_ref(),
|
| 400 |
-
args.ref_B0.non_const_ref(),
|
| 401 |
-
args.ref_C0.non_const_ref(),
|
| 402 |
-
args.ref_D0,
|
| 403 |
-
args.ref_B1.non_const_ref(),
|
| 404 |
-
args.ref_C1.non_const_ref(),
|
| 405 |
-
args.ref_D1,
|
| 406 |
-
args.ref_D2,
|
| 407 |
-
args.epilogue0,
|
| 408 |
-
args.epilogue1,
|
| 409 |
-
args.epilogue2,
|
| 410 |
-
reinterpret_cast<int *>(workspace),
|
| 411 |
-
args.batch_stride_A,
|
| 412 |
-
args.batch_stride_B0,
|
| 413 |
-
args.batch_stride_B1,
|
| 414 |
-
args.batch_stride_C,
|
| 415 |
-
args.batch_stride_D,
|
| 416 |
-
};
|
| 417 |
-
|
| 418 |
-
return Status::kSuccess;
|
| 419 |
-
}
|
| 420 |
-
|
| 421 |
-
/// Lightweight update given a subset of arguments
|
| 422 |
-
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 423 |
-
|
| 424 |
-
if (kSplitKSerial && args.split_k_slices > 1) {
|
| 425 |
-
if (!workspace) {
|
| 426 |
-
return Status::kErrorWorkspaceNull;
|
| 427 |
-
}
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
| 431 |
-
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
| 432 |
-
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
| 433 |
-
params_.ref_D0.reset(args.ref_D0.data());
|
| 434 |
-
params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
|
| 435 |
-
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
|
| 436 |
-
params_.ref_D1.reset(args.ref_D1.data());
|
| 437 |
-
params_.ref_D2.reset(args.ref_D2.data());
|
| 438 |
-
params_.output_op_0 = args.epilogue0;
|
| 439 |
-
params_.output_op_1 = args.epilogue1;
|
| 440 |
-
params_.output_op_2 = args.epilogue2;
|
| 441 |
-
params_.semaphore = reinterpret_cast<int *>(workspace);
|
| 442 |
-
|
| 443 |
-
return Status::kSuccess;
|
| 444 |
-
}
|
| 445 |
-
|
| 446 |
-
/// Runs the kernel using initialized state.
|
| 447 |
-
Status run(cudaStream_t stream = nullptr) {
|
| 448 |
-
|
| 449 |
-
ThreadblockSwizzle threadblock_swizzle;
|
| 450 |
-
|
| 451 |
-
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
| 452 |
-
dim3 block(DualGemmKernel::kThreadCount, 1, 1);
|
| 453 |
-
|
| 454 |
-
cudaError_t result;
|
| 455 |
-
|
| 456 |
-
int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage));
|
| 457 |
-
if (smem_size >= (48 << 10)) {
|
| 458 |
-
result = cudaFuncSetAttribute(Kernel<DualGemmKernel>,
|
| 459 |
-
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 460 |
-
smem_size);
|
| 461 |
-
|
| 462 |
-
if (result != cudaSuccess) {
|
| 463 |
-
return Status::kErrorInternal;
|
| 464 |
-
}
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
cutlass::Kernel<DualGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
| 468 |
-
|
| 469 |
-
result = cudaGetLastError();
|
| 470 |
-
|
| 471 |
-
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
| 472 |
-
}
|
| 473 |
-
|
| 474 |
-
/// Runs the kernel using initialized state.
|
| 475 |
-
Status operator()(cudaStream_t stream = nullptr) {
|
| 476 |
-
return run(stream);
|
| 477 |
-
}
|
| 478 |
-
|
| 479 |
-
/// Runs the kernel using initialized state.
|
| 480 |
-
Status operator()(
|
| 481 |
-
Arguments const &args,
|
| 482 |
-
void *workspace = nullptr,
|
| 483 |
-
cudaStream_t stream = nullptr) {
|
| 484 |
-
|
| 485 |
-
Status status = initialize(args, workspace, stream);
|
| 486 |
-
|
| 487 |
-
if (status == Status::kSuccess) {
|
| 488 |
-
status = run(stream);
|
| 489 |
-
}
|
| 490 |
-
|
| 491 |
-
return status;
|
| 492 |
-
}
|
| 493 |
-
};
|
| 494 |
-
|
| 495 |
-
} // namespace device
|
| 496 |
-
} // namespace gemm
|
| 497 |
-
} // namespace cutlass
|
| 498 |
-
|
| 499 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Defines common types used for all DualGemm operators.
|
| 33 |
-
*/
|
| 34 |
-
#pragma once
|
| 35 |
-
|
| 36 |
-
namespace cutlass {
|
| 37 |
-
namespace gemm {
|
| 38 |
-
|
| 39 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
-
|
| 41 |
-
enum class DualGemmMode {
|
| 42 |
-
kGemm,
|
| 43 |
-
kBatched,
|
| 44 |
-
kInvalid
|
| 45 |
-
};
|
| 46 |
-
|
| 47 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
-
|
| 49 |
-
} // namespace gemm
|
| 50 |
-
} // namespace cutlass
|
| 51 |
-
|
| 52 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h
DELETED
|
@@ -1,938 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include <iostream>
|
| 34 |
-
#include <fstream>
|
| 35 |
-
#include <sstream>
|
| 36 |
-
#include <type_traits>
|
| 37 |
-
|
| 38 |
-
#include "cutlass/util/host_tensor.h"
|
| 39 |
-
#include "cutlass/util/tensor_view_io.h"
|
| 40 |
-
#include "cutlass/util/distribution.h"
|
| 41 |
-
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 42 |
-
#include "cutlass/util/reference/host/tensor_copy.h"
|
| 43 |
-
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 44 |
-
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 45 |
-
#include "cutlass/util/reference/device/gemm.h"
|
| 46 |
-
#include "cutlass/util/reference/device/tensor_relu.h"
|
| 47 |
-
|
| 48 |
-
#include "cutlass/platform/platform.h"
|
| 49 |
-
#include "cutlass/gemm/gemm.h"
|
| 50 |
-
#include "cutlass/gemm/device/gemm_universal.h"
|
| 51 |
-
|
| 52 |
-
#include "dual_gemm_common.h"
|
| 53 |
-
#include "helper.h"
|
| 54 |
-
|
| 55 |
-
#define CHECK_GT(val1, val2) \
|
| 56 |
-
if((val1) <= (val2)) \
|
| 57 |
-
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
| 58 |
-
#define CHECK_TRUE(val) \
|
| 59 |
-
if(!(val)) \
|
| 60 |
-
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
| 61 |
-
|
| 62 |
-
template <
|
| 63 |
-
typename OutputOp,
|
| 64 |
-
typename Element,
|
| 65 |
-
typename Layout>
|
| 66 |
-
struct TensorEpilogueForEachFunc {
|
| 67 |
-
/// View type
|
| 68 |
-
using TensorView = cutlass::TensorView<Element, Layout>;
|
| 69 |
-
|
| 70 |
-
/// Coordinate in tensor's index space
|
| 71 |
-
using TensorCoord = typename TensorView::TensorCoord;
|
| 72 |
-
|
| 73 |
-
/// Parameters structure
|
| 74 |
-
struct Params {
|
| 75 |
-
|
| 76 |
-
//
|
| 77 |
-
// Data members
|
| 78 |
-
//
|
| 79 |
-
|
| 80 |
-
TensorView view_x0;
|
| 81 |
-
TensorView view_x1;
|
| 82 |
-
TensorView view_y;
|
| 83 |
-
OutputOp output_op;
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
//
|
| 87 |
-
// Methods
|
| 88 |
-
//
|
| 89 |
-
|
| 90 |
-
Params(
|
| 91 |
-
TensorView view_x0_ = TensorView(),
|
| 92 |
-
TensorView view_x1_ = TensorView(),
|
| 93 |
-
TensorView view_y_ = TensorView(),
|
| 94 |
-
OutputOp output_op_ = OutputOp(typename OutputOp::Params{})
|
| 95 |
-
):
|
| 96 |
-
view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) {
|
| 97 |
-
}
|
| 98 |
-
};
|
| 99 |
-
|
| 100 |
-
Params params;
|
| 101 |
-
|
| 102 |
-
CUTLASS_DEVICE
|
| 103 |
-
TensorEpilogueForEachFunc(Params const ¶ms): params(params) {
|
| 104 |
-
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
CUTLASS_DEVICE
|
| 108 |
-
void operator()(TensorCoord const &coord) {
|
| 109 |
-
Element const & x0 = params.view_x0.at(coord);
|
| 110 |
-
Element const & x1 = params.view_x1.at(coord);
|
| 111 |
-
Element& y = params.view_y.at(coord);
|
| 112 |
-
y = params.output_op(x0, x1);
|
| 113 |
-
}
|
| 114 |
-
};
|
| 115 |
-
|
| 116 |
-
template <
|
| 117 |
-
typename OutputOp,
|
| 118 |
-
typename Element,
|
| 119 |
-
typename Layout>
|
| 120 |
-
void TensorEpilogueForEach(
|
| 121 |
-
cutlass::TensorView<Element, Layout> x0,
|
| 122 |
-
cutlass::TensorView<Element, Layout> x1,
|
| 123 |
-
cutlass::TensorView<Element, Layout> y) {
|
| 124 |
-
|
| 125 |
-
using Func = TensorEpilogueForEachFunc<OutputOp, Element, Layout>;
|
| 126 |
-
using Params = typename Func::Params;
|
| 127 |
-
|
| 128 |
-
cutlass::reference::device::TensorForEach<Func, Layout::kRank, Params>(
|
| 129 |
-
y.extent(),
|
| 130 |
-
Params(x0, x1, y)
|
| 131 |
-
);
|
| 132 |
-
}
|
| 133 |
-
|
| 134 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 135 |
-
|
| 136 |
-
template <typename Gemm0_, typename Gemm1_>
|
| 137 |
-
struct NonFusedDualGemmRun
|
| 138 |
-
{
|
| 139 |
-
|
| 140 |
-
using Gemm0 = Gemm0_;
|
| 141 |
-
using Gemm1 = Gemm1_;
|
| 142 |
-
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
| 143 |
-
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
| 144 |
-
|
| 145 |
-
/// Initialization
|
| 146 |
-
cutlass::Distribution::Kind init_A;
|
| 147 |
-
cutlass::Distribution::Kind init_B;
|
| 148 |
-
cutlass::Distribution::Kind init_C;
|
| 149 |
-
cutlass::Distribution::Kind init_Bias;
|
| 150 |
-
uint64_t seed;
|
| 151 |
-
|
| 152 |
-
//
|
| 153 |
-
// Methods
|
| 154 |
-
//
|
| 155 |
-
|
| 156 |
-
NonFusedDualGemmRun(
|
| 157 |
-
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 158 |
-
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 159 |
-
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 160 |
-
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 161 |
-
uint64_t seed_ = 2080
|
| 162 |
-
):
|
| 163 |
-
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
| 164 |
-
|
| 165 |
-
/// Helper to initialize a tensor view
|
| 166 |
-
template <typename Element, typename Layout>
|
| 167 |
-
bool initialize_tensor(
|
| 168 |
-
cutlass::TensorView<Element, Layout> view,
|
| 169 |
-
cutlass::Distribution::Kind dist_kind,
|
| 170 |
-
uint64_t seed) {
|
| 171 |
-
|
| 172 |
-
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 173 |
-
|
| 174 |
-
cutlass::reference::host::TensorFillRandomUniform(
|
| 175 |
-
view, seed, 2, -2, 0);
|
| 176 |
-
}
|
| 177 |
-
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 178 |
-
|
| 179 |
-
cutlass::reference::host::TensorFillIdentity(view);
|
| 180 |
-
}
|
| 181 |
-
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 182 |
-
|
| 183 |
-
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 184 |
-
}
|
| 185 |
-
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 186 |
-
|
| 187 |
-
cutlass::reference::host::BlockFillSequential(
|
| 188 |
-
view.data(), view.capacity());
|
| 189 |
-
}
|
| 190 |
-
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 191 |
-
cutlass::reference::host::TensorFill(view, Element(0));
|
| 192 |
-
}
|
| 193 |
-
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 194 |
-
cutlass::reference::host::TensorFill(view, Element(1));
|
| 195 |
-
}
|
| 196 |
-
else {
|
| 197 |
-
std::cerr << "Not implemented\n";
|
| 198 |
-
return false;
|
| 199 |
-
}
|
| 200 |
-
|
| 201 |
-
return true;
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
/// Executes one test
|
| 208 |
-
bool run(
|
| 209 |
-
cutlass::gemm::GemmCoord problem_size,
|
| 210 |
-
ElementCompute alpha0 = ElementCompute(1),
|
| 211 |
-
ElementCompute beta0 = ElementCompute(0),
|
| 212 |
-
ElementCompute alpha1 = ElementCompute(1),
|
| 213 |
-
ElementCompute beta1 = ElementCompute(0),
|
| 214 |
-
bool is_profiling = true,
|
| 215 |
-
bool relu = false,
|
| 216 |
-
int warm_ups = 1,
|
| 217 |
-
int runs = 100) {
|
| 218 |
-
|
| 219 |
-
//
|
| 220 |
-
// Allocate the GEMM workspace
|
| 221 |
-
//
|
| 222 |
-
|
| 223 |
-
cutlass::HostTensor<
|
| 224 |
-
typename Gemm0::ElementA,
|
| 225 |
-
typename Gemm0::LayoutA> tensor_A0(problem_size.mk());
|
| 226 |
-
|
| 227 |
-
cutlass::HostTensor<
|
| 228 |
-
typename Gemm0::ElementB,
|
| 229 |
-
typename Gemm0::LayoutB> tensor_B0(problem_size.kn());
|
| 230 |
-
|
| 231 |
-
cutlass::HostTensor<
|
| 232 |
-
typename Gemm0::ElementC,
|
| 233 |
-
typename Gemm0::LayoutC> tensor_C0(problem_size.mn());
|
| 234 |
-
|
| 235 |
-
cutlass::HostTensor<
|
| 236 |
-
typename Gemm1::ElementC,
|
| 237 |
-
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()});
|
| 238 |
-
|
| 239 |
-
cutlass::HostTensor<
|
| 240 |
-
typename Gemm0::ElementC,
|
| 241 |
-
typename Gemm0::LayoutC> tensor_D0(problem_size.mn());
|
| 242 |
-
|
| 243 |
-
cutlass::HostTensor<
|
| 244 |
-
typename Gemm0::ElementC,
|
| 245 |
-
typename Gemm0::LayoutC> reference_D0(problem_size.mn());
|
| 246 |
-
|
| 247 |
-
cutlass::HostTensor<
|
| 248 |
-
typename Gemm1::ElementB,
|
| 249 |
-
typename Gemm1::LayoutB> tensor_B1(problem_size.kn());
|
| 250 |
-
|
| 251 |
-
cutlass::HostTensor<
|
| 252 |
-
typename Gemm1::ElementC,
|
| 253 |
-
typename Gemm1::LayoutC> tensor_C1(problem_size.mn());
|
| 254 |
-
|
| 255 |
-
cutlass::HostTensor<
|
| 256 |
-
typename Gemm1::ElementC,
|
| 257 |
-
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()});
|
| 258 |
-
|
| 259 |
-
cutlass::HostTensor<
|
| 260 |
-
typename Gemm1::ElementC,
|
| 261 |
-
typename Gemm1::LayoutC> tensor_D1(problem_size.mn());
|
| 262 |
-
|
| 263 |
-
cutlass::HostTensor<
|
| 264 |
-
typename Gemm1::ElementC,
|
| 265 |
-
typename Gemm1::LayoutC> reference_D1(problem_size.mn());
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
| 269 |
-
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
| 270 |
-
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
| 271 |
-
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
| 272 |
-
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
| 273 |
-
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
| 274 |
-
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
| 275 |
-
|
| 276 |
-
cutlass::reference::host::TensorFill(
|
| 277 |
-
tensor_D0.host_view());
|
| 278 |
-
cutlass::reference::host::TensorFill(
|
| 279 |
-
tensor_D1.host_view());
|
| 280 |
-
cutlass::reference::host::TensorFill(
|
| 281 |
-
reference_D0.host_view());
|
| 282 |
-
cutlass::reference::host::TensorFill(
|
| 283 |
-
reference_D1.host_view());
|
| 284 |
-
|
| 285 |
-
tensor_A0.sync_device();
|
| 286 |
-
tensor_B0.sync_device();
|
| 287 |
-
tensor_C0.sync_device();
|
| 288 |
-
tensor_Bias0.sync_device();
|
| 289 |
-
tensor_D0.sync_device();
|
| 290 |
-
reference_D0.sync_device();
|
| 291 |
-
tensor_B1.sync_device();
|
| 292 |
-
tensor_C1.sync_device();
|
| 293 |
-
tensor_Bias1.sync_device();
|
| 294 |
-
tensor_D1.sync_device();
|
| 295 |
-
reference_D1.sync_device();
|
| 296 |
-
|
| 297 |
-
//
|
| 298 |
-
// Initialize the GEMM operator
|
| 299 |
-
//
|
| 300 |
-
|
| 301 |
-
int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1;
|
| 302 |
-
typename Gemm0::Arguments arguments_0{
|
| 303 |
-
problem_size,
|
| 304 |
-
tensor_A0.device_ref(),
|
| 305 |
-
tensor_B0.device_ref(),
|
| 306 |
-
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
| 307 |
-
tensor_D0.device_ref(),
|
| 308 |
-
{alpha0, beta0},
|
| 309 |
-
split_k_slices
|
| 310 |
-
};
|
| 311 |
-
|
| 312 |
-
split_k_slices = Gemm1::kSplitKSerial ? 2 : 1;
|
| 313 |
-
typename Gemm1::Arguments arguments_1{
|
| 314 |
-
problem_size,
|
| 315 |
-
tensor_A0.device_ref(),
|
| 316 |
-
tensor_B1.device_ref(),
|
| 317 |
-
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
| 318 |
-
tensor_D1.device_ref(),
|
| 319 |
-
{alpha1, beta1},
|
| 320 |
-
split_k_slices
|
| 321 |
-
};
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
Gemm0 gemm_op_0;
|
| 325 |
-
Gemm1 gemm_op_1;
|
| 326 |
-
|
| 327 |
-
// Allocate workspace memory
|
| 328 |
-
cutlass::device_memory::allocation<uint8_t> workspace0(gemm_op_0.get_workspace_size(arguments_0));
|
| 329 |
-
cutlass::device_memory::allocation<uint8_t> workspace1(gemm_op_1.get_workspace_size(arguments_1));
|
| 330 |
-
|
| 331 |
-
cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get());
|
| 332 |
-
|
| 333 |
-
CUTLASS_CHECK(status);
|
| 334 |
-
|
| 335 |
-
status = gemm_op_1.initialize(arguments_1, workspace1.get());
|
| 336 |
-
|
| 337 |
-
CUTLASS_CHECK(status);
|
| 338 |
-
|
| 339 |
-
for(int i = 0; i < warm_ups; i++) {
|
| 340 |
-
status = gemm_op_0();
|
| 341 |
-
CUTLASS_CHECK(status);
|
| 342 |
-
status = gemm_op_1();
|
| 343 |
-
CUTLASS_CHECK(status);
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
if (is_profiling) {
|
| 347 |
-
//
|
| 348 |
-
// Profile the GEMM
|
| 349 |
-
//
|
| 350 |
-
|
| 351 |
-
cudaEvent_t start, stop1, stop2;
|
| 352 |
-
cudaEventCreate(&start);
|
| 353 |
-
cudaEventCreate(&stop1);
|
| 354 |
-
cudaEventCreate(&stop2);
|
| 355 |
-
|
| 356 |
-
cudaEventRecord(start);
|
| 357 |
-
|
| 358 |
-
for(int i = 0; i < runs; i++) {
|
| 359 |
-
status = gemm_op_0();
|
| 360 |
-
|
| 361 |
-
CUTLASS_CHECK(status);
|
| 362 |
-
}
|
| 363 |
-
cudaEventRecord(stop1);
|
| 364 |
-
for(int i = 0; i < runs; i++) {
|
| 365 |
-
status = gemm_op_1();
|
| 366 |
-
|
| 367 |
-
CUTLASS_CHECK(status);
|
| 368 |
-
}
|
| 369 |
-
|
| 370 |
-
cudaEventRecord(stop2);
|
| 371 |
-
cudaDeviceSynchronize();
|
| 372 |
-
float gemm0Time, gemm1Time, totalTime;
|
| 373 |
-
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
| 374 |
-
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
| 375 |
-
cudaEventElapsedTime(&totalTime, start, stop2);
|
| 376 |
-
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
| 377 |
-
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
| 378 |
-
std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
tensor_D0.sync_host();
|
| 382 |
-
tensor_D1.sync_host();
|
| 383 |
-
|
| 384 |
-
//
|
| 385 |
-
// Verify
|
| 386 |
-
//
|
| 387 |
-
cutlass::reference::device::Gemm<
|
| 388 |
-
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
| 389 |
-
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
| 390 |
-
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
| 391 |
-
ElementAccumulator, typename Gemm0::Operator>
|
| 392 |
-
reference_gemm_0;
|
| 393 |
-
|
| 394 |
-
cutlass::reference::device::Gemm<
|
| 395 |
-
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
| 396 |
-
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
| 397 |
-
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
| 398 |
-
ElementAccumulator, typename Gemm1::Operator>
|
| 399 |
-
reference_gemm_1;
|
| 400 |
-
|
| 401 |
-
reference_gemm_0(
|
| 402 |
-
problem_size,
|
| 403 |
-
alpha0,
|
| 404 |
-
tensor_A0.device_ref(),
|
| 405 |
-
tensor_B0.device_ref(),
|
| 406 |
-
beta0,
|
| 407 |
-
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
| 408 |
-
reference_D0.device_ref()
|
| 409 |
-
);
|
| 410 |
-
|
| 411 |
-
if(relu) {
|
| 412 |
-
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
| 413 |
-
}
|
| 414 |
-
|
| 415 |
-
reference_gemm_1(
|
| 416 |
-
problem_size,
|
| 417 |
-
alpha1,
|
| 418 |
-
tensor_A0.device_ref(),
|
| 419 |
-
tensor_B1.device_ref(),
|
| 420 |
-
beta1,
|
| 421 |
-
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
| 422 |
-
reference_D1.device_ref()
|
| 423 |
-
);
|
| 424 |
-
|
| 425 |
-
if(relu) {
|
| 426 |
-
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
| 427 |
-
}
|
| 428 |
-
|
| 429 |
-
// Wait for kernels to finish
|
| 430 |
-
cudaDeviceSynchronize();
|
| 431 |
-
reference_D0.sync_host();
|
| 432 |
-
reference_D1.sync_host();
|
| 433 |
-
|
| 434 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
| 435 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
| 436 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
| 437 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
| 438 |
-
|
| 439 |
-
bool passed0 = cutlass::reference::host::TensorEquals(
|
| 440 |
-
reference_D1.host_view(),
|
| 441 |
-
tensor_D1.host_view());
|
| 442 |
-
CHECK_TRUE(passed0);
|
| 443 |
-
|
| 444 |
-
bool passed1 = cutlass::reference::host::TensorEquals(
|
| 445 |
-
reference_D1.host_view(),
|
| 446 |
-
tensor_D1.host_view());
|
| 447 |
-
CHECK_TRUE(passed1);
|
| 448 |
-
if (!passed0 || !passed1) {
|
| 449 |
-
|
| 450 |
-
std::stringstream fname;
|
| 451 |
-
|
| 452 |
-
fname << "error_DualGemm_device_nonfused.txt";
|
| 453 |
-
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 454 |
-
|
| 455 |
-
std::ofstream file(fname.str());
|
| 456 |
-
|
| 457 |
-
file
|
| 458 |
-
<< "A0 =\n" << tensor_A0.host_view()
|
| 459 |
-
<< "\nB0 =\n" << tensor_B0.host_view()
|
| 460 |
-
<< "\nC0 =\n" << tensor_C0.host_view()
|
| 461 |
-
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 462 |
-
<< "\nD0 =\n" << tensor_D0.host_view()
|
| 463 |
-
<< "\nB1 =\n" << tensor_B1.host_view()
|
| 464 |
-
<< "\nC1 =\n" << tensor_C1.host_view()
|
| 465 |
-
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 466 |
-
<< "\n\nReference =\n" << reference_D1.host_view()
|
| 467 |
-
<< "\nComputed =\n" << tensor_D1.host_view();
|
| 468 |
-
}
|
| 469 |
-
return passed0 && passed1;
|
| 470 |
-
}
|
| 471 |
-
};
|
| 472 |
-
|
| 473 |
-
template <typename DualGemm_>
|
| 474 |
-
struct DualFusedGemmRun
|
| 475 |
-
{
|
| 476 |
-
|
| 477 |
-
using DualGemm = DualGemm_;
|
| 478 |
-
using ElementAccumulator = typename DualGemm::ElementAccumulator;
|
| 479 |
-
using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute;
|
| 480 |
-
using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2;
|
| 481 |
-
|
| 482 |
-
/// Initialization
|
| 483 |
-
cutlass::Distribution::Kind init_A;
|
| 484 |
-
cutlass::Distribution::Kind init_B;
|
| 485 |
-
cutlass::Distribution::Kind init_C;
|
| 486 |
-
cutlass::Distribution::Kind init_Scale;
|
| 487 |
-
cutlass::Distribution::Kind init_Bias;
|
| 488 |
-
uint64_t seed;
|
| 489 |
-
|
| 490 |
-
//
|
| 491 |
-
// Methods
|
| 492 |
-
//
|
| 493 |
-
|
| 494 |
-
DualFusedGemmRun(
|
| 495 |
-
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 496 |
-
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 497 |
-
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 498 |
-
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
| 499 |
-
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 500 |
-
uint64_t seed_ = 2080
|
| 501 |
-
):
|
| 502 |
-
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
| 503 |
-
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
| 504 |
-
|
| 505 |
-
/// Helper to initialize a tensor view
|
| 506 |
-
template <typename Element, typename Layout>
|
| 507 |
-
bool initialize_tensor(
|
| 508 |
-
cutlass::TensorView<Element, Layout> view,
|
| 509 |
-
cutlass::Distribution::Kind dist_kind,
|
| 510 |
-
uint64_t seed) {
|
| 511 |
-
|
| 512 |
-
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 513 |
-
|
| 514 |
-
cutlass::reference::host::TensorFillRandomUniform(
|
| 515 |
-
view, seed, 2, -2, 0);
|
| 516 |
-
}
|
| 517 |
-
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 518 |
-
|
| 519 |
-
cutlass::reference::host::TensorFillIdentity(view);
|
| 520 |
-
}
|
| 521 |
-
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 522 |
-
|
| 523 |
-
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 524 |
-
}
|
| 525 |
-
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 526 |
-
|
| 527 |
-
cutlass::reference::host::BlockFillSequential(
|
| 528 |
-
view.data(), view.capacity());
|
| 529 |
-
}
|
| 530 |
-
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 531 |
-
cutlass::reference::host::TensorFill(view, Element(0));
|
| 532 |
-
}
|
| 533 |
-
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 534 |
-
cutlass::reference::host::TensorFill(view, Element(1));
|
| 535 |
-
}
|
| 536 |
-
else {
|
| 537 |
-
std::cerr << "Not implemented\n";
|
| 538 |
-
return false;
|
| 539 |
-
}
|
| 540 |
-
|
| 541 |
-
return true;
|
| 542 |
-
}
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
/// Executes one test
|
| 548 |
-
bool run(
|
| 549 |
-
cutlass::gemm::GemmCoord problem_size,
|
| 550 |
-
ElementCompute alpha0 = ElementCompute(1),
|
| 551 |
-
ElementCompute beta0 = ElementCompute(1),
|
| 552 |
-
ElementCompute alpha1 = ElementCompute(1),
|
| 553 |
-
ElementCompute beta1 = ElementCompute(1),
|
| 554 |
-
int batch_count = 1,
|
| 555 |
-
bool broadcast_b1 = false,
|
| 556 |
-
bool is_profiling = true,
|
| 557 |
-
bool relu = false,
|
| 558 |
-
int warm_ups = 1,
|
| 559 |
-
int runs = 100) {
|
| 560 |
-
|
| 561 |
-
//
|
| 562 |
-
// Allocate the GEMM workspace
|
| 563 |
-
//
|
| 564 |
-
|
| 565 |
-
cutlass::HostTensor<
|
| 566 |
-
typename DualGemm::ElementA,
|
| 567 |
-
typename DualGemm::LayoutA> tensor_A0(
|
| 568 |
-
cutlass::platform::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
|
| 569 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
|
| 570 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k()));
|
| 571 |
-
|
| 572 |
-
cutlass::HostTensor<
|
| 573 |
-
typename DualGemm::ElementB,
|
| 574 |
-
typename DualGemm::LayoutB0> tensor_B0(
|
| 575 |
-
cutlass::platform::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
|
| 576 |
-
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
| 577 |
-
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
| 578 |
-
|
| 579 |
-
cutlass::HostTensor<
|
| 580 |
-
typename DualGemm::ElementC,
|
| 581 |
-
typename DualGemm::LayoutC> tensor_C0(
|
| 582 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 583 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 584 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 585 |
-
|
| 586 |
-
cutlass::HostTensor<
|
| 587 |
-
typename DualGemm::ElementC,
|
| 588 |
-
typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()});
|
| 589 |
-
|
| 590 |
-
cutlass::HostTensor<
|
| 591 |
-
typename DualGemm::ElementC,
|
| 592 |
-
typename DualGemm::LayoutC> tensor_D0(
|
| 593 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 594 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 595 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 596 |
-
|
| 597 |
-
cutlass::HostTensor<
|
| 598 |
-
typename DualGemm::ElementC,
|
| 599 |
-
typename DualGemm::LayoutC> reference_D0(
|
| 600 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 601 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 602 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 603 |
-
|
| 604 |
-
cutlass::HostTensor<
|
| 605 |
-
typename DualGemm::ElementB,
|
| 606 |
-
typename DualGemm::LayoutB1> tensor_B1(
|
| 607 |
-
cutlass::platform::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
|
| 608 |
-
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
| 609 |
-
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
| 610 |
-
if (broadcast_b1) {
|
| 611 |
-
tensor_B1.resize({problem_size.k(), batch_count});
|
| 612 |
-
}
|
| 613 |
-
|
| 614 |
-
cutlass::HostTensor<
|
| 615 |
-
typename DualGemm::ElementC,
|
| 616 |
-
typename DualGemm::LayoutC> tensor_C1(
|
| 617 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 618 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 619 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 620 |
-
|
| 621 |
-
cutlass::HostTensor<
|
| 622 |
-
typename DualGemm::ElementC,
|
| 623 |
-
typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()});
|
| 624 |
-
|
| 625 |
-
cutlass::HostTensor<
|
| 626 |
-
typename DualGemm::ElementC,
|
| 627 |
-
typename DualGemm::LayoutC> tensor_D1(
|
| 628 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 629 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 630 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 631 |
-
|
| 632 |
-
cutlass::HostTensor<
|
| 633 |
-
typename DualGemm::ElementC,
|
| 634 |
-
typename DualGemm::LayoutC> tensor_D2(
|
| 635 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 636 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 637 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 638 |
-
|
| 639 |
-
cutlass::HostTensor<
|
| 640 |
-
typename DualGemm::ElementC,
|
| 641 |
-
typename DualGemm::LayoutC> reference_D1(
|
| 642 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 643 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 644 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 645 |
-
|
| 646 |
-
cutlass::HostTensor<
|
| 647 |
-
typename DualGemm::ElementC,
|
| 648 |
-
typename DualGemm::LayoutC> reference_D2(
|
| 649 |
-
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
| 650 |
-
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
| 651 |
-
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
| 652 |
-
|
| 653 |
-
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
| 654 |
-
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118));
|
| 655 |
-
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
| 656 |
-
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011));
|
| 657 |
-
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113));
|
| 658 |
-
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
| 659 |
-
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
| 660 |
-
|
| 661 |
-
cutlass::reference::host::TensorFill(
|
| 662 |
-
tensor_D0.host_view());
|
| 663 |
-
cutlass::reference::host::TensorFill(
|
| 664 |
-
tensor_D1.host_view());
|
| 665 |
-
cutlass::reference::host::TensorFill(
|
| 666 |
-
tensor_D2.host_view());
|
| 667 |
-
cutlass::reference::host::TensorFill(
|
| 668 |
-
reference_D0.host_view());
|
| 669 |
-
cutlass::reference::host::TensorFill(
|
| 670 |
-
reference_D1.host_view());
|
| 671 |
-
cutlass::reference::host::TensorFill(
|
| 672 |
-
reference_D2.host_view());
|
| 673 |
-
|
| 674 |
-
tensor_A0.sync_device();
|
| 675 |
-
tensor_B0.sync_device();
|
| 676 |
-
tensor_C0.sync_device();
|
| 677 |
-
tensor_Bias0.sync_device();
|
| 678 |
-
tensor_B1.sync_device();
|
| 679 |
-
tensor_C1.sync_device();
|
| 680 |
-
tensor_Bias1.sync_device();
|
| 681 |
-
tensor_D0.sync_device();
|
| 682 |
-
tensor_D1.sync_device();
|
| 683 |
-
tensor_D2.sync_device();
|
| 684 |
-
reference_D0.sync_device();
|
| 685 |
-
reference_D1.sync_device();
|
| 686 |
-
reference_D2.sync_device();
|
| 687 |
-
|
| 688 |
-
//
|
| 689 |
-
// Batch strides (irrelevant when batch_count == 1)
|
| 690 |
-
//
|
| 691 |
-
|
| 692 |
-
int64_t batch_stride_A = problem_size.m() * problem_size.k();
|
| 693 |
-
int64_t batch_stride_B0 = problem_size.k() * problem_size.n();
|
| 694 |
-
int64_t batch_stride_B1 = problem_size.k() * problem_size.n();
|
| 695 |
-
if (broadcast_b1) {
|
| 696 |
-
// B1 is a (column) vector
|
| 697 |
-
batch_stride_B1 = problem_size.k();
|
| 698 |
-
}
|
| 699 |
-
int64_t batch_stride_Bias = problem_size.n();
|
| 700 |
-
int64_t batch_stride_D = problem_size.m() * problem_size.n();
|
| 701 |
-
|
| 702 |
-
//
|
| 703 |
-
// Initialize the GEMM operator
|
| 704 |
-
//
|
| 705 |
-
|
| 706 |
-
int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
|
| 707 |
-
typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> nullptr_ref{};
|
| 708 |
-
decltype(nullptr_ref) ref_B0, ref_B1;
|
| 709 |
-
if (beta0 != ElementCompute(0)) {
|
| 710 |
-
ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
| 711 |
-
}
|
| 712 |
-
if (beta1 != ElementCompute(0)) {
|
| 713 |
-
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
| 714 |
-
}
|
| 715 |
-
typename DualGemm::Arguments arguments{
|
| 716 |
-
(batch_count > 1 ?
|
| 717 |
-
cutlass::gemm::DualGemmMode::kBatched :
|
| 718 |
-
cutlass::gemm::DualGemmMode::kGemm),
|
| 719 |
-
problem_size,
|
| 720 |
-
tensor_A0.device_ref(),
|
| 721 |
-
tensor_B0.device_ref(),
|
| 722 |
-
ref_B0,
|
| 723 |
-
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
|
| 724 |
-
(broadcast_b1 ?
|
| 725 |
-
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
|
| 726 |
-
tensor_B1.device_ref()),
|
| 727 |
-
ref_B1,
|
| 728 |
-
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
|
| 729 |
-
tensor_D2.device_ref(),
|
| 730 |
-
{alpha0, beta0},
|
| 731 |
-
{alpha1, beta1},
|
| 732 |
-
{},
|
| 733 |
-
split_k_slices,
|
| 734 |
-
batch_count,
|
| 735 |
-
batch_stride_A,
|
| 736 |
-
batch_stride_B0,
|
| 737 |
-
batch_stride_B1,
|
| 738 |
-
batch_stride_Bias,
|
| 739 |
-
batch_stride_D,
|
| 740 |
-
};
|
| 741 |
-
|
| 742 |
-
//
|
| 743 |
-
// Run the GEMM
|
| 744 |
-
//
|
| 745 |
-
|
| 746 |
-
DualGemm b2b_gemm_op;
|
| 747 |
-
|
| 748 |
-
cutlass::device_memory::allocation<uint8_t> workspace(b2b_gemm_op.get_workspace_size(arguments));
|
| 749 |
-
|
| 750 |
-
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
| 751 |
-
|
| 752 |
-
CUTLASS_CHECK(status);
|
| 753 |
-
|
| 754 |
-
status = b2b_gemm_op.initialize(arguments, workspace.get());
|
| 755 |
-
|
| 756 |
-
CUTLASS_CHECK(status);
|
| 757 |
-
|
| 758 |
-
for(int i = 0; i < warm_ups; i++) {
|
| 759 |
-
status = b2b_gemm_op();
|
| 760 |
-
CUTLASS_CHECK(status);
|
| 761 |
-
}
|
| 762 |
-
|
| 763 |
-
if (is_profiling) {
|
| 764 |
-
//
|
| 765 |
-
// Profile the GEMM
|
| 766 |
-
//
|
| 767 |
-
|
| 768 |
-
cudaEvent_t start, stop;
|
| 769 |
-
cudaEventCreate(&start);
|
| 770 |
-
cudaEventCreate(&stop);
|
| 771 |
-
|
| 772 |
-
cudaEventRecord(start);
|
| 773 |
-
|
| 774 |
-
for(int i = 0; i < runs; i++) {
|
| 775 |
-
status = b2b_gemm_op();
|
| 776 |
-
CUTLASS_CHECK(status);
|
| 777 |
-
}
|
| 778 |
-
|
| 779 |
-
cudaEventRecord(stop);
|
| 780 |
-
cudaDeviceSynchronize();
|
| 781 |
-
float gemmTime;
|
| 782 |
-
cudaEventElapsedTime(&gemmTime, start, stop);
|
| 783 |
-
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
| 784 |
-
}
|
| 785 |
-
|
| 786 |
-
tensor_D0.sync_host();
|
| 787 |
-
tensor_D1.sync_host();
|
| 788 |
-
tensor_D2.sync_host();
|
| 789 |
-
|
| 790 |
-
//
|
| 791 |
-
// Verify
|
| 792 |
-
//
|
| 793 |
-
|
| 794 |
-
using GemmUniversal0 = cutlass::gemm::device::GemmUniversal<
|
| 795 |
-
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
| 796 |
-
typename DualGemm::ElementB, typename DualGemm::LayoutB0,
|
| 797 |
-
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
| 798 |
-
ElementAccumulator
|
| 799 |
-
>;
|
| 800 |
-
|
| 801 |
-
GemmUniversal0 reference_gemm0;
|
| 802 |
-
|
| 803 |
-
typename GemmUniversal0::Arguments args0 {
|
| 804 |
-
(batch_count > 1 ?
|
| 805 |
-
cutlass::gemm::GemmUniversalMode::kBatched :
|
| 806 |
-
cutlass::gemm::GemmUniversalMode::kGemm),
|
| 807 |
-
problem_size,
|
| 808 |
-
batch_count,
|
| 809 |
-
{alpha0, beta0},
|
| 810 |
-
tensor_A0.device_data(),
|
| 811 |
-
tensor_B0.device_data(),
|
| 812 |
-
tensor_Bias0.device_data(),
|
| 813 |
-
reference_D0.device_data(),
|
| 814 |
-
batch_stride_A,
|
| 815 |
-
batch_stride_B0,
|
| 816 |
-
batch_stride_Bias,
|
| 817 |
-
batch_stride_D,
|
| 818 |
-
tensor_A0.stride(0),
|
| 819 |
-
tensor_B0.stride(0),
|
| 820 |
-
0, // zero stride for the bias vector
|
| 821 |
-
reference_D0.stride(0),
|
| 822 |
-
};
|
| 823 |
-
|
| 824 |
-
status = reference_gemm0.can_implement(args0);
|
| 825 |
-
CUTLASS_CHECK(status);
|
| 826 |
-
status = reference_gemm0(args0);
|
| 827 |
-
CUTLASS_CHECK(status);
|
| 828 |
-
|
| 829 |
-
using GemmUniversal1 = cutlass::gemm::device::GemmUniversal<
|
| 830 |
-
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
| 831 |
-
typename DualGemm::ElementB, typename DualGemm::LayoutB1,
|
| 832 |
-
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
| 833 |
-
ElementAccumulator
|
| 834 |
-
>;
|
| 835 |
-
|
| 836 |
-
GemmUniversal1 reference_gemm1;
|
| 837 |
-
|
| 838 |
-
typename GemmUniversal1::Arguments args1 {
|
| 839 |
-
(batch_count > 1 ?
|
| 840 |
-
cutlass::gemm::GemmUniversalMode::kBatched :
|
| 841 |
-
cutlass::gemm::GemmUniversalMode::kGemm),
|
| 842 |
-
problem_size,
|
| 843 |
-
batch_count,
|
| 844 |
-
{alpha1, beta1},
|
| 845 |
-
tensor_A0.device_data(),
|
| 846 |
-
tensor_B1.device_data(),
|
| 847 |
-
tensor_Bias1.device_data(),
|
| 848 |
-
reference_D1.device_data(),
|
| 849 |
-
batch_stride_A,
|
| 850 |
-
batch_stride_B1,
|
| 851 |
-
batch_stride_Bias,
|
| 852 |
-
batch_stride_D,
|
| 853 |
-
tensor_A0.stride(0),
|
| 854 |
-
(broadcast_b1 ? 0 : tensor_B1.stride(0)),
|
| 855 |
-
0, // zero stride for the bias vector
|
| 856 |
-
reference_D1.stride(0),
|
| 857 |
-
};
|
| 858 |
-
|
| 859 |
-
status = reference_gemm1.can_implement(args1);
|
| 860 |
-
CUTLASS_CHECK(status);
|
| 861 |
-
status = reference_gemm1(args1);
|
| 862 |
-
CUTLASS_CHECK(status);
|
| 863 |
-
|
| 864 |
-
if(relu) {
|
| 865 |
-
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
| 866 |
-
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
| 867 |
-
}
|
| 868 |
-
|
| 869 |
-
TensorEpilogueForEach<EpilogueOutputOp2>(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view());
|
| 870 |
-
cudaDeviceSynchronize();
|
| 871 |
-
reference_D0.sync_host();
|
| 872 |
-
reference_D1.sync_host();
|
| 873 |
-
reference_D2.sync_host();
|
| 874 |
-
|
| 875 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
| 876 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
| 877 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
|
| 878 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0);
|
| 879 |
-
|
| 880 |
-
bool passed_out0 = true;
|
| 881 |
-
if (DualGemm::kStoreD0) {
|
| 882 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
| 883 |
-
passed_out0 = cutlass::reference::host::TensorEquals(
|
| 884 |
-
reference_D0.host_view(),
|
| 885 |
-
tensor_D0.host_view());
|
| 886 |
-
}
|
| 887 |
-
CHECK_TRUE(passed_out0);
|
| 888 |
-
|
| 889 |
-
bool passed_out1 = true;
|
| 890 |
-
if (DualGemm::kStoreD1) {
|
| 891 |
-
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
| 892 |
-
passed_out1 = cutlass::reference::host::TensorEquals(
|
| 893 |
-
reference_D1.host_view(),
|
| 894 |
-
tensor_D1.host_view());
|
| 895 |
-
}
|
| 896 |
-
CHECK_TRUE(passed_out1);
|
| 897 |
-
|
| 898 |
-
bool passed_out2 = cutlass::reference::host::TensorEquals(
|
| 899 |
-
reference_D2.host_view(),
|
| 900 |
-
tensor_D2.host_view());
|
| 901 |
-
CHECK_TRUE(passed_out2);
|
| 902 |
-
|
| 903 |
-
bool passed = passed_out0 && passed_out1 && passed_out2;
|
| 904 |
-
if (!passed)
|
| 905 |
-
{
|
| 906 |
-
std::stringstream fname;
|
| 907 |
-
|
| 908 |
-
fname << "error_DualGemm_device_fused.txt";
|
| 909 |
-
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 910 |
-
|
| 911 |
-
std::ofstream file(fname.str());
|
| 912 |
-
|
| 913 |
-
file
|
| 914 |
-
<< "A0 =\n" << tensor_A0.host_view()
|
| 915 |
-
<< "\nB0 =\n" << tensor_B0.host_view()
|
| 916 |
-
<< "\nC0 =\n" << tensor_C0.host_view()
|
| 917 |
-
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 918 |
-
<< "\nB1 =\n" << tensor_B1.host_view()
|
| 919 |
-
<< "\nC1 =\n" << tensor_C1.host_view()
|
| 920 |
-
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 921 |
-
<< "\n\nReference0 =\n" << reference_D0.host_view()
|
| 922 |
-
<< "\nComputed0 =\n" << tensor_D0.host_view()
|
| 923 |
-
<< "\n\nReference1 =\n" << reference_D1.host_view()
|
| 924 |
-
<< "\nComputed1 =\n" << tensor_D1.host_view()
|
| 925 |
-
<< "\n\nReference2 =\n" << reference_D2.host_view()
|
| 926 |
-
<< "\nComputed2 =\n" << tensor_D2.host_view();
|
| 927 |
-
}
|
| 928 |
-
//std::cout << "A0 " << tensor_A0.host_view() << std::endl;
|
| 929 |
-
// std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
|
| 930 |
-
// std::cout << "reference_D1 " << reference_D1.host_view() << std::endl;
|
| 931 |
-
// std::cout << "reference_D2 " << reference_D2.host_view() << std::endl;
|
| 932 |
-
//std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
|
| 933 |
-
return passed;
|
| 934 |
-
}
|
| 935 |
-
|
| 936 |
-
};
|
| 937 |
-
|
| 938 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h
DELETED
|
@@ -1,545 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/cutlass.h"
|
| 38 |
-
|
| 39 |
-
#include "cutlass/gemm/gemm.h"
|
| 40 |
-
#include "cutlass/matrix_coord.h"
|
| 41 |
-
#include "cutlass/semaphore.h"
|
| 42 |
-
|
| 43 |
-
#include "../threadblock/dual_mma_multistage.h"
|
| 44 |
-
#include "../threadblock/dual_epilogue.h"
|
| 45 |
-
#include "../dual_gemm_common.h"
|
| 46 |
-
|
| 47 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
-
|
| 49 |
-
namespace cutlass {
|
| 50 |
-
namespace gemm {
|
| 51 |
-
namespace kernel {
|
| 52 |
-
|
| 53 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
-
|
| 55 |
-
template <
|
| 56 |
-
typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
| 57 |
-
typename Epilogue0_, ///! Epilogue
|
| 58 |
-
typename Epilogue1_, ///! Epilogue
|
| 59 |
-
typename OutputOp2_, ///! Epilogue
|
| 60 |
-
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
| 61 |
-
bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
|
| 62 |
-
bool StoreD0,
|
| 63 |
-
bool StoreD1
|
| 64 |
-
>
|
| 65 |
-
struct DualGemm {
|
| 66 |
-
|
| 67 |
-
using DualMma = DualMma_;
|
| 68 |
-
|
| 69 |
-
using Epilogue0 = Epilogue0_;
|
| 70 |
-
using Epilogue1 = Epilogue1_;
|
| 71 |
-
using OutputOp0 = typename Epilogue0::OutputOp;
|
| 72 |
-
using OutputOp1 = typename Epilogue1::OutputOp;
|
| 73 |
-
using OutputOp2 = OutputOp2_;
|
| 74 |
-
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 75 |
-
static constexpr bool kStoreD0 = StoreD0;
|
| 76 |
-
static constexpr bool kStoreD1 = StoreD1;
|
| 77 |
-
|
| 78 |
-
using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue<
|
| 79 |
-
typename Epilogue0::Shape,
|
| 80 |
-
typename Epilogue0::WarpMmaOperator,
|
| 81 |
-
Epilogue0::kPartitionsK,
|
| 82 |
-
typename Epilogue0::OutputTileIterator,
|
| 83 |
-
typename Epilogue0::AccumulatorFragmentIterator,
|
| 84 |
-
typename Epilogue0::WarpTileIterator,
|
| 85 |
-
typename Epilogue0::SharedLoadIterator,
|
| 86 |
-
OutputOp0,
|
| 87 |
-
OutputOp1,
|
| 88 |
-
OutputOp2,
|
| 89 |
-
typename Epilogue0::Padding,
|
| 90 |
-
kStoreD0,
|
| 91 |
-
kStoreD1,
|
| 92 |
-
Epilogue0::kFragmentsPerIteration,
|
| 93 |
-
true // IterationsUnroll
|
| 94 |
-
>;
|
| 95 |
-
|
| 96 |
-
using ElementA = typename DualMma::IteratorA::Element;
|
| 97 |
-
using ElementB = typename DualMma::IteratorB0::Element;
|
| 98 |
-
using ElementC = typename DualEpilogue::OutputTileIterator::Element;
|
| 99 |
-
|
| 100 |
-
static bool const kSplitKSerial = SplitKSerial;
|
| 101 |
-
static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1),
|
| 102 |
-
"Split-K serial requires buffers for D0/D1 for reduction");
|
| 103 |
-
|
| 104 |
-
/// Warp count (concept: GemmShape)
|
| 105 |
-
using WarpCount0 = typename DualMma::WarpCount;
|
| 106 |
-
static int const kThreadCount = 32 * WarpCount0::kCount;
|
| 107 |
-
|
| 108 |
-
/// Parameters structure
|
| 109 |
-
struct Params {
|
| 110 |
-
DualGemmMode mode;
|
| 111 |
-
cutlass::gemm::GemmCoord problem_size;
|
| 112 |
-
cutlass::gemm::GemmCoord grid_tiled_shape;
|
| 113 |
-
int swizzle_log_tile;
|
| 114 |
-
|
| 115 |
-
// Mma0
|
| 116 |
-
typename DualMma::IteratorA::Params params_A0;
|
| 117 |
-
typename DualMma::IteratorA::TensorRef ref_A0;
|
| 118 |
-
typename DualMma::IteratorB0::Params params_B0;
|
| 119 |
-
typename DualMma::IteratorB0::TensorRef ref_B0;
|
| 120 |
-
typename Epilogue0::OutputTileIterator::Params params_C0;
|
| 121 |
-
typename Epilogue0::OutputTileIterator::TensorRef ref_C0;
|
| 122 |
-
typename Epilogue0::OutputTileIterator::Params params_D0;
|
| 123 |
-
typename Epilogue0::OutputTileIterator::TensorRef ref_D0;
|
| 124 |
-
typename OutputOp0::Params output_op_0;
|
| 125 |
-
|
| 126 |
-
// Mma1
|
| 127 |
-
typename DualMma::IteratorB1::Params params_B1;
|
| 128 |
-
typename DualMma::IteratorB1::TensorRef ref_B1;
|
| 129 |
-
typename Epilogue1::OutputTileIterator::Params params_C1;
|
| 130 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_C1;
|
| 131 |
-
typename Epilogue1::OutputTileIterator::Params params_D1;
|
| 132 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_D1;
|
| 133 |
-
typename OutputOp1::Params output_op_1;
|
| 134 |
-
|
| 135 |
-
typename Epilogue1::OutputTileIterator::Params params_D2;
|
| 136 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_D2;
|
| 137 |
-
typename OutputOp2::Params output_op_2;
|
| 138 |
-
|
| 139 |
-
int *semaphore;
|
| 140 |
-
int gemm_k_size;
|
| 141 |
-
|
| 142 |
-
int64_t batch_stride_A;
|
| 143 |
-
int64_t batch_stride_B0;
|
| 144 |
-
int64_t batch_stride_B1;
|
| 145 |
-
int64_t batch_stride_C;
|
| 146 |
-
int64_t batch_stride_D;
|
| 147 |
-
|
| 148 |
-
//
|
| 149 |
-
// Methods
|
| 150 |
-
//
|
| 151 |
-
|
| 152 |
-
CUTLASS_HOST_DEVICE
|
| 153 |
-
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { }
|
| 154 |
-
|
| 155 |
-
CUTLASS_HOST_DEVICE
|
| 156 |
-
Params(
|
| 157 |
-
DualGemmMode mode,
|
| 158 |
-
cutlass::gemm::GemmCoord const & problem_size,
|
| 159 |
-
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
| 160 |
-
// Mma0: D0 = A @ B0 + C0
|
| 161 |
-
typename DualMma::IteratorA::TensorRef ref_A0,
|
| 162 |
-
typename DualMma::IteratorB0::TensorRef ref_B0,
|
| 163 |
-
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
|
| 164 |
-
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
|
| 165 |
-
// Mma1: D1 = A @ B1 + C1
|
| 166 |
-
typename DualMma::IteratorB1::TensorRef ref_B1,
|
| 167 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
|
| 168 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
|
| 169 |
-
|
| 170 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_D2,
|
| 171 |
-
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
| 172 |
-
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
| 173 |
-
typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(),
|
| 174 |
-
int *workspace = nullptr,
|
| 175 |
-
int64_t batch_stride_A = 1,
|
| 176 |
-
int64_t batch_stride_B0 = 1,
|
| 177 |
-
int64_t batch_stride_B1 = 1,
|
| 178 |
-
int64_t batch_stride_C = 1,
|
| 179 |
-
int64_t batch_stride_D = 1
|
| 180 |
-
):
|
| 181 |
-
mode(mode),
|
| 182 |
-
problem_size(problem_size),
|
| 183 |
-
grid_tiled_shape(grid_tiled_shape),
|
| 184 |
-
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
| 185 |
-
// Mma0
|
| 186 |
-
params_A0(ref_A0.layout()),
|
| 187 |
-
ref_A0(ref_A0),
|
| 188 |
-
params_B0(ref_B0.layout()),
|
| 189 |
-
ref_B0(ref_B0),
|
| 190 |
-
params_C0(ref_C0.layout()),
|
| 191 |
-
ref_C0(ref_C0),
|
| 192 |
-
params_D0(ref_D0.layout()),
|
| 193 |
-
ref_D0(ref_D0),
|
| 194 |
-
// Mma1
|
| 195 |
-
params_B1(ref_B1.layout()),
|
| 196 |
-
ref_B1(ref_B1),
|
| 197 |
-
params_C1(ref_C1.layout()),
|
| 198 |
-
ref_C1(ref_C1),
|
| 199 |
-
params_D1(ref_D1.layout()),
|
| 200 |
-
ref_D1(ref_D1),
|
| 201 |
-
params_D2(ref_D2.layout()),
|
| 202 |
-
ref_D2(ref_D2),
|
| 203 |
-
output_op_0(output_op_0),
|
| 204 |
-
output_op_1(output_op_1),
|
| 205 |
-
output_op_2(output_op_2),
|
| 206 |
-
batch_stride_A(batch_stride_A),
|
| 207 |
-
batch_stride_B0(batch_stride_B0),
|
| 208 |
-
batch_stride_B1(batch_stride_B1),
|
| 209 |
-
batch_stride_C(batch_stride_C),
|
| 210 |
-
batch_stride_D(batch_stride_D) {
|
| 211 |
-
|
| 212 |
-
int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
| 213 |
-
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
| 214 |
-
gemm_k_size = gemm_k_iterations * DualMma::Shape::kK;
|
| 215 |
-
|
| 216 |
-
semaphore = workspace;
|
| 217 |
-
}
|
| 218 |
-
};
|
| 219 |
-
|
| 220 |
-
/// Shared memory storage structure
|
| 221 |
-
union SharedStorage {
|
| 222 |
-
typename DualMma::SharedStorage main_loop;
|
| 223 |
-
typename DualEpilogue::SharedStorage epilogue;
|
| 224 |
-
};
|
| 225 |
-
|
| 226 |
-
//
|
| 227 |
-
// Methods
|
| 228 |
-
//
|
| 229 |
-
|
| 230 |
-
CUTLASS_HOST_DEVICE
|
| 231 |
-
DualGemm() { }
|
| 232 |
-
|
| 233 |
-
/// Determines whether kernel satisfies alignment
|
| 234 |
-
static Status can_implement(
|
| 235 |
-
cutlass::gemm::GemmCoord const & problem_size,
|
| 236 |
-
typename DualMma::IteratorA::TensorRef ref_A0,
|
| 237 |
-
typename DualMma::IteratorB0::TensorRef ref_B0,
|
| 238 |
-
typename Epilogue0::OutputTileIterator::TensorRef ref_C0,
|
| 239 |
-
typename Epilogue0::OutputTileIterator::TensorRef ref_D0,
|
| 240 |
-
typename DualMma::IteratorB1::TensorRef ref_B1,
|
| 241 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_C1,
|
| 242 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_D1,
|
| 243 |
-
typename Epilogue1::OutputTileIterator::TensorRef ref_D2) {
|
| 244 |
-
|
| 245 |
-
static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements;
|
| 246 |
-
static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements;
|
| 247 |
-
static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess;
|
| 248 |
-
|
| 249 |
-
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
|
| 250 |
-
return Status::kErrorMisalignedOperand;
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
|
| 254 |
-
return Status::kErrorMisalignedOperand;
|
| 255 |
-
}
|
| 256 |
-
|
| 257 |
-
if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
|
| 258 |
-
return Status::kErrorMisalignedOperand;
|
| 259 |
-
}
|
| 260 |
-
|
| 261 |
-
if (!TensorRef_aligned(ref_D0, kAlignmentC)) {
|
| 262 |
-
return Status::kErrorMisalignedOperand;
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
-
if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
|
| 266 |
-
return Status::kErrorMisalignedOperand;
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
-
if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
|
| 270 |
-
return Status::kErrorMisalignedOperand;
|
| 271 |
-
}
|
| 272 |
-
|
| 273 |
-
if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
|
| 274 |
-
return Status::kErrorMisalignedOperand;
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
if (!TensorRef_aligned(ref_D2, kAlignmentC)) {
|
| 278 |
-
return Status::kErrorMisalignedOperand;
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
return Status::kSuccess;
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
/// Executes one GEMM
|
| 285 |
-
CUTLASS_DEVICE
|
| 286 |
-
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 287 |
-
// Compute threadblock location
|
| 288 |
-
ThreadblockSwizzle threadblock_swizzle;
|
| 289 |
-
|
| 290 |
-
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
| 291 |
-
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 292 |
-
|
| 293 |
-
// Early exit if CTA is out of range
|
| 294 |
-
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
| 295 |
-
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
| 296 |
-
|
| 297 |
-
return;
|
| 298 |
-
}
|
| 299 |
-
|
| 300 |
-
int offset_k = 0;
|
| 301 |
-
int problem_size_k = params.problem_size.k();
|
| 302 |
-
|
| 303 |
-
ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
|
| 304 |
-
ElementB *ptr_B0 = static_cast<ElementB *>(params.ref_B0.data());
|
| 305 |
-
ElementB *ptr_B1 = static_cast<ElementB *>(params.ref_B1.data());
|
| 306 |
-
|
| 307 |
-
//
|
| 308 |
-
// Fetch pointers based on mode.
|
| 309 |
-
//
|
| 310 |
-
if (params.mode == DualGemmMode::kGemm) {
|
| 311 |
-
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
| 312 |
-
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
| 313 |
-
}
|
| 314 |
-
|
| 315 |
-
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
| 316 |
-
}
|
| 317 |
-
else if (params.mode == DualGemmMode::kBatched) {
|
| 318 |
-
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A;
|
| 319 |
-
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
| 320 |
-
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
-
// Compute initial location in logical coordinates
|
| 324 |
-
cutlass::MatrixCoord tb_offset_A0{
|
| 325 |
-
threadblock_tile_offset.m() * DualMma::Shape::kM,
|
| 326 |
-
offset_k,
|
| 327 |
-
};
|
| 328 |
-
|
| 329 |
-
cutlass::MatrixCoord tb_offset_B0{
|
| 330 |
-
offset_k,
|
| 331 |
-
threadblock_tile_offset.n() * DualMma::Shape::kN
|
| 332 |
-
};
|
| 333 |
-
|
| 334 |
-
cutlass::MatrixCoord tb_offset_B1{
|
| 335 |
-
offset_k,
|
| 336 |
-
threadblock_tile_offset.n() * DualMma::Shape::kN
|
| 337 |
-
};
|
| 338 |
-
|
| 339 |
-
// Compute position within threadblock
|
| 340 |
-
int thread_idx = threadIdx.x;
|
| 341 |
-
|
| 342 |
-
// Construct iterators to A and B operands
|
| 343 |
-
typename DualMma::IteratorA iterator_A0(
|
| 344 |
-
params.params_A0,
|
| 345 |
-
ptr_A0,
|
| 346 |
-
{params.problem_size.m(), problem_size_k},
|
| 347 |
-
thread_idx,
|
| 348 |
-
tb_offset_A0);
|
| 349 |
-
|
| 350 |
-
typename DualMma::IteratorB0 iterator_B0(
|
| 351 |
-
params.params_B0,
|
| 352 |
-
ptr_B0,
|
| 353 |
-
{problem_size_k, params.problem_size.n()},
|
| 354 |
-
thread_idx,
|
| 355 |
-
tb_offset_B0);
|
| 356 |
-
|
| 357 |
-
typename DualMma::IteratorB1 iterator_B1(
|
| 358 |
-
params.params_B1,
|
| 359 |
-
ptr_B1,
|
| 360 |
-
{problem_size_k, params.problem_size.n()},
|
| 361 |
-
thread_idx,
|
| 362 |
-
tb_offset_B1);
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
| 366 |
-
// is compiled as warp-uniform.
|
| 367 |
-
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 368 |
-
int lane_idx = threadIdx.x % 32;
|
| 369 |
-
|
| 370 |
-
//
|
| 371 |
-
// Main loop
|
| 372 |
-
//
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
// Construct thread-scoped matrix multiply
|
| 376 |
-
typename DualMma::FragmentC accum0;
|
| 377 |
-
typename DualMma::FragmentC accum1;
|
| 378 |
-
accum0.clear();
|
| 379 |
-
accum1.clear();
|
| 380 |
-
|
| 381 |
-
// Compute threadblock-scoped matrix multiply-add
|
| 382 |
-
int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK;
|
| 383 |
-
|
| 384 |
-
DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
| 385 |
-
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
| 386 |
-
// Compute threadblock-scoped matrix multiply-add
|
| 387 |
-
mma(gemm_k_iterations,
|
| 388 |
-
accum0, accum1,
|
| 389 |
-
iterator_A0, iterator_B0, iterator_B1,
|
| 390 |
-
accum0, accum1);
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
//
|
| 394 |
-
// Epilogue
|
| 395 |
-
//
|
| 396 |
-
|
| 397 |
-
OutputOp0 output_op_0(params.output_op_0);
|
| 398 |
-
OutputOp1 output_op_1(params.output_op_1);
|
| 399 |
-
OutputOp2 output_op_2(params.output_op_2);
|
| 400 |
-
|
| 401 |
-
//
|
| 402 |
-
// Masked tile iterators constructed from members
|
| 403 |
-
//
|
| 404 |
-
|
| 405 |
-
threadblock_tile_offset =
|
| 406 |
-
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 407 |
-
|
| 408 |
-
//assume identity swizzle
|
| 409 |
-
MatrixCoord threadblock_offset(
|
| 410 |
-
threadblock_tile_offset.m() * DualMma::Shape::kM,
|
| 411 |
-
threadblock_tile_offset.n() * DualMma::Shape::kN
|
| 412 |
-
);
|
| 413 |
-
|
| 414 |
-
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
| 415 |
-
|
| 416 |
-
ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
|
| 417 |
-
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
| 418 |
-
ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
|
| 419 |
-
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
| 420 |
-
ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
|
| 421 |
-
|
| 422 |
-
// Construct the semaphore.
|
| 423 |
-
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
| 424 |
-
|
| 425 |
-
if (params.mode == DualGemmMode::kGemm) {
|
| 426 |
-
// If performing a reduction via split-K, fetch the initial synchronization
|
| 427 |
-
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
| 428 |
-
|
| 429 |
-
// Fetch the synchronization lock initially but do not block.
|
| 430 |
-
semaphore.fetch();
|
| 431 |
-
|
| 432 |
-
// Indicate which position in a serial reduction the output operator is currently updating
|
| 433 |
-
output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
| 434 |
-
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
| 435 |
-
}
|
| 436 |
-
}
|
| 437 |
-
else if (params.mode == DualGemmMode::kBatched) {
|
| 438 |
-
ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C;
|
| 439 |
-
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C;
|
| 440 |
-
ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D;
|
| 441 |
-
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D;
|
| 442 |
-
ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D;
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
// Tile iterator loading from source tensor.
|
| 446 |
-
typename Epilogue0::OutputTileIterator iterator_C0(
|
| 447 |
-
params.params_C0,
|
| 448 |
-
ptr_C0,
|
| 449 |
-
params.problem_size.mn(),
|
| 450 |
-
thread_idx,
|
| 451 |
-
threadblock_offset
|
| 452 |
-
);
|
| 453 |
-
typename Epilogue1::OutputTileIterator iterator_C1(
|
| 454 |
-
params.params_C1,
|
| 455 |
-
ptr_C1,
|
| 456 |
-
params.problem_size.mn(),
|
| 457 |
-
thread_idx,
|
| 458 |
-
threadblock_offset
|
| 459 |
-
);
|
| 460 |
-
|
| 461 |
-
// Tile iterator writing to destination tensor.
|
| 462 |
-
typename Epilogue0::OutputTileIterator iterator_D0(
|
| 463 |
-
params.params_D0,
|
| 464 |
-
ptr_D0,
|
| 465 |
-
params.problem_size.mn(),
|
| 466 |
-
thread_idx,
|
| 467 |
-
threadblock_offset
|
| 468 |
-
);
|
| 469 |
-
typename Epilogue1::OutputTileIterator iterator_D1(
|
| 470 |
-
params.params_D1,
|
| 471 |
-
ptr_D1,
|
| 472 |
-
params.problem_size.mn(),
|
| 473 |
-
thread_idx,
|
| 474 |
-
threadblock_offset
|
| 475 |
-
);
|
| 476 |
-
typename Epilogue1::OutputTileIterator iterator_D2(
|
| 477 |
-
params.params_D2,
|
| 478 |
-
ptr_D2,
|
| 479 |
-
params.problem_size.mn(),
|
| 480 |
-
thread_idx,
|
| 481 |
-
threadblock_offset
|
| 482 |
-
);
|
| 483 |
-
|
| 484 |
-
DualEpilogue epilogue(
|
| 485 |
-
shared_storage.epilogue,
|
| 486 |
-
thread_idx,
|
| 487 |
-
warp_idx,
|
| 488 |
-
lane_idx);
|
| 489 |
-
|
| 490 |
-
// Wait on the semaphore - this latency may have been covered by iterator construction
|
| 491 |
-
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
| 492 |
-
|
| 493 |
-
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
| 494 |
-
if (threadblock_tile_offset.k()) {
|
| 495 |
-
iterator_C0 = iterator_D0;
|
| 496 |
-
iterator_C1 = iterator_D1;
|
| 497 |
-
}
|
| 498 |
-
|
| 499 |
-
semaphore.wait(threadblock_tile_offset.k());
|
| 500 |
-
|
| 501 |
-
__threadfence();
|
| 502 |
-
}
|
| 503 |
-
|
| 504 |
-
// Execute the epilogue operator to update the destination tensor.
|
| 505 |
-
typename Epilogue0::OutputTileIterator source_iters[] = {
|
| 506 |
-
iterator_C0, iterator_C1
|
| 507 |
-
};
|
| 508 |
-
const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1);
|
| 509 |
-
epilogue(
|
| 510 |
-
output_op_0, output_op_1, output_op_2,
|
| 511 |
-
iterator_D0, iterator_D1, iterator_D2,
|
| 512 |
-
accum0, accum1,
|
| 513 |
-
source_iters,
|
| 514 |
-
writeToD2
|
| 515 |
-
);
|
| 516 |
-
|
| 517 |
-
//
|
| 518 |
-
// Release the semaphore
|
| 519 |
-
//
|
| 520 |
-
|
| 521 |
-
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
| 522 |
-
|
| 523 |
-
int lock = 0;
|
| 524 |
-
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
| 525 |
-
|
| 526 |
-
// The final threadblock resets the semaphore for subsequent grids.
|
| 527 |
-
lock = 0;
|
| 528 |
-
}
|
| 529 |
-
else {
|
| 530 |
-
// Otherwise, the semaphore is incremented
|
| 531 |
-
lock = threadblock_tile_offset.k() + 1;
|
| 532 |
-
}
|
| 533 |
-
|
| 534 |
-
__threadfence();
|
| 535 |
-
semaphore.release(lock);
|
| 536 |
-
}
|
| 537 |
-
}
|
| 538 |
-
};
|
| 539 |
-
|
| 540 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 541 |
-
|
| 542 |
-
} // namespace kernel
|
| 543 |
-
} // namespace gemm
|
| 544 |
-
} // namespace cutlass
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
#include <iostream>
|
| 34 |
-
|
| 35 |
-
// Run tests on GPUs
|
| 36 |
-
|
| 37 |
-
int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
|
| 38 |
-
|
| 39 |
-
bool supported = false;
|
| 40 |
-
|
| 41 |
-
int arch_major = arch / 10;
|
| 42 |
-
int arch_minor = arch - arch / 10 * 10;
|
| 43 |
-
|
| 44 |
-
if(arch_major >= 8) {
|
| 45 |
-
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
| 46 |
-
//
|
| 47 |
-
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
| 48 |
-
if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
|
| 49 |
-
supported = true;
|
| 50 |
-
}
|
| 51 |
-
}
|
| 52 |
-
else if(arch_major >= 7) {
|
| 53 |
-
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
| 54 |
-
//
|
| 55 |
-
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
| 56 |
-
if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
|
| 57 |
-
supported = true;
|
| 58 |
-
}
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
cudaDeviceProp props;
|
| 62 |
-
|
| 63 |
-
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
| 64 |
-
if (error != cudaSuccess) {
|
| 65 |
-
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
| 66 |
-
return -1;
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) {
|
| 70 |
-
supported = false;
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
if (!supported) {
|
| 74 |
-
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
| 75 |
-
std::cout << "This example isn't supported on current architecture" << std::endl;
|
| 76 |
-
return 0;
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
bool pass = true;
|
| 80 |
-
|
| 81 |
-
std::cout << "Device: " << props.name << std::endl;
|
| 82 |
-
std::cout << "Arch: SM" << arch << std::endl;
|
| 83 |
-
std::cout << "Test: " << test_name << std::endl;
|
| 84 |
-
for(auto func : test_funcs) {
|
| 85 |
-
pass &= func();
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
if(pass)
|
| 90 |
-
return 0;
|
| 91 |
-
else
|
| 92 |
-
return -1;
|
| 93 |
-
|
| 94 |
-
}
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h
DELETED
|
@@ -1,150 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Functor performing linear combination operations used by epilogues.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/cutlass.h"
|
| 38 |
-
#include "cutlass/numeric_types.h"
|
| 39 |
-
#include "cutlass/array.h"
|
| 40 |
-
#include "cutlass/functional.h"
|
| 41 |
-
#include "cutlass/numeric_conversion.h"
|
| 42 |
-
#include "cutlass/epilogue/thread/scale_type.h"
|
| 43 |
-
#include "cutlass/epilogue/thread/linear_combination_params.h"
|
| 44 |
-
|
| 45 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
namespace epilogue {
|
| 49 |
-
namespace thread {
|
| 50 |
-
|
| 51 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
-
|
| 53 |
-
/// Applies a linear combination operator to an array of elements.
|
| 54 |
-
///
|
| 55 |
-
/// D = alpha * accumulator + beta * source + uniform
|
| 56 |
-
///
|
| 57 |
-
template <
|
| 58 |
-
typename ElementOutput_, ///< Data type used to load and store tensors
|
| 59 |
-
int Count, ///< Number of elements computed per operation.
|
| 60 |
-
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
| 61 |
-
///< but we use 64 or 32 sometimes when there are not enough data to store
|
| 62 |
-
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
| 63 |
-
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
| 64 |
-
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
| 65 |
-
>
|
| 66 |
-
class LeftSiLUAndMul {
|
| 67 |
-
public:
|
| 68 |
-
|
| 69 |
-
using ElementOutput = ElementOutput_;
|
| 70 |
-
using ElementAccumulator = ElementAccumulator_;
|
| 71 |
-
using ElementCompute = ElementCompute_;
|
| 72 |
-
|
| 73 |
-
static int const kCount = Count;
|
| 74 |
-
using FragmentOutput = Array<ElementOutput, kCount>;
|
| 75 |
-
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
| 76 |
-
using ComputeFragment = Array<ElementCompute, kCount>;
|
| 77 |
-
|
| 78 |
-
static FloatRoundStyle const kRound = Round;
|
| 79 |
-
|
| 80 |
-
struct Params{};
|
| 81 |
-
|
| 82 |
-
private:
|
| 83 |
-
|
| 84 |
-
//
|
| 85 |
-
// Data members
|
| 86 |
-
//
|
| 87 |
-
|
| 88 |
-
ElementCompute alpha_;
|
| 89 |
-
ElementCompute beta_;
|
| 90 |
-
|
| 91 |
-
public:
|
| 92 |
-
|
| 93 |
-
/// Constructs the function object, possibly loading from pointers in host memory
|
| 94 |
-
CUTLASS_HOST_DEVICE
|
| 95 |
-
LeftSiLUAndMul(Params const &/*params*/) {}
|
| 96 |
-
|
| 97 |
-
/// Returns true if source is needed
|
| 98 |
-
CUTLASS_HOST_DEVICE
|
| 99 |
-
bool is_source_needed() const {
|
| 100 |
-
return true;
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
/// Functionally required for serial reduction in the epilogue
|
| 104 |
-
CUTLASS_HOST_DEVICE
|
| 105 |
-
void set_k_partition(int k_partition, int k_partition_count) {
|
| 106 |
-
assert(false);
|
| 107 |
-
}
|
| 108 |
-
|
| 109 |
-
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
| 110 |
-
CUTLASS_HOST_DEVICE
|
| 111 |
-
FragmentOutput operator()(
|
| 112 |
-
FragmentAccumulator const &lhs,
|
| 113 |
-
FragmentAccumulator const &rhs) const {
|
| 114 |
-
|
| 115 |
-
// Convert source to interal compute numeric type
|
| 116 |
-
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_to_compute;
|
| 117 |
-
|
| 118 |
-
// Convert to destination numeric type
|
| 119 |
-
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> compute_to_output;
|
| 120 |
-
|
| 121 |
-
ComputeFragment converted_lhs = accumulator_to_compute(lhs);
|
| 122 |
-
ComputeFragment converted_rhs = accumulator_to_compute(rhs);
|
| 123 |
-
|
| 124 |
-
cutlass::epilogue::thread::SiLu<ComputeFragment> silu;
|
| 125 |
-
cutlass::multiplies<ComputeFragment> mul;
|
| 126 |
-
auto silu_lhs = silu(converted_lhs);
|
| 127 |
-
return compute_to_output(mul(silu_lhs, converted_rhs));
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
CUTLASS_HOST_DEVICE
|
| 131 |
-
ElementOutput operator()(
|
| 132 |
-
ElementAccumulator const& lhs,
|
| 133 |
-
ElementAccumulator const& rhs
|
| 134 |
-
) const {
|
| 135 |
-
ElementCompute convert_lhs(lhs);
|
| 136 |
-
ElementCompute convert_rhs(rhs);
|
| 137 |
-
cutlass::epilogue::thread::SiLu<ElementCompute> silu;
|
| 138 |
-
cutlass::multiplies<ElementCompute> mul;
|
| 139 |
-
auto silu_lhs = silu(convert_lhs);
|
| 140 |
-
return ElementOutput(mul(silu_lhs, convert_rhs));
|
| 141 |
-
}
|
| 142 |
-
};
|
| 143 |
-
|
| 144 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 145 |
-
|
| 146 |
-
} // namespace thread
|
| 147 |
-
} // namespace epilogue
|
| 148 |
-
} // namespace cutlass
|
| 149 |
-
|
| 150 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h
DELETED
|
@@ -1,424 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
-
|
| 34 |
-
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
-
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
-
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#pragma once
|
| 40 |
-
#include "cutlass/array.h"
|
| 41 |
-
#include CUDA_STD_HEADER(cassert)
|
| 42 |
-
#include "cutlass/cutlass.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
#include "cutlass/layout/vector.h"
|
| 45 |
-
#include "cutlass/layout/tensor.h"
|
| 46 |
-
#include "cutlass/tensor_coord.h"
|
| 47 |
-
#include "cutlass/aligned_buffer.h"
|
| 48 |
-
#include "cutlass/functional.h"
|
| 49 |
-
|
| 50 |
-
#include "cutlass/gemm/gemm.h"
|
| 51 |
-
|
| 52 |
-
#include "cutlass/transform/pitch_linear_thread_map.h"
|
| 53 |
-
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
| 54 |
-
|
| 55 |
-
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
| 56 |
-
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 57 |
-
#include "cutlass/numeric_types.h"
|
| 58 |
-
|
| 59 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 60 |
-
|
| 61 |
-
namespace cutlass {
|
| 62 |
-
namespace epilogue {
|
| 63 |
-
namespace threadblock {
|
| 64 |
-
|
| 65 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 66 |
-
|
| 67 |
-
/// Epilogue operator
|
| 68 |
-
template <
|
| 69 |
-
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
| 70 |
-
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
| 71 |
-
int PartitionsK, ///< Number of partitions of the K dimension
|
| 72 |
-
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
| 73 |
-
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
| 74 |
-
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
| 75 |
-
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
| 76 |
-
///< Output operator
|
| 77 |
-
typename OutputOp0_,
|
| 78 |
-
typename OutputOp1_,
|
| 79 |
-
typename OutputOp2_,
|
| 80 |
-
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
| 81 |
-
bool StoreD0 = true,
|
| 82 |
-
bool StoreD1 = true,
|
| 83 |
-
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
| 84 |
-
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
| 85 |
-
(!IsEpilogueFunctorHeavy<OutputOp0_>::value)
|
| 86 |
-
>
|
| 87 |
-
class DualEpilogue {
|
| 88 |
-
|
| 89 |
-
public:
|
| 90 |
-
|
| 91 |
-
using Base = EpilogueBase<
|
| 92 |
-
Shape_,
|
| 93 |
-
typename WarpMmaOperator_::Shape,
|
| 94 |
-
PartitionsK,
|
| 95 |
-
AccumulatorFragmentIterator_,
|
| 96 |
-
WarpTileIterator_,
|
| 97 |
-
Padding_,
|
| 98 |
-
FragmentsPerPartition>;
|
| 99 |
-
|
| 100 |
-
using Shape = Shape_;
|
| 101 |
-
using WarpMmaOperator = WarpMmaOperator_;
|
| 102 |
-
static int const kPartitionsK = PartitionsK;
|
| 103 |
-
static bool constexpr kStoreD0 = StoreD0;
|
| 104 |
-
static bool constexpr kStoreD1 = StoreD1;
|
| 105 |
-
using OutputTileIterator = OutputTileIterator_;
|
| 106 |
-
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
| 107 |
-
using WarpTileIterator = WarpTileIterator_;
|
| 108 |
-
using SharedLoadIterator = SharedLoadIterator_;
|
| 109 |
-
using OutputOp0 = OutputOp0_;
|
| 110 |
-
using OutputOp1 = OutputOp1_;
|
| 111 |
-
using OutputOp2 = OutputOp2_;
|
| 112 |
-
using Padding = Padding_;
|
| 113 |
-
|
| 114 |
-
using Layout = layout::RowMajor;
|
| 115 |
-
using LongIndex = typename Layout::LongIndex;
|
| 116 |
-
|
| 117 |
-
/// The complete warp-level accumulator tile
|
| 118 |
-
using AccumulatorTile = typename Base::AccumulatorTile;
|
| 119 |
-
|
| 120 |
-
/// Accumulator element
|
| 121 |
-
using ElementAccumulator = typename WarpTileIterator::Element;
|
| 122 |
-
|
| 123 |
-
/// Output element
|
| 124 |
-
using ElementOutput = typename OutputTileIterator::Element;
|
| 125 |
-
|
| 126 |
-
/// Output access size
|
| 127 |
-
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 128 |
-
|
| 129 |
-
/// Tensor reference to destination tensor
|
| 130 |
-
using TensorRef = typename OutputTileIterator::TensorRef;
|
| 131 |
-
|
| 132 |
-
/// Tensor reference to sync tensor
|
| 133 |
-
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
| 134 |
-
|
| 135 |
-
/// Const tensor reference to source tensor
|
| 136 |
-
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
| 137 |
-
|
| 138 |
-
/// Array type used to output
|
| 139 |
-
using OutputAccessType = Array<
|
| 140 |
-
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 141 |
-
|
| 142 |
-
/// Array type used by output functor
|
| 143 |
-
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
| 144 |
-
|
| 145 |
-
/// Number of warps
|
| 146 |
-
using WarpCount = typename Base::WarpCount;
|
| 147 |
-
|
| 148 |
-
struct SharedStorage {
|
| 149 |
-
using Element = typename WarpTileIterator::Element;
|
| 150 |
-
|
| 151 |
-
/// Tensor reference to shared memory allocation
|
| 152 |
-
using TensorRef = typename WarpTileIterator::TensorRef;
|
| 153 |
-
|
| 154 |
-
/// Logical shape of the shared memory tile written to by all warps.
|
| 155 |
-
using Shape = typename Base::Shape;
|
| 156 |
-
|
| 157 |
-
/// Shape of the shared memory allocation for the epilogue
|
| 158 |
-
using StorageShape = typename Base::SharedStorage::StorageShape;
|
| 159 |
-
|
| 160 |
-
//
|
| 161 |
-
// Data members
|
| 162 |
-
//
|
| 163 |
-
|
| 164 |
-
AlignedBuffer<Element, StorageShape::kCount> storage[2];
|
| 165 |
-
|
| 166 |
-
//
|
| 167 |
-
// Methods
|
| 168 |
-
//
|
| 169 |
-
|
| 170 |
-
/// Returns a tensor reference to the shared memory buffer
|
| 171 |
-
CUTLASS_DEVICE
|
| 172 |
-
TensorRef reference(int i) {
|
| 173 |
-
return TensorRef(
|
| 174 |
-
storage[i].data(),
|
| 175 |
-
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
| 176 |
-
}
|
| 177 |
-
};
|
| 178 |
-
|
| 179 |
-
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
| 180 |
-
static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles;
|
| 181 |
-
|
| 182 |
-
public:
|
| 183 |
-
|
| 184 |
-
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
| 185 |
-
"Mismatch between shared load iterator and output tile iterator.");
|
| 186 |
-
|
| 187 |
-
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
| 188 |
-
|
| 189 |
-
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
| 190 |
-
"Divisibility");
|
| 191 |
-
|
| 192 |
-
private:
|
| 193 |
-
|
| 194 |
-
/// Loads fragment from shared memory aligned with output tensor
|
| 195 |
-
SharedLoadIterator shared_load_iterator0_;
|
| 196 |
-
SharedLoadIterator shared_load_iterator1_;
|
| 197 |
-
|
| 198 |
-
/// Stores a warp's fragment of accumulators to SMEM
|
| 199 |
-
WarpTileIterator warp_tile_iterator0_;
|
| 200 |
-
WarpTileIterator warp_tile_iterator1_;
|
| 201 |
-
|
| 202 |
-
public:
|
| 203 |
-
|
| 204 |
-
/// Constructor
|
| 205 |
-
CUTLASS_DEVICE
|
| 206 |
-
DualEpilogue(
|
| 207 |
-
SharedStorage &shared_storage, ///< Shared storage object
|
| 208 |
-
int thread_idx, ///< ID of a thread within the threadblock
|
| 209 |
-
int warp_idx, ///< ID of warp within threadblock
|
| 210 |
-
int lane_idx ///< Id of thread within warp
|
| 211 |
-
):
|
| 212 |
-
shared_load_iterator0_(shared_storage.reference(0), thread_idx),
|
| 213 |
-
shared_load_iterator1_(shared_storage.reference(1), thread_idx),
|
| 214 |
-
warp_tile_iterator0_(shared_storage.reference(0), lane_idx),
|
| 215 |
-
warp_tile_iterator1_(shared_storage.reference(1), lane_idx)
|
| 216 |
-
{
|
| 217 |
-
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
|
| 218 |
-
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
|
| 219 |
-
int warp_m = warp_mn % WarpCount::kM;
|
| 220 |
-
int warp_n = warp_mn / WarpCount::kM;
|
| 221 |
-
|
| 222 |
-
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
|
| 223 |
-
|
| 224 |
-
warp_tile_iterator0_.add_tile_offset(warp_offset);
|
| 225 |
-
warp_tile_iterator1_.add_tile_offset(warp_offset);
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
/// Streams the result to global memory
|
| 229 |
-
CUTLASS_DEVICE
|
| 230 |
-
void operator()(
|
| 231 |
-
OutputOp0 const &output_op0,
|
| 232 |
-
OutputOp1 const &output_op1,
|
| 233 |
-
OutputOp2 const &output_op2,
|
| 234 |
-
OutputTileIterator dest0,
|
| 235 |
-
OutputTileIterator dest1,
|
| 236 |
-
OutputTileIterator dest2,
|
| 237 |
-
AccumulatorTile const &accumulator0,
|
| 238 |
-
AccumulatorTile const &accumulator1,
|
| 239 |
-
OutputTileIterator source_iterator[2],
|
| 240 |
-
bool writeToD2 // true if it's the final split-k
|
| 241 |
-
) {
|
| 242 |
-
// TODO: Implement when no source is needed
|
| 243 |
-
|
| 244 |
-
typename OutputTileIterator::Fragment source_fragment[2];
|
| 245 |
-
CUTLASS_PRAGMA_UNROLL
|
| 246 |
-
for (int i = 0; i < 2; ++i) {
|
| 247 |
-
source_fragment[i].clear();
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
//
|
| 251 |
-
// Iterator over warp-level accumulator fragment
|
| 252 |
-
//
|
| 253 |
-
|
| 254 |
-
AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1};
|
| 255 |
-
|
| 256 |
-
//
|
| 257 |
-
// Iterate over accumulator tile
|
| 258 |
-
//
|
| 259 |
-
|
| 260 |
-
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
| 261 |
-
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
| 262 |
-
|
| 263 |
-
//
|
| 264 |
-
// Load the source
|
| 265 |
-
//
|
| 266 |
-
|
| 267 |
-
CUTLASS_PRAGMA_UNROLL
|
| 268 |
-
for (int i = 0; i < 2; ++i) {
|
| 269 |
-
source_iterator[i].load(source_fragment[i]);
|
| 270 |
-
++source_iterator[i];
|
| 271 |
-
}
|
| 272 |
-
|
| 273 |
-
//
|
| 274 |
-
// Convert and store fragment
|
| 275 |
-
//
|
| 276 |
-
|
| 277 |
-
__syncthreads();
|
| 278 |
-
|
| 279 |
-
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
| 280 |
-
iter, accum_fragment_iterator[0], this->warp_tile_iterator0_);
|
| 281 |
-
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
| 282 |
-
iter, accum_fragment_iterator[1], this->warp_tile_iterator1_);
|
| 283 |
-
|
| 284 |
-
__syncthreads();
|
| 285 |
-
|
| 286 |
-
//
|
| 287 |
-
// Load fragments from shared memory
|
| 288 |
-
//
|
| 289 |
-
|
| 290 |
-
typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK];
|
| 291 |
-
typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK];
|
| 292 |
-
|
| 293 |
-
shared_load_iterator0_.load(aligned_accum_fragment0[0]);
|
| 294 |
-
shared_load_iterator1_.load(aligned_accum_fragment1[0]);
|
| 295 |
-
|
| 296 |
-
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
| 297 |
-
if (kPartitionsK > 1) {
|
| 298 |
-
|
| 299 |
-
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
| 300 |
-
|
| 301 |
-
CUTLASS_PRAGMA_UNROLL
|
| 302 |
-
for ( int i = 1; i < kPartitionsK; ++i) {
|
| 303 |
-
shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset);
|
| 304 |
-
shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset);
|
| 305 |
-
shared_load_iterator0_.load(aligned_accum_fragment0[i]);
|
| 306 |
-
shared_load_iterator1_.load(aligned_accum_fragment1[i]);
|
| 307 |
-
aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]);
|
| 308 |
-
aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]);
|
| 309 |
-
}
|
| 310 |
-
|
| 311 |
-
shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
| 312 |
-
shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
| 313 |
-
}
|
| 314 |
-
|
| 315 |
-
//
|
| 316 |
-
// Compute the output result
|
| 317 |
-
//
|
| 318 |
-
|
| 319 |
-
typename OutputTileIterator::Fragment output_fragment[3];
|
| 320 |
-
|
| 321 |
-
apply_output_operator_(output_fragment,
|
| 322 |
-
output_op0, output_op1, output_op2,
|
| 323 |
-
aligned_accum_fragment0[0], aligned_accum_fragment1[0],
|
| 324 |
-
source_fragment);
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
//
|
| 328 |
-
// Store the final result
|
| 329 |
-
//
|
| 330 |
-
|
| 331 |
-
if (kStoreD0) {
|
| 332 |
-
dest0.store(output_fragment[0]);
|
| 333 |
-
++dest0;
|
| 334 |
-
}
|
| 335 |
-
if (kStoreD1) {
|
| 336 |
-
dest1.store(output_fragment[1]);
|
| 337 |
-
++dest1;
|
| 338 |
-
}
|
| 339 |
-
if (writeToD2) {
|
| 340 |
-
dest2.store(output_fragment[2]);
|
| 341 |
-
++dest2;
|
| 342 |
-
}
|
| 343 |
-
}
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
private:
|
| 347 |
-
|
| 348 |
-
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
| 349 |
-
|
| 350 |
-
template<class Seq>
|
| 351 |
-
struct acc2smem_source_needed;
|
| 352 |
-
|
| 353 |
-
template <size_t... Seq>
|
| 354 |
-
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
| 355 |
-
template<int Advance>
|
| 356 |
-
CUTLASS_DEVICE
|
| 357 |
-
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
| 358 |
-
WarpTileIterator &warp_tile_iterator) {
|
| 359 |
-
CUTLASS_PRAGMA_UNROLL
|
| 360 |
-
for (int i = 0; i < Advance; i++) {
|
| 361 |
-
++accum_fragment_iterator;
|
| 362 |
-
}
|
| 363 |
-
|
| 364 |
-
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
| 365 |
-
accum_fragment_iterator.load(accum_fragment);
|
| 366 |
-
warp_tile_iterator.store(accum_fragment);
|
| 367 |
-
}
|
| 368 |
-
|
| 369 |
-
CUTLASS_DEVICE
|
| 370 |
-
static void push(size_t pos,
|
| 371 |
-
AccumulatorFragmentIterator const &iterator_begin,
|
| 372 |
-
WarpTileIterator &warp_tile_iterator) {
|
| 373 |
-
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
| 374 |
-
}
|
| 375 |
-
};
|
| 376 |
-
|
| 377 |
-
/// Helper to invoke the output functor over each vector of output
|
| 378 |
-
CUTLASS_DEVICE
|
| 379 |
-
void apply_output_operator_(
|
| 380 |
-
typename OutputTileIterator::Fragment (&output_fragment)[3],
|
| 381 |
-
OutputOp0 const &output_op0,
|
| 382 |
-
OutputOp1 const &output_op1,
|
| 383 |
-
OutputOp2 const &output_op2,
|
| 384 |
-
typename SharedLoadIterator::Fragment const& aligned_accum_fragment0,
|
| 385 |
-
typename SharedLoadIterator::Fragment const& aligned_accum_fragment1,
|
| 386 |
-
typename OutputTileIterator::Fragment const (&source_fragment)[2]) {
|
| 387 |
-
|
| 388 |
-
OutputAccessType* output_frag_ptr[3] = {
|
| 389 |
-
reinterpret_cast<OutputAccessType *>(&output_fragment[0]),
|
| 390 |
-
reinterpret_cast<OutputAccessType *>(&output_fragment[1]),
|
| 391 |
-
reinterpret_cast<OutputAccessType *>(&output_fragment[2])
|
| 392 |
-
};
|
| 393 |
-
|
| 394 |
-
AccumulatorAccessType const *compute_frag_ptr[2] = {
|
| 395 |
-
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment0),
|
| 396 |
-
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment1)
|
| 397 |
-
};
|
| 398 |
-
|
| 399 |
-
OutputAccessType const *source_frag_ptr[2] = {
|
| 400 |
-
reinterpret_cast<OutputAccessType const *>(&source_fragment[0]),
|
| 401 |
-
reinterpret_cast<OutputAccessType const *>(&source_fragment[1])
|
| 402 |
-
};
|
| 403 |
-
|
| 404 |
-
int const kOutputOpIterations =
|
| 405 |
-
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
| 406 |
-
|
| 407 |
-
CUTLASS_PRAGMA_UNROLL
|
| 408 |
-
for (int i = 0; i < kOutputOpIterations; ++i) {
|
| 409 |
-
|
| 410 |
-
// Call the output operators
|
| 411 |
-
output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]);
|
| 412 |
-
output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]);
|
| 413 |
-
output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]);
|
| 414 |
-
}
|
| 415 |
-
}
|
| 416 |
-
};
|
| 417 |
-
|
| 418 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 419 |
-
|
| 420 |
-
} // namespace threadblock
|
| 421 |
-
} // namespace epilogue
|
| 422 |
-
} // namespace cutlass
|
| 423 |
-
|
| 424 |
-
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/aligned_buffer.h"
|
| 38 |
-
#include "cutlass/arch/memory.h"
|
| 39 |
-
#include "cutlass/array.h"
|
| 40 |
-
#include "cutlass/cutlass.h"
|
| 41 |
-
#include "cutlass/gemm/gemm.h"
|
| 42 |
-
#include "cutlass/matrix_shape.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
|
| 45 |
-
#include "cutlass/gemm/threadblock/mma_base.h"
|
| 46 |
-
|
| 47 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 48 |
-
|
| 49 |
-
namespace cutlass {
|
| 50 |
-
namespace gemm {
|
| 51 |
-
namespace threadblock {
|
| 52 |
-
|
| 53 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 54 |
-
|
| 55 |
-
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
| 56 |
-
/// instructions.
|
| 57 |
-
template <
|
| 58 |
-
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 59 |
-
typename Shape_,
|
| 60 |
-
/// Policy describing tuning details (concept: MmaPolicy)
|
| 61 |
-
typename Policy0_,
|
| 62 |
-
/// B1-specific version of the policy (concept: MmaPolicy)
|
| 63 |
-
typename Policy1_,
|
| 64 |
-
/// Number of stages,
|
| 65 |
-
int Stages,
|
| 66 |
-
/// Used for partial specialization
|
| 67 |
-
typename Enable = bool>
|
| 68 |
-
class DualMmaBase {
|
| 69 |
-
public:
|
| 70 |
-
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 71 |
-
using Shape = Shape_;
|
| 72 |
-
|
| 73 |
-
///< Policy describing tuning details
|
| 74 |
-
using Policy0 = Policy0_;
|
| 75 |
-
using Policy1 = Policy1_;
|
| 76 |
-
|
| 77 |
-
//
|
| 78 |
-
// Dependent types
|
| 79 |
-
//
|
| 80 |
-
|
| 81 |
-
/// Warp-level Mma
|
| 82 |
-
using Operator0 = typename Policy0::Operator;
|
| 83 |
-
using Operator1 = typename Policy1::Operator;
|
| 84 |
-
|
| 85 |
-
/// Shape describing the overall GEMM computed from shared memory
|
| 86 |
-
/// by each warp.
|
| 87 |
-
using WarpGemm = typename Policy0::Operator::Shape;
|
| 88 |
-
|
| 89 |
-
/// Shape describing the number of warps filling the CTA
|
| 90 |
-
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
| 91 |
-
Shape::kN / WarpGemm::kN,
|
| 92 |
-
Shape::kK / WarpGemm::kK>;
|
| 93 |
-
|
| 94 |
-
/// Number of warp-level GEMM oeprations
|
| 95 |
-
static int const kWarpGemmIterations =
|
| 96 |
-
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);
|
| 97 |
-
|
| 98 |
-
/// Number of stages
|
| 99 |
-
static int const kStages = Stages;
|
| 100 |
-
|
| 101 |
-
/// Tensor reference to the A operand
|
| 102 |
-
using TensorRefA = TensorRef<typename Operator0::ElementA, typename Operator0::LayoutA>;
|
| 103 |
-
|
| 104 |
-
/// Tensor reference to the B operand
|
| 105 |
-
using TensorRefB0 = TensorRef<typename Operator0::ElementB, typename Operator0::LayoutB>;
|
| 106 |
-
using TensorRefB1 = TensorRef<typename Operator1::ElementB, typename Operator1::LayoutB>;
|
| 107 |
-
|
| 108 |
-
static_assert(kWarpGemmIterations > 1,
|
| 109 |
-
"The pipelined structure requires at least two warp-level "
|
| 110 |
-
"GEMM operations.");
|
| 111 |
-
|
| 112 |
-
static_assert((kWarpGemmIterations % 2) == 0,
|
| 113 |
-
"Inner loop iteration must be an even number.");
|
| 114 |
-
|
| 115 |
-
//
|
| 116 |
-
// Nested structs
|
| 117 |
-
//
|
| 118 |
-
|
| 119 |
-
/// Shared storage object needed by threadblock-scoped GEMM
|
| 120 |
-
class SharedStorage {
|
| 121 |
-
public:
|
| 122 |
-
//
|
| 123 |
-
// Type definitions
|
| 124 |
-
//
|
| 125 |
-
|
| 126 |
-
/// Shape of the A matrix operand in shared memory
|
| 127 |
-
using ShapeA = MatrixShape<Shape::kM + Policy0::SmemPaddingA::kRow,
|
| 128 |
-
Shape::kK * kStages +
|
| 129 |
-
Policy0::SmemPaddingA::kColumn>;
|
| 130 |
-
|
| 131 |
-
/// Shape of the B matrix operand in shared memory
|
| 132 |
-
using ShapeB0 =
|
| 133 |
-
MatrixShape<Shape::kK * kStages + Policy0::SmemPaddingB::kRow,
|
| 134 |
-
Shape::kN + Policy0::SmemPaddingB::kColumn>;
|
| 135 |
-
using ShapeB1 =
|
| 136 |
-
MatrixShape<Shape::kK * kStages + Policy1::SmemPaddingB::kRow,
|
| 137 |
-
Shape::kN + Policy1::SmemPaddingB::kColumn>;
|
| 138 |
-
|
| 139 |
-
public:
|
| 140 |
-
//
|
| 141 |
-
// Data members
|
| 142 |
-
//
|
| 143 |
-
|
| 144 |
-
/// Buffer for A operand
|
| 145 |
-
AlignedBuffer<typename Operator0::ElementA, ShapeA::kCount> operand_A;
|
| 146 |
-
|
| 147 |
-
/// Buffer for B operand
|
| 148 |
-
AlignedBuffer<typename Operator0::ElementB, ShapeB0::kCount> operand_B0;
|
| 149 |
-
AlignedBuffer<typename Operator1::ElementB, ShapeB1::kCount> operand_B1;
|
| 150 |
-
|
| 151 |
-
public:
|
| 152 |
-
|
| 153 |
-
//
|
| 154 |
-
// Methods
|
| 155 |
-
//
|
| 156 |
-
|
| 157 |
-
/// Returns a layout object for the A matrix
|
| 158 |
-
CUTLASS_DEVICE
|
| 159 |
-
static typename Operator0::LayoutA LayoutA() {
|
| 160 |
-
return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
/// Returns a layout object for the B matrix
|
| 164 |
-
CUTLASS_HOST_DEVICE
|
| 165 |
-
static typename Operator0::LayoutB LayoutB0() {
|
| 166 |
-
return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn});
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
/// Returns a layout object for the B matrix
|
| 170 |
-
CUTLASS_HOST_DEVICE
|
| 171 |
-
static typename Operator1::LayoutB LayoutB1() {
|
| 172 |
-
return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn});
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
/// Returns a TensorRef to the A operand
|
| 176 |
-
CUTLASS_HOST_DEVICE
|
| 177 |
-
TensorRefA operand_A_ref() {
|
| 178 |
-
return TensorRefA{operand_A.data(), LayoutA()};
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
/// Returns a TensorRef to the B operand
|
| 182 |
-
CUTLASS_HOST_DEVICE
|
| 183 |
-
TensorRefB0 operand_B0_ref() {
|
| 184 |
-
return TensorRefB0{operand_B0.data(), LayoutB0()};
|
| 185 |
-
}
|
| 186 |
-
CUTLASS_HOST_DEVICE
|
| 187 |
-
TensorRefB1 operand_B1_ref() {
|
| 188 |
-
return TensorRefB1{operand_B1.data(), LayoutB1()};
|
| 189 |
-
}
|
| 190 |
-
};
|
| 191 |
-
|
| 192 |
-
protected:
|
| 193 |
-
|
| 194 |
-
//
|
| 195 |
-
// Data members
|
| 196 |
-
//
|
| 197 |
-
|
| 198 |
-
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
| 199 |
-
typename Operator0::IteratorA warp_tile_iterator_A_;
|
| 200 |
-
|
| 201 |
-
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
| 202 |
-
typename Operator0::IteratorB warp_tile_iterator_B0_;
|
| 203 |
-
typename Operator1::IteratorB warp_tile_iterator_B1_;
|
| 204 |
-
|
| 205 |
-
public:
|
| 206 |
-
|
| 207 |
-
/// Construct from tensor references
|
| 208 |
-
CUTLASS_DEVICE
|
| 209 |
-
DualMmaBase(
|
| 210 |
-
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
| 211 |
-
SharedStorage &shared_storage,
|
| 212 |
-
///< ID within the threadblock
|
| 213 |
-
int thread_idx,
|
| 214 |
-
///< ID of warp
|
| 215 |
-
int warp_idx,
|
| 216 |
-
///< ID of each thread within a warp
|
| 217 |
-
int lane_idx
|
| 218 |
-
):
|
| 219 |
-
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
| 220 |
-
warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx),
|
| 221 |
-
warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) {
|
| 222 |
-
|
| 223 |
-
}
|
| 224 |
-
};
|
| 225 |
-
|
| 226 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 227 |
-
|
| 228 |
-
} // namespace threadblock
|
| 229 |
-
} // namespace gemm
|
| 230 |
-
} // namespace cutlass
|
| 231 |
-
|
| 232 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h
DELETED
|
@@ -1,775 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/aligned_buffer.h"
|
| 38 |
-
#include "cutlass/arch/memory.h"
|
| 39 |
-
#include "cutlass/array.h"
|
| 40 |
-
#include "cutlass/cutlass.h"
|
| 41 |
-
#include "cutlass/gemm/gemm.h"
|
| 42 |
-
#include "cutlass/matrix_shape.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
|
| 45 |
-
#include "cutlass/gemm/threadblock/mma_base.h"
|
| 46 |
-
#include "dual_mma_base.h"
|
| 47 |
-
|
| 48 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
-
|
| 50 |
-
namespace cutlass {
|
| 51 |
-
namespace gemm {
|
| 52 |
-
namespace threadblock {
|
| 53 |
-
|
| 54 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
-
|
| 56 |
-
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
| 57 |
-
/// instructions.
|
| 58 |
-
template <
|
| 59 |
-
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 60 |
-
typename Shape_,
|
| 61 |
-
/// Iterates over tiles of A operand in global memory
|
| 62 |
-
// (concept: ReadableTileIterator | ForwardTileIterator |
|
| 63 |
-
// MaskedTileIterator)
|
| 64 |
-
typename IteratorA_,
|
| 65 |
-
/// Iterates over tiles of A operand in shared memory
|
| 66 |
-
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 67 |
-
typename SmemIteratorA_,
|
| 68 |
-
/// Cache operation for operand A
|
| 69 |
-
cutlass::arch::CacheOperation::Kind CacheOpA,
|
| 70 |
-
/// Iterates over tiles of B0 operand in global memory
|
| 71 |
-
// (concept: ReadableTileIterator | ForwardTileIterator |
|
| 72 |
-
// MaskedTileIterator)
|
| 73 |
-
typename IteratorB0_,
|
| 74 |
-
/// Iterates over tiles of B0 operand in shared memory
|
| 75 |
-
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 76 |
-
typename SmemIteratorB0_,
|
| 77 |
-
/// Cache operation for operand B
|
| 78 |
-
cutlass::arch::CacheOperation::Kind CacheOpB,
|
| 79 |
-
/// Iterates over tiles of B1 operand in global memory
|
| 80 |
-
// (concept: ReadableTileIterator | ForwardTileIterator |
|
| 81 |
-
// MaskedTileIterator)
|
| 82 |
-
typename IteratorB1_,
|
| 83 |
-
/// Iterates over tiles of B1 operand in shared memory
|
| 84 |
-
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 85 |
-
typename SmemIteratorB1_,
|
| 86 |
-
/// Data type of accumulator matrix
|
| 87 |
-
typename ElementC_,
|
| 88 |
-
/// Data type of accumulator matrix
|
| 89 |
-
typename LayoutC_,
|
| 90 |
-
/// Policy describing tuning details (concept: MmaPolicy)
|
| 91 |
-
typename Policy0_,
|
| 92 |
-
/// B1-specific version of the policy (concept: MmaPolicy)
|
| 93 |
-
typename Policy1_,
|
| 94 |
-
/// Number of stages,
|
| 95 |
-
int Stages,
|
| 96 |
-
/// Use zfill or predicate for out-of-bound cp.async
|
| 97 |
-
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
| 98 |
-
/// Used for partial specialization
|
| 99 |
-
typename Enable = bool>
|
| 100 |
-
class DualMmaMultistage :
|
| 101 |
-
public DualMmaBase<Shape_, Policy0_, Policy1_, Stages> {
|
| 102 |
-
public:
|
| 103 |
-
///< Base class
|
| 104 |
-
using Base = DualMmaBase<Shape_, Policy0_, Policy1_, Stages>;
|
| 105 |
-
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 106 |
-
using Shape = Shape_;
|
| 107 |
-
///< Iterates over tiles of A operand in global memory
|
| 108 |
-
using IteratorA = IteratorA_;
|
| 109 |
-
///< Iterates over tiles of B0 operand in global memory
|
| 110 |
-
using IteratorB0 = IteratorB0_;
|
| 111 |
-
///< Iterates over tiles of B1 operand in global memory
|
| 112 |
-
using IteratorB1 = IteratorB1_;
|
| 113 |
-
///< Data type of accumulator matrix
|
| 114 |
-
using ElementC = ElementC_;
|
| 115 |
-
///< Layout of accumulator matrix
|
| 116 |
-
using LayoutC = LayoutC_;
|
| 117 |
-
///< Policy describing tuning details
|
| 118 |
-
using Policy0 = Policy0_;
|
| 119 |
-
using Policy1 = Policy1_;
|
| 120 |
-
|
| 121 |
-
using SmemIteratorA = SmemIteratorA_;
|
| 122 |
-
using SmemIteratorB0 = SmemIteratorB0_;
|
| 123 |
-
using SmemIteratorB1 = SmemIteratorB1_;
|
| 124 |
-
|
| 125 |
-
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
| 126 |
-
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
| 127 |
-
|
| 128 |
-
//
|
| 129 |
-
// Dependent types
|
| 130 |
-
//
|
| 131 |
-
|
| 132 |
-
/// Fragment of accumulator tile
|
| 133 |
-
using FragmentC = typename Policy0::Operator::FragmentC;
|
| 134 |
-
|
| 135 |
-
/// Warp-level Mma
|
| 136 |
-
using Operator0 = typename Policy0::Operator;
|
| 137 |
-
using Operator1 = typename Policy1::Operator;
|
| 138 |
-
|
| 139 |
-
/// Minimum architecture is Sm80 to support cp.async
|
| 140 |
-
using ArchTag = arch::Sm80;
|
| 141 |
-
|
| 142 |
-
/// Complex transform on A operand
|
| 143 |
-
static ComplexTransform const kTransformA = Operator0::kTransformA;
|
| 144 |
-
|
| 145 |
-
/// Complex transform on B operand
|
| 146 |
-
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
| 147 |
-
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
| 148 |
-
|
| 149 |
-
/// Internal structure exposed for introspection.
|
| 150 |
-
struct Detail {
|
| 151 |
-
|
| 152 |
-
/// Number of cp.async instructions to load one stage of operand A
|
| 153 |
-
static int const AsyncCopyIterationsPerStageA =
|
| 154 |
-
IteratorA::ThreadMap::Iterations::kCount;
|
| 155 |
-
|
| 156 |
-
/// Number of cp.async instructions to load one stage of operand B
|
| 157 |
-
static int const AsyncCopyIterationsPerStageB =
|
| 158 |
-
IteratorB0::ThreadMap::Iterations::kCount;
|
| 159 |
-
|
| 160 |
-
/// Number of stages
|
| 161 |
-
static int const kStages = Stages;
|
| 162 |
-
|
| 163 |
-
/// Number of cp.async instructions to load on group of operand A
|
| 164 |
-
static int const kAccessesPerGroupA =
|
| 165 |
-
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
| 166 |
-
|
| 167 |
-
/// Number of cp.async instructions to load on group of operand B
|
| 168 |
-
static int const kAccessesPerGroupB =
|
| 169 |
-
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
| 170 |
-
};
|
| 171 |
-
|
| 172 |
-
private:
|
| 173 |
-
|
| 174 |
-
using WarpLoadedFragmentA = typename Operator0::FragmentA;
|
| 175 |
-
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
| 176 |
-
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
| 177 |
-
using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA;
|
| 178 |
-
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
| 179 |
-
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
| 180 |
-
|
| 181 |
-
private:
|
| 182 |
-
|
| 183 |
-
//
|
| 184 |
-
// Data members
|
| 185 |
-
//
|
| 186 |
-
|
| 187 |
-
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
| 188 |
-
SmemIteratorA smem_iterator_A_;
|
| 189 |
-
|
| 190 |
-
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
| 191 |
-
SmemIteratorB0 smem_iterator_B0_;
|
| 192 |
-
SmemIteratorB1 smem_iterator_B1_;
|
| 193 |
-
|
| 194 |
-
public:
|
| 195 |
-
|
| 196 |
-
/// Construct from tensor references
|
| 197 |
-
CUTLASS_DEVICE
|
| 198 |
-
DualMmaMultistage(
|
| 199 |
-
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
| 200 |
-
typename Base::SharedStorage &shared_storage,
|
| 201 |
-
///< ID within the threadblock
|
| 202 |
-
int thread_idx,
|
| 203 |
-
///< ID of warp
|
| 204 |
-
int warp_idx,
|
| 205 |
-
///< ID of each thread within a warp
|
| 206 |
-
int lane_idx
|
| 207 |
-
):
|
| 208 |
-
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
| 209 |
-
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
| 210 |
-
smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx),
|
| 211 |
-
smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx)
|
| 212 |
-
{
|
| 213 |
-
// Compute warp location within threadblock tile by mapping the warp_id to
|
| 214 |
-
// three coordinates:
|
| 215 |
-
// _m: the warp's position within the threadblock along the M dimension
|
| 216 |
-
// _n: the warp's position within the threadblock along the N dimension
|
| 217 |
-
// _k: the warp's position within the threadblock along the K dimension
|
| 218 |
-
|
| 219 |
-
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
| 220 |
-
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
| 221 |
-
|
| 222 |
-
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
| 223 |
-
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
| 224 |
-
|
| 225 |
-
// Add per-warp offsets in units of warp-level tiles
|
| 226 |
-
this->warp_tile_iterator_A_.add_tile_offset(
|
| 227 |
-
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
| 228 |
-
this->warp_tile_iterator_B0_.add_tile_offset(
|
| 229 |
-
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
| 230 |
-
this->warp_tile_iterator_B1_.add_tile_offset(
|
| 231 |
-
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
CUTLASS_DEVICE
|
| 235 |
-
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1,
|
| 236 |
-
int group_start_A = 0, int group_start_B = 0) {
|
| 237 |
-
iterator_A.set_iteration_index(group_start_A *
|
| 238 |
-
IteratorA::kAccessesPerVector);
|
| 239 |
-
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
| 240 |
-
|
| 241 |
-
// Async Copy for operand A
|
| 242 |
-
CUTLASS_PRAGMA_UNROLL
|
| 243 |
-
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
| 244 |
-
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
| 245 |
-
typename IteratorA::AccessType *dst_ptr =
|
| 246 |
-
reinterpret_cast<typename IteratorA::AccessType *>(
|
| 247 |
-
this->smem_iterator_A_.get());
|
| 248 |
-
|
| 249 |
-
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
| 250 |
-
IteratorA::ThreadMap::kElementsPerAccess /
|
| 251 |
-
IteratorA::kAccessesPerVector / 8;
|
| 252 |
-
|
| 253 |
-
CUTLASS_PRAGMA_UNROLL
|
| 254 |
-
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
| 255 |
-
auto gmem_ptr = iterator_A.get();
|
| 256 |
-
|
| 257 |
-
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
| 258 |
-
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
| 259 |
-
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
| 260 |
-
} else {
|
| 261 |
-
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
| 262 |
-
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
-
++iterator_A;
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
++this->smem_iterator_A_;
|
| 269 |
-
}
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
iterator_B0.set_iteration_index(group_start_B *
|
| 273 |
-
IteratorB0::kAccessesPerVector);
|
| 274 |
-
iterator_B1.set_iteration_index(group_start_B *
|
| 275 |
-
IteratorB1::kAccessesPerVector);
|
| 276 |
-
this->smem_iterator_B0_.set_iteration_index(group_start_B);
|
| 277 |
-
this->smem_iterator_B1_.set_iteration_index(group_start_B);
|
| 278 |
-
|
| 279 |
-
// Async Copy for operand B0
|
| 280 |
-
CUTLASS_PRAGMA_UNROLL
|
| 281 |
-
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
| 282 |
-
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
| 283 |
-
typename IteratorB0::AccessType *dst_ptr =
|
| 284 |
-
reinterpret_cast<typename IteratorB0::AccessType *>(
|
| 285 |
-
this->smem_iterator_B0_.get());
|
| 286 |
-
|
| 287 |
-
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
| 288 |
-
IteratorB0::ThreadMap::kElementsPerAccess /
|
| 289 |
-
IteratorB0::kAccessesPerVector / 8;
|
| 290 |
-
|
| 291 |
-
CUTLASS_PRAGMA_UNROLL
|
| 292 |
-
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
| 293 |
-
auto gmem_ptr = iterator_B0.get();
|
| 294 |
-
|
| 295 |
-
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
| 296 |
-
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
| 297 |
-
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
| 298 |
-
} else {
|
| 299 |
-
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
| 300 |
-
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
++iterator_B0;
|
| 304 |
-
}
|
| 305 |
-
++this->smem_iterator_B0_;
|
| 306 |
-
}
|
| 307 |
-
}
|
| 308 |
-
// Async Copy for operand B1
|
| 309 |
-
CUTLASS_PRAGMA_UNROLL
|
| 310 |
-
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
| 311 |
-
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
| 312 |
-
typename IteratorB1::AccessType *dst_ptr =
|
| 313 |
-
reinterpret_cast<typename IteratorB1::AccessType *>(
|
| 314 |
-
this->smem_iterator_B1_.get());
|
| 315 |
-
|
| 316 |
-
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
| 317 |
-
IteratorB1::ThreadMap::kElementsPerAccess /
|
| 318 |
-
IteratorB1::kAccessesPerVector / 8;
|
| 319 |
-
|
| 320 |
-
CUTLASS_PRAGMA_UNROLL
|
| 321 |
-
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
| 322 |
-
auto gmem_ptr = iterator_B1.get();
|
| 323 |
-
|
| 324 |
-
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
| 325 |
-
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
| 326 |
-
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
| 327 |
-
} else {
|
| 328 |
-
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
| 329 |
-
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
| 330 |
-
}
|
| 331 |
-
|
| 332 |
-
++iterator_B1;
|
| 333 |
-
}
|
| 334 |
-
++this->smem_iterator_B1_;
|
| 335 |
-
}
|
| 336 |
-
}
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
/// Perform a threadblock-scoped matrix multiply-accumulate
|
| 340 |
-
CUTLASS_DEVICE
|
| 341 |
-
void operator()(
|
| 342 |
-
///< problem size of GEMM
|
| 343 |
-
int gemm_k_iterations,
|
| 344 |
-
///< destination accumulator tile
|
| 345 |
-
FragmentC &accum0,
|
| 346 |
-
FragmentC &accum1,
|
| 347 |
-
///< iterator over A operand in global memory
|
| 348 |
-
IteratorA iterator_A,
|
| 349 |
-
///< iterator over B operand in global memory
|
| 350 |
-
IteratorB0 iterator_B0,
|
| 351 |
-
IteratorB1 iterator_B1,
|
| 352 |
-
///< initial value of accumulator
|
| 353 |
-
FragmentC const &src_accum0,
|
| 354 |
-
FragmentC const &src_accum1
|
| 355 |
-
) {
|
| 356 |
-
|
| 357 |
-
//
|
| 358 |
-
// Prologue
|
| 359 |
-
//
|
| 360 |
-
|
| 361 |
-
// Issue several complete stages
|
| 362 |
-
CUTLASS_PRAGMA_UNROLL
|
| 363 |
-
for (int stage = 0; stage < Base::kStages - 1;
|
| 364 |
-
++stage, --gemm_k_iterations) {
|
| 365 |
-
|
| 366 |
-
iterator_A.clear_mask(gemm_k_iterations == 0);
|
| 367 |
-
iterator_B0.clear_mask(gemm_k_iterations == 0);
|
| 368 |
-
iterator_B1.clear_mask(gemm_k_iterations == 0);
|
| 369 |
-
|
| 370 |
-
iterator_A.set_iteration_index(0);
|
| 371 |
-
this->smem_iterator_A_.set_iteration_index(0);
|
| 372 |
-
|
| 373 |
-
// Async Copy for operand A
|
| 374 |
-
CUTLASS_PRAGMA_UNROLL
|
| 375 |
-
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
| 376 |
-
typename IteratorA::AccessType *dst_ptr =
|
| 377 |
-
reinterpret_cast<typename IteratorA::AccessType *>(
|
| 378 |
-
this->smem_iterator_A_.get());
|
| 379 |
-
|
| 380 |
-
CUTLASS_PRAGMA_UNROLL
|
| 381 |
-
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
| 382 |
-
int const kSrcBytes =
|
| 383 |
-
sizeof_bits<typename IteratorA::Element>::value *
|
| 384 |
-
IteratorA::ThreadMap::kElementsPerAccess /
|
| 385 |
-
IteratorA::kAccessesPerVector / 8;
|
| 386 |
-
|
| 387 |
-
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
| 388 |
-
|
| 389 |
-
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
| 390 |
-
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
| 391 |
-
|
| 392 |
-
++iterator_A;
|
| 393 |
-
}
|
| 394 |
-
|
| 395 |
-
++this->smem_iterator_A_;
|
| 396 |
-
}
|
| 397 |
-
|
| 398 |
-
iterator_B0.set_iteration_index(0);
|
| 399 |
-
iterator_B1.set_iteration_index(0);
|
| 400 |
-
this->smem_iterator_B0_.set_iteration_index(0);
|
| 401 |
-
this->smem_iterator_B1_.set_iteration_index(0);
|
| 402 |
-
|
| 403 |
-
// Async Copy for operand B0
|
| 404 |
-
CUTLASS_PRAGMA_UNROLL
|
| 405 |
-
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
| 406 |
-
typename IteratorB0::AccessType *dst_ptr =
|
| 407 |
-
reinterpret_cast<typename IteratorB0::AccessType *>(
|
| 408 |
-
this->smem_iterator_B0_.get());
|
| 409 |
-
|
| 410 |
-
CUTLASS_PRAGMA_UNROLL
|
| 411 |
-
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
| 412 |
-
int const kSrcBytes =
|
| 413 |
-
sizeof_bits<typename IteratorB0::Element>::value *
|
| 414 |
-
IteratorB0::ThreadMap::kElementsPerAccess /
|
| 415 |
-
IteratorB0::kAccessesPerVector / 8;
|
| 416 |
-
|
| 417 |
-
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
| 418 |
-
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
| 419 |
-
|
| 420 |
-
++iterator_B0;
|
| 421 |
-
}
|
| 422 |
-
|
| 423 |
-
++this->smem_iterator_B0_;
|
| 424 |
-
}
|
| 425 |
-
// Async Copy for operand B1
|
| 426 |
-
CUTLASS_PRAGMA_UNROLL
|
| 427 |
-
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
| 428 |
-
typename IteratorB1::AccessType *dst_ptr =
|
| 429 |
-
reinterpret_cast<typename IteratorB1::AccessType *>(
|
| 430 |
-
this->smem_iterator_B1_.get());
|
| 431 |
-
|
| 432 |
-
CUTLASS_PRAGMA_UNROLL
|
| 433 |
-
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
| 434 |
-
int const kSrcBytes =
|
| 435 |
-
sizeof_bits<typename IteratorB1::Element>::value *
|
| 436 |
-
IteratorB1::ThreadMap::kElementsPerAccess /
|
| 437 |
-
IteratorB1::kAccessesPerVector / 8;
|
| 438 |
-
|
| 439 |
-
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
| 440 |
-
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
| 441 |
-
|
| 442 |
-
++iterator_B1;
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
++this->smem_iterator_B1_;
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
// Move to the next stage
|
| 449 |
-
iterator_A.add_tile_offset({0, 1});
|
| 450 |
-
iterator_B0.add_tile_offset({1, 0});
|
| 451 |
-
iterator_B1.add_tile_offset({1, 0});
|
| 452 |
-
|
| 453 |
-
this->smem_iterator_A_.add_tile_offset({0, 1});
|
| 454 |
-
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
| 455 |
-
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
| 456 |
-
|
| 457 |
-
// Defines the boundary of a stage of cp.async.
|
| 458 |
-
cutlass::arch::cp_async_fence();
|
| 459 |
-
}
|
| 460 |
-
|
| 461 |
-
// Perform accumulation in the 'd' output operand
|
| 462 |
-
accum0 = src_accum0;
|
| 463 |
-
accum1 = src_accum1;
|
| 464 |
-
|
| 465 |
-
//
|
| 466 |
-
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
| 467 |
-
// so that all accumulator elements outside the GEMM footprint are zero.
|
| 468 |
-
//
|
| 469 |
-
|
| 470 |
-
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
| 471 |
-
|
| 472 |
-
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
| 473 |
-
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
| 474 |
-
|
| 475 |
-
typename IteratorA::AccessType zero_A;
|
| 476 |
-
zero_A.clear();
|
| 477 |
-
|
| 478 |
-
last_smem_iterator_A.set_iteration_index(0);
|
| 479 |
-
|
| 480 |
-
// Async Copy for operand A
|
| 481 |
-
CUTLASS_PRAGMA_UNROLL
|
| 482 |
-
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
| 483 |
-
|
| 484 |
-
typename IteratorA::AccessType *dst_ptr =
|
| 485 |
-
reinterpret_cast<typename IteratorA::AccessType *>(
|
| 486 |
-
last_smem_iterator_A.get());
|
| 487 |
-
|
| 488 |
-
*dst_ptr = zero_A;
|
| 489 |
-
|
| 490 |
-
++last_smem_iterator_A;
|
| 491 |
-
}
|
| 492 |
-
|
| 493 |
-
typename IteratorB0::AccessType zero_B;
|
| 494 |
-
zero_B.clear();
|
| 495 |
-
|
| 496 |
-
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
| 497 |
-
SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_);
|
| 498 |
-
last_smem_iterator_B0.set_iteration_index(0);
|
| 499 |
-
|
| 500 |
-
// Async Copy for operand B0
|
| 501 |
-
CUTLASS_PRAGMA_UNROLL
|
| 502 |
-
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
| 503 |
-
typename IteratorB0::AccessType *dst_ptr =
|
| 504 |
-
reinterpret_cast<typename IteratorB0::AccessType *>(
|
| 505 |
-
last_smem_iterator_B0.get());
|
| 506 |
-
|
| 507 |
-
*dst_ptr = zero_B;
|
| 508 |
-
|
| 509 |
-
++last_smem_iterator_B0;
|
| 510 |
-
}
|
| 511 |
-
|
| 512 |
-
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
| 513 |
-
SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_);
|
| 514 |
-
last_smem_iterator_B1.set_iteration_index(0);
|
| 515 |
-
|
| 516 |
-
// Async Copy for operand B1
|
| 517 |
-
CUTLASS_PRAGMA_UNROLL
|
| 518 |
-
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
| 519 |
-
|
| 520 |
-
typename IteratorB1::AccessType *dst_ptr =
|
| 521 |
-
reinterpret_cast<typename IteratorB1::AccessType *>(
|
| 522 |
-
last_smem_iterator_B1.get());
|
| 523 |
-
|
| 524 |
-
*dst_ptr = zero_B;
|
| 525 |
-
|
| 526 |
-
++last_smem_iterator_B1;
|
| 527 |
-
}
|
| 528 |
-
}
|
| 529 |
-
|
| 530 |
-
// Waits until stages up to the previous (kStages-2)th stage have committed.
|
| 531 |
-
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
| 532 |
-
__syncthreads();
|
| 533 |
-
|
| 534 |
-
// Pair of fragments used to overlap shared memory loads and math
|
| 535 |
-
// instructions
|
| 536 |
-
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
| 537 |
-
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
| 538 |
-
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
| 539 |
-
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
| 540 |
-
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
| 541 |
-
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
| 542 |
-
|
| 543 |
-
Operator0 warp_mma0;
|
| 544 |
-
Operator1 warp_mma1;
|
| 545 |
-
|
| 546 |
-
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
| 547 |
-
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
| 548 |
-
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
| 549 |
-
|
| 550 |
-
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
| 551 |
-
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
| 552 |
-
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
| 553 |
-
|
| 554 |
-
++this->warp_tile_iterator_A_;
|
| 555 |
-
++this->warp_tile_iterator_B0_;
|
| 556 |
-
++this->warp_tile_iterator_B1_;
|
| 557 |
-
|
| 558 |
-
iterator_A.clear_mask(gemm_k_iterations == 0);
|
| 559 |
-
iterator_B0.clear_mask(gemm_k_iterations == 0);
|
| 560 |
-
iterator_B1.clear_mask(gemm_k_iterations == 0);
|
| 561 |
-
|
| 562 |
-
int smem_write_stage_idx = Base::kStages - 1;
|
| 563 |
-
int smem_read_stage_idx = 0;
|
| 564 |
-
|
| 565 |
-
warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0],
|
| 566 |
-
warp_loaded_frag_A[0], warp_loaded_frag_B0[0]);
|
| 567 |
-
warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0],
|
| 568 |
-
warp_loaded_frag_A[0], warp_loaded_frag_B1[0]);
|
| 569 |
-
|
| 570 |
-
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
|
| 571 |
-
// accumulator and this temporary accumulator is added to the final
|
| 572 |
-
// accumulator once in every mainloop iteration.
|
| 573 |
-
plus<FragmentC> plus_accum;
|
| 574 |
-
|
| 575 |
-
FragmentC tmp_accum0, tmp_accum1;
|
| 576 |
-
|
| 577 |
-
if (platform::is_same<typename Operator0::MathOperator,
|
| 578 |
-
arch::OpMultiplyAddFastF32>::value
|
| 579 |
-
|| platform::is_same<typename Operator0::MathOperator,
|
| 580 |
-
arch::OpMultiplyAddComplexFastF32>::value) {
|
| 581 |
-
|
| 582 |
-
tmp_accum0.clear();
|
| 583 |
-
tmp_accum1.clear();
|
| 584 |
-
}
|
| 585 |
-
|
| 586 |
-
//
|
| 587 |
-
// Mainloop
|
| 588 |
-
//
|
| 589 |
-
|
| 590 |
-
CUTLASS_GEMM_LOOP
|
| 591 |
-
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
| 592 |
-
//
|
| 593 |
-
// Loop over GEMM K dimension
|
| 594 |
-
//
|
| 595 |
-
|
| 596 |
-
// Computes a warp-level GEMM on data held in shared memory
|
| 597 |
-
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
| 598 |
-
CUTLASS_PRAGMA_UNROLL
|
| 599 |
-
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
| 600 |
-
++warp_mma_k) {
|
| 601 |
-
|
| 602 |
-
// Load warp-level tiles from shared memory, wrapping to k offset if
|
| 603 |
-
// this is the last group as the case may be.
|
| 604 |
-
|
| 605 |
-
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
| 606 |
-
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
| 607 |
-
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
| 608 |
-
|
| 609 |
-
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
| 610 |
-
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
| 611 |
-
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
| 612 |
-
|
| 613 |
-
++this->warp_tile_iterator_A_;
|
| 614 |
-
++this->warp_tile_iterator_B0_;
|
| 615 |
-
++this->warp_tile_iterator_B1_;
|
| 616 |
-
|
| 617 |
-
if (warp_mma_k > 0) {
|
| 618 |
-
warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
| 619 |
-
warp_transformed_frag_B0[warp_mma_k % 2],
|
| 620 |
-
warp_loaded_frag_A[warp_mma_k % 2],
|
| 621 |
-
warp_loaded_frag_B0[warp_mma_k % 2]);
|
| 622 |
-
warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
| 623 |
-
warp_transformed_frag_B1[warp_mma_k % 2],
|
| 624 |
-
warp_loaded_frag_A[warp_mma_k % 2],
|
| 625 |
-
warp_loaded_frag_B1[warp_mma_k % 2]);
|
| 626 |
-
}
|
| 627 |
-
|
| 628 |
-
if (platform::is_same<typename Operator0::MathOperator,
|
| 629 |
-
arch::OpMultiplyAddFastF32>::value
|
| 630 |
-
|| platform::is_same<typename Operator0::MathOperator,
|
| 631 |
-
arch::OpMultiplyAddComplexFastF32>::value) {
|
| 632 |
-
|
| 633 |
-
warp_mma0(
|
| 634 |
-
tmp_accum0,
|
| 635 |
-
warp_transformed_frag_A[warp_mma_k % 2],
|
| 636 |
-
warp_transformed_frag_B0[warp_mma_k % 2],
|
| 637 |
-
tmp_accum0
|
| 638 |
-
);
|
| 639 |
-
warp_mma1(
|
| 640 |
-
tmp_accum1,
|
| 641 |
-
warp_transformed_frag_A[warp_mma_k % 2],
|
| 642 |
-
warp_transformed_frag_B1[warp_mma_k % 2],
|
| 643 |
-
tmp_accum1
|
| 644 |
-
);
|
| 645 |
-
|
| 646 |
-
if (warp_mma_k == 0) {
|
| 647 |
-
accum0 = plus_accum(accum0, tmp_accum0);
|
| 648 |
-
accum1 = plus_accum(accum1, tmp_accum1);
|
| 649 |
-
tmp_accum0.clear();
|
| 650 |
-
tmp_accum1.clear();
|
| 651 |
-
}
|
| 652 |
-
} else {
|
| 653 |
-
warp_mma0(
|
| 654 |
-
accum0,
|
| 655 |
-
warp_transformed_frag_A[warp_mma_k % 2],
|
| 656 |
-
warp_transformed_frag_B0[warp_mma_k % 2],
|
| 657 |
-
accum0
|
| 658 |
-
);
|
| 659 |
-
warp_mma1(
|
| 660 |
-
accum1,
|
| 661 |
-
warp_transformed_frag_A[warp_mma_k % 2],
|
| 662 |
-
warp_transformed_frag_B1[warp_mma_k % 2],
|
| 663 |
-
accum1
|
| 664 |
-
);
|
| 665 |
-
}
|
| 666 |
-
|
| 667 |
-
// Issue global->shared copies for the this stage
|
| 668 |
-
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
| 669 |
-
int group_start_iteration_A, group_start_iteration_B;
|
| 670 |
-
|
| 671 |
-
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
| 672 |
-
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
| 673 |
-
|
| 674 |
-
copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
|
| 675 |
-
group_start_iteration_B);
|
| 676 |
-
}
|
| 677 |
-
|
| 678 |
-
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
| 679 |
-
int group_start_iteration_A, group_start_iteration_B;
|
| 680 |
-
group_start_iteration_A =
|
| 681 |
-
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
| 682 |
-
group_start_iteration_B =
|
| 683 |
-
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
| 684 |
-
|
| 685 |
-
copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A,
|
| 686 |
-
group_start_iteration_B);
|
| 687 |
-
|
| 688 |
-
// Inserts a memory fence between stages of cp.async instructions.
|
| 689 |
-
cutlass::arch::cp_async_fence();
|
| 690 |
-
|
| 691 |
-
// Waits until stages up to the previous (kStages-2)th stage have committed.
|
| 692 |
-
arch::cp_async_wait<Base::kStages - 2>();
|
| 693 |
-
__syncthreads();
|
| 694 |
-
|
| 695 |
-
// Move to the next stage
|
| 696 |
-
iterator_A.add_tile_offset({0, 1});
|
| 697 |
-
iterator_B0.add_tile_offset({1, 0});
|
| 698 |
-
iterator_B1.add_tile_offset({1, 0});
|
| 699 |
-
|
| 700 |
-
this->smem_iterator_A_.add_tile_offset({0, 1});
|
| 701 |
-
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
| 702 |
-
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
| 703 |
-
|
| 704 |
-
// Add negative offsets to return iterators to the 'start' of the
|
| 705 |
-
// circular buffer in shared memory
|
| 706 |
-
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
| 707 |
-
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
| 708 |
-
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
| 709 |
-
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
| 710 |
-
smem_write_stage_idx = 0;
|
| 711 |
-
} else {
|
| 712 |
-
++smem_write_stage_idx;
|
| 713 |
-
}
|
| 714 |
-
|
| 715 |
-
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
| 716 |
-
this->warp_tile_iterator_A_.add_tile_offset(
|
| 717 |
-
{0, -Base::kStages * Policy0::kPartitionsK *
|
| 718 |
-
Base::kWarpGemmIterations});
|
| 719 |
-
this->warp_tile_iterator_B0_.add_tile_offset(
|
| 720 |
-
{-Base::kStages * Policy0::kPartitionsK *
|
| 721 |
-
Base::kWarpGemmIterations,
|
| 722 |
-
0});
|
| 723 |
-
this->warp_tile_iterator_B1_.add_tile_offset(
|
| 724 |
-
{-Base::kStages * Policy1::kPartitionsK *
|
| 725 |
-
Base::kWarpGemmIterations,
|
| 726 |
-
0});
|
| 727 |
-
smem_read_stage_idx = 0;
|
| 728 |
-
} else {
|
| 729 |
-
++smem_read_stage_idx;
|
| 730 |
-
}
|
| 731 |
-
|
| 732 |
-
--gemm_k_iterations;
|
| 733 |
-
iterator_A.clear_mask(gemm_k_iterations == 0);
|
| 734 |
-
iterator_B0.clear_mask(gemm_k_iterations == 0);
|
| 735 |
-
iterator_B1.clear_mask(gemm_k_iterations == 0);
|
| 736 |
-
}
|
| 737 |
-
|
| 738 |
-
// Do any conversions feeding the first stage at the end of the loop so
|
| 739 |
-
// we can start right away on mma instructions
|
| 740 |
-
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
| 741 |
-
warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
| 742 |
-
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
| 743 |
-
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
| 744 |
-
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
| 745 |
-
warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
| 746 |
-
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
| 747 |
-
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
| 748 |
-
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
| 749 |
-
}
|
| 750 |
-
}
|
| 751 |
-
|
| 752 |
-
}
|
| 753 |
-
|
| 754 |
-
if (platform::is_same<typename Operator0::MathOperator,
|
| 755 |
-
arch::OpMultiplyAddFastF32>::value
|
| 756 |
-
|| platform::is_same<typename Operator0::MathOperator,
|
| 757 |
-
arch::OpMultiplyAddComplexFastF32>::value) {
|
| 758 |
-
accum0 = plus_accum(accum0, tmp_accum0);
|
| 759 |
-
accum1 = plus_accum(accum1, tmp_accum1);
|
| 760 |
-
}
|
| 761 |
-
|
| 762 |
-
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
| 763 |
-
cutlass::arch::cp_async_fence();
|
| 764 |
-
cutlass::arch::cp_async_wait<0>();
|
| 765 |
-
__syncthreads();
|
| 766 |
-
}
|
| 767 |
-
};
|
| 768 |
-
|
| 769 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 770 |
-
|
| 771 |
-
} // namespace threadblock
|
| 772 |
-
} // namespace gemm
|
| 773 |
-
} // namespace cutlass
|
| 774 |
-
|
| 775 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include "cute/tensor.hpp"
|
| 34 |
-
|
| 35 |
-
#include "cutlass/arch/arch.h"
|
| 36 |
-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 37 |
-
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
| 38 |
-
#include "cutlass/gemm/collective/collective_builder.hpp"
|
| 39 |
-
|
| 40 |
-
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
|
| 41 |
-
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 42 |
-
|
| 43 |
-
namespace example {
|
| 44 |
-
|
| 45 |
-
//
|
| 46 |
-
// GETT entry point
|
| 47 |
-
//
|
| 48 |
-
template <
|
| 49 |
-
class ProblemShapeMNKL,
|
| 50 |
-
class ElementA,
|
| 51 |
-
class StrideA,
|
| 52 |
-
class ElementB,
|
| 53 |
-
class StrideB,
|
| 54 |
-
class ElementAccumulator,
|
| 55 |
-
class ElementC,
|
| 56 |
-
class StrideC,
|
| 57 |
-
class ElementD,
|
| 58 |
-
class StrideD,
|
| 59 |
-
class ElementEpilogue>
|
| 60 |
-
cutlass::Status
|
| 61 |
-
gett_kernel(
|
| 62 |
-
ProblemShapeMNKL problem_shape_mnkl,
|
| 63 |
-
ElementA const* ptr_A, StrideA stride_a_mkl,
|
| 64 |
-
ElementB const* ptr_B, StrideB stride_b_nkl,
|
| 65 |
-
ElementAccumulator _,
|
| 66 |
-
ElementC const* ptr_C, StrideC stride_c_mnl,
|
| 67 |
-
ElementD * ptr_D, StrideD stride_d_mnl,
|
| 68 |
-
ElementEpilogue alpha, ElementEpilogue beta,
|
| 69 |
-
cudaStream_t stream = 0) {
|
| 70 |
-
using namespace cute;
|
| 71 |
-
|
| 72 |
-
// TileShape -- GETT configuration
|
| 73 |
-
// Specify the number of elements to take from each mode
|
| 74 |
-
// BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...)
|
| 75 |
-
|
| 76 |
-
// Take 128 from m0, 128 from n0, 64 from k0
|
| 77 |
-
using TileShape = Shape<Shape<_128>, Shape<_128>, Shape<_64>>;
|
| 78 |
-
|
| 79 |
-
/* Other examples:
|
| 80 |
-
* Take 32 elements from m0 and 4 elements from m1
|
| 81 |
-
* Take 64 elements from n0 and 2 elements from n1
|
| 82 |
-
* Take 8 elements from k0 and 8 elements from k1
|
| 83 |
-
**/
|
| 84 |
-
// using TileShape = Shape<Shape<_32,_4>, Shape<_64,_2>, Shape<_8,_8>>;
|
| 85 |
-
|
| 86 |
-
using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination<
|
| 87 |
-
ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default,
|
| 88 |
-
cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
| 89 |
-
|
| 90 |
-
// No changes are required to the default epilogue
|
| 91 |
-
using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
| 92 |
-
cutlass::epilogue::collective::DefaultEpilogue<
|
| 93 |
-
ElementC,
|
| 94 |
-
StrideC,
|
| 95 |
-
StrideD,
|
| 96 |
-
EpilogueThreadOp,
|
| 97 |
-
cutlass::gemm::EpilogueDefault>>;
|
| 98 |
-
|
| 99 |
-
// CollectiveMma for GETTs can be built using the CollectiveBuilders
|
| 100 |
-
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
| 101 |
-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
| 102 |
-
ElementA, StrideA, 128 / cutlass::sizeof_bits<ElementA>::value,
|
| 103 |
-
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
| 104 |
-
ElementAccumulator,
|
| 105 |
-
TileShape, Shape<_1,_2,_1>,
|
| 106 |
-
cutlass::gemm::collective::StageCountAutoCarveout<
|
| 107 |
-
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
| 108 |
-
cutlass::gemm::collective::KernelScheduleAuto
|
| 109 |
-
>::CollectiveOp;
|
| 110 |
-
|
| 111 |
-
// The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM
|
| 112 |
-
using GettKernel = cutlass::gemm::kernel::GemmUniversal<
|
| 113 |
-
ProblemShapeMNKL,
|
| 114 |
-
CollectiveMainloop,
|
| 115 |
-
CollectiveEpilogue>;
|
| 116 |
-
|
| 117 |
-
using GettOperator = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;
|
| 118 |
-
|
| 119 |
-
typename GettOperator::Arguments args {
|
| 120 |
-
cutlass::gemm::GemmUniversalMode::kBatched,
|
| 121 |
-
problem_shape_mnkl,
|
| 122 |
-
{ ptr_A, stride_a_mkl, ptr_B, stride_b_nkl },
|
| 123 |
-
{ {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl }
|
| 124 |
-
};
|
| 125 |
-
|
| 126 |
-
#if CUTLASS_DEBUG_TRACE_LEVEL > 0
|
| 127 |
-
print("Problem shape:");
|
| 128 |
-
print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n");
|
| 129 |
-
print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n");
|
| 130 |
-
print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n");
|
| 131 |
-
print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n");
|
| 132 |
-
print("TileSape:"); print(TileShape{}); print("\n");
|
| 133 |
-
#endif
|
| 134 |
-
|
| 135 |
-
GettOperator op;
|
| 136 |
-
return op(args, stream);
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
} // namespace example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp
DELETED
|
@@ -1,421 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include "cutlass/cutlass.h"
|
| 34 |
-
#include "cutlass/kernel_hardware_info.hpp"
|
| 35 |
-
#include "cutlass/gemm/gemm.h"
|
| 36 |
-
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 37 |
-
|
| 38 |
-
#include "cute/tensor.hpp"
|
| 39 |
-
|
| 40 |
-
#include "gather_tensor.hpp"
|
| 41 |
-
|
| 42 |
-
namespace cutlass {
|
| 43 |
-
///Forward declaration
|
| 44 |
-
struct CudaHostAdapter;
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
namespace cutlass::gemm::kernel {
|
| 48 |
-
|
| 49 |
-
///////////////////////////////////////////////////////////////////////////////
|
| 50 |
-
|
| 51 |
-
template <
|
| 52 |
-
class ProblemShape_,
|
| 53 |
-
class CollectiveMainloop_,
|
| 54 |
-
class CollectiveEpilogue_,
|
| 55 |
-
class TileScheduler_,
|
| 56 |
-
class GatherA_,
|
| 57 |
-
class GatherB_
|
| 58 |
-
>
|
| 59 |
-
class GemmGather
|
| 60 |
-
{
|
| 61 |
-
public:
|
| 62 |
-
//
|
| 63 |
-
// Type Aliases
|
| 64 |
-
//
|
| 65 |
-
using ProblemShape = ProblemShape_;
|
| 66 |
-
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
| 67 |
-
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
| 68 |
-
|
| 69 |
-
// Mainloop derived types
|
| 70 |
-
using CollectiveMainloop = CollectiveMainloop_;
|
| 71 |
-
using TileShape = typename CollectiveMainloop::TileShape;
|
| 72 |
-
using TiledMma = typename CollectiveMainloop::TiledMma;
|
| 73 |
-
using ArchTag = typename CollectiveMainloop::ArchTag;
|
| 74 |
-
using ElementA = typename CollectiveMainloop::ElementA;
|
| 75 |
-
using StrideA = typename CollectiveMainloop::StrideA;
|
| 76 |
-
using ElementB = typename CollectiveMainloop::ElementB;
|
| 77 |
-
using StrideB = typename CollectiveMainloop::StrideB;
|
| 78 |
-
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
| 79 |
-
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
| 80 |
-
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
| 81 |
-
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
| 82 |
-
using MainloopParams = typename CollectiveMainloop::Params;
|
| 83 |
-
static_assert(ArchTag::kMinComputeCapability >= 90);
|
| 84 |
-
|
| 85 |
-
// Epilogue derived types
|
| 86 |
-
using CollectiveEpilogue = CollectiveEpilogue_;
|
| 87 |
-
using ElementC = typename CollectiveEpilogue::ElementC;
|
| 88 |
-
using StrideC = typename CollectiveEpilogue::StrideC;
|
| 89 |
-
using ElementD = typename CollectiveEpilogue::ElementD;
|
| 90 |
-
using StrideD = typename CollectiveEpilogue::StrideD;
|
| 91 |
-
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
| 92 |
-
using EpilogueParams = typename CollectiveEpilogue::Params;
|
| 93 |
-
|
| 94 |
-
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
|
| 95 |
-
"Non-persistent warp-specialized kernel does not support specializing the tile scheduler.");
|
| 96 |
-
using TileSchedulerTag = TileScheduler_;
|
| 97 |
-
using TileScheduler = typename detail::TileSchedulerSelector<
|
| 98 |
-
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
| 99 |
-
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
| 100 |
-
|
| 101 |
-
using GatherA = GatherA_;
|
| 102 |
-
using GatherB = GatherB_;
|
| 103 |
-
|
| 104 |
-
// Kernel level shared memory storage
|
| 105 |
-
struct SharedStorage {
|
| 106 |
-
union TensorStorage {
|
| 107 |
-
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
| 108 |
-
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
| 109 |
-
|
| 110 |
-
MainloopTensorStorage mainloop;
|
| 111 |
-
EpilogueTensorStorage epilogue;
|
| 112 |
-
} tensors;
|
| 113 |
-
|
| 114 |
-
struct PipelineStorage : cute::aligned_struct<16, _2> {
|
| 115 |
-
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
| 116 |
-
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
| 117 |
-
|
| 118 |
-
alignas(16) MainloopPipelineStorage mainloop;
|
| 119 |
-
alignas(16) EpiLoadPipelineStorage epi_load;
|
| 120 |
-
} pipelines;
|
| 121 |
-
};
|
| 122 |
-
|
| 123 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 124 |
-
|
| 125 |
-
using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA;
|
| 126 |
-
using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB;
|
| 127 |
-
static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same.");
|
| 128 |
-
|
| 129 |
-
static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup;
|
| 130 |
-
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup;
|
| 131 |
-
static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups;
|
| 132 |
-
static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance.");
|
| 133 |
-
|
| 134 |
-
static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup;
|
| 135 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
| 136 |
-
|
| 137 |
-
// Device side arguments
|
| 138 |
-
struct Arguments {
|
| 139 |
-
GemmUniversalMode mode{};
|
| 140 |
-
ProblemShape problem_shape{};
|
| 141 |
-
MainloopArguments mainloop{};
|
| 142 |
-
EpilogueArguments epilogue{};
|
| 143 |
-
KernelHardwareInfo hw_info{};
|
| 144 |
-
TileSchedulerArguments scheduler{};
|
| 145 |
-
GatherA gather_A{};
|
| 146 |
-
GatherB gather_B{};
|
| 147 |
-
};
|
| 148 |
-
|
| 149 |
-
// Kernel entry point API
|
| 150 |
-
struct Params {
|
| 151 |
-
GemmUniversalMode mode{};
|
| 152 |
-
ProblemShape problem_shape{};
|
| 153 |
-
MainloopParams mainloop{};
|
| 154 |
-
EpilogueParams epilogue{};
|
| 155 |
-
GatherA gather_A{};
|
| 156 |
-
GatherB gather_B{};
|
| 157 |
-
};
|
| 158 |
-
|
| 159 |
-
//
|
| 160 |
-
// Methods
|
| 161 |
-
//
|
| 162 |
-
|
| 163 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 164 |
-
static
|
| 165 |
-
Params
|
| 166 |
-
to_underlying_arguments(Arguments const& args, void* workspace) {
|
| 167 |
-
(void) workspace;
|
| 168 |
-
auto problem_shape = args.problem_shape;
|
| 169 |
-
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
|
| 170 |
-
// swap M/N
|
| 171 |
-
get<0>(problem_shape) = get<1>(args.problem_shape);
|
| 172 |
-
get<1>(problem_shape) = get<0>(args.problem_shape);
|
| 173 |
-
}
|
| 174 |
-
return {
|
| 175 |
-
args.mode,
|
| 176 |
-
problem_shape,
|
| 177 |
-
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
| 178 |
-
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
|
| 179 |
-
args.gather_A,
|
| 180 |
-
args.gather_B
|
| 181 |
-
};
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
static bool
|
| 185 |
-
can_implement(Arguments const& args) {
|
| 186 |
-
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
| 187 |
-
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
| 188 |
-
if (!implementable) {
|
| 189 |
-
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
| 190 |
-
return implementable;
|
| 191 |
-
}
|
| 192 |
-
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
| 193 |
-
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
| 194 |
-
return implementable;
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
static
|
| 198 |
-
size_t
|
| 199 |
-
get_workspace_size(Arguments const& args) {
|
| 200 |
-
return 0;
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
static
|
| 204 |
-
cutlass::Status
|
| 205 |
-
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
| 206 |
-
CudaHostAdapter* cuda_adapter = nullptr) {
|
| 207 |
-
return Status::kSuccess;
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
// Computes the kernel launch grid shape based on runtime parameters
|
| 211 |
-
static dim3
|
| 212 |
-
get_grid_shape(Params const& params) {
|
| 213 |
-
auto cluster_shape = Shape<_1,_1,_1>{};
|
| 214 |
-
auto tile_shape = TileShape{};
|
| 215 |
-
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
| 216 |
-
return TileScheduler::get_tiled_cta_shape_mnl(
|
| 217 |
-
problem_shape_MNKL, tile_shape, cluster_shape);
|
| 218 |
-
}
|
| 219 |
-
|
| 220 |
-
static dim3
|
| 221 |
-
get_block_shape() {
|
| 222 |
-
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 223 |
-
}
|
| 224 |
-
|
| 225 |
-
CUTLASS_DEVICE
|
| 226 |
-
void
|
| 227 |
-
operator()(Params const& params, char* smem_buf) {
|
| 228 |
-
using namespace cute;
|
| 229 |
-
using X = Underscore;
|
| 230 |
-
|
| 231 |
-
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
| 232 |
-
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
| 233 |
-
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
| 234 |
-
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
| 235 |
-
return;
|
| 236 |
-
}
|
| 237 |
-
#endif
|
| 238 |
-
|
| 239 |
-
enum class WarpGroupRole {
|
| 240 |
-
Producer = 0,
|
| 241 |
-
Consumer = 1,
|
| 242 |
-
};
|
| 243 |
-
|
| 244 |
-
// Kernel level shared memory storage
|
| 245 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 246 |
-
|
| 247 |
-
int thread_idx = int(threadIdx.x);
|
| 248 |
-
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
| 249 |
-
int warp_group_idx = canonical_warp_group_idx();
|
| 250 |
-
CUTLASS_ASSERT(warp_group_idx < NumWarpGroups);
|
| 251 |
-
WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer;
|
| 252 |
-
|
| 253 |
-
// Mainloop Load pipeline
|
| 254 |
-
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
| 255 |
-
typename MainloopPipeline::Params mainloop_pipeline_params;
|
| 256 |
-
if (warp_group_role == WarpGroupRole::Producer) {
|
| 257 |
-
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
| 258 |
-
}
|
| 259 |
-
if (warp_group_role == WarpGroupRole::Consumer) {
|
| 260 |
-
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
| 261 |
-
}
|
| 262 |
-
mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup;
|
| 263 |
-
mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup;
|
| 264 |
-
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params);
|
| 265 |
-
|
| 266 |
-
// Epilogue Load pipeline
|
| 267 |
-
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
| 268 |
-
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
| 269 |
-
if (warp_group_role == WarpGroupRole::Producer) {
|
| 270 |
-
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
| 271 |
-
}
|
| 272 |
-
if (warp_group_role == WarpGroupRole::Consumer) {
|
| 273 |
-
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
| 274 |
-
}
|
| 275 |
-
epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup;
|
| 276 |
-
epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup;
|
| 277 |
-
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
| 278 |
-
|
| 279 |
-
// Epilogue Store pipeline
|
| 280 |
-
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
| 281 |
-
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
| 282 |
-
epi_store_pipeline_params.always_wait = true;
|
| 283 |
-
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
| 284 |
-
|
| 285 |
-
// Initialize starting pipeline states for the collectives
|
| 286 |
-
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
| 287 |
-
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
| 288 |
-
|
| 289 |
-
// For the DMA Load (producer) we start with an opposite phase
|
| 290 |
-
// i.e., we skip all waits since we know that the buffer is indeed empty
|
| 291 |
-
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
| 292 |
-
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
| 293 |
-
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
| 294 |
-
|
| 295 |
-
// Preconditions
|
| 296 |
-
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 297 |
-
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 298 |
-
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 299 |
-
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 300 |
-
|
| 301 |
-
// Separate out problem shape for convenience
|
| 302 |
-
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
| 303 |
-
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
| 304 |
-
auto M = get<0>(problem_shape_MNKL);
|
| 305 |
-
auto N = get<1>(problem_shape_MNKL);
|
| 306 |
-
auto K = get<2>(problem_shape_MNKL);
|
| 307 |
-
auto L = get<3>(problem_shape_MNKL);
|
| 308 |
-
|
| 309 |
-
// Represent the full tensors
|
| 310 |
-
Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l)
|
| 311 |
-
Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l)
|
| 312 |
-
|
| 313 |
-
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
| 314 |
-
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
| 315 |
-
TiledMma tiled_mma;
|
| 316 |
-
|
| 317 |
-
// Make tiled views, defer the slice
|
| 318 |
-
Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
| 319 |
-
Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
| 320 |
-
|
| 321 |
-
// Compute m_coord, n_coord, and l_coord with their post-tiled shapes
|
| 322 |
-
auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
|
| 323 |
-
auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
|
| 324 |
-
auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
|
| 325 |
-
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
| 326 |
-
|
| 327 |
-
// Slice with m_coord and n_coord
|
| 328 |
-
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
| 329 |
-
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
| 330 |
-
|
| 331 |
-
// Get pipeline iterators and increments from tensor shapes
|
| 332 |
-
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
|
| 333 |
-
auto k_tile_count = size<2>(gA);
|
| 334 |
-
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
| 335 |
-
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
| 336 |
-
|
| 337 |
-
// Wait for all threads in the thread block
|
| 338 |
-
__syncthreads();
|
| 339 |
-
|
| 340 |
-
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
| 341 |
-
CollectiveMainloop collective_mainloop;
|
| 342 |
-
CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue};
|
| 343 |
-
|
| 344 |
-
if (warp_group_role == WarpGroupRole::Producer) {
|
| 345 |
-
// Compute tile residues for predication
|
| 346 |
-
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord
|
| 347 |
-
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord
|
| 348 |
-
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
|
| 349 |
-
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
|
| 350 |
-
|
| 351 |
-
collective_mainloop.load(
|
| 352 |
-
mainloop_pipeline,
|
| 353 |
-
mainloop_pipe_producer_state,
|
| 354 |
-
gA,
|
| 355 |
-
gB,
|
| 356 |
-
k_tile_iter, k_tile_count,
|
| 357 |
-
residue_mnk,
|
| 358 |
-
thread_idx,
|
| 359 |
-
shared_storage.tensors.mainloop
|
| 360 |
-
);
|
| 361 |
-
// Update starting mainloop pipeline state for the pipeline drain
|
| 362 |
-
mainloop_pipe_producer_state.advance(k_tile_count);
|
| 363 |
-
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
| 364 |
-
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
| 365 |
-
|
| 366 |
-
if (collective_epilogue.is_producer_load_needed()) {
|
| 367 |
-
epi_load_pipe_producer_state =
|
| 368 |
-
collective_epilogue.load(
|
| 369 |
-
epi_load_pipeline,
|
| 370 |
-
epi_load_pipe_producer_state,
|
| 371 |
-
problem_shape_MNKL,
|
| 372 |
-
blk_shape,
|
| 373 |
-
blk_coord,
|
| 374 |
-
tiled_mma,
|
| 375 |
-
thread_idx,
|
| 376 |
-
shared_storage.tensors.epilogue
|
| 377 |
-
);
|
| 378 |
-
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
| 379 |
-
}
|
| 380 |
-
}
|
| 381 |
-
else if (warp_group_role == WarpGroupRole::Consumer) {
|
| 382 |
-
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
| 383 |
-
|
| 384 |
-
collective_mainloop.mma(
|
| 385 |
-
mainloop_pipeline,
|
| 386 |
-
mainloop_pipe_consumer_state,
|
| 387 |
-
accumulators,
|
| 388 |
-
k_tile_count,
|
| 389 |
-
warp_group_thread_idx,
|
| 390 |
-
shared_storage.tensors.mainloop,
|
| 391 |
-
params.mainloop
|
| 392 |
-
);
|
| 393 |
-
|
| 394 |
-
// Make sure the math instructions are done and free buffers before entering the epilogue
|
| 395 |
-
collective_mainloop.mma_tail(
|
| 396 |
-
mainloop_pipeline,
|
| 397 |
-
mainloop_pipe_consumer_state,
|
| 398 |
-
k_tile_count
|
| 399 |
-
);
|
| 400 |
-
|
| 401 |
-
// Epilogue and write to gD
|
| 402 |
-
collective_epilogue.store(
|
| 403 |
-
epi_load_pipeline,
|
| 404 |
-
epi_load_pipe_consumer_state,
|
| 405 |
-
epi_store_pipeline,
|
| 406 |
-
epi_store_pipe_producer_state,
|
| 407 |
-
problem_shape_MNKL,
|
| 408 |
-
blk_shape,
|
| 409 |
-
blk_coord,
|
| 410 |
-
accumulators,
|
| 411 |
-
tiled_mma,
|
| 412 |
-
warp_group_thread_idx,
|
| 413 |
-
shared_storage.tensors.epilogue
|
| 414 |
-
);
|
| 415 |
-
}
|
| 416 |
-
}
|
| 417 |
-
};
|
| 418 |
-
|
| 419 |
-
///////////////////////////////////////////////////////////////////////////////
|
| 420 |
-
|
| 421 |
-
} // namespace cutlass::gemm::kernel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh
DELETED
|
@@ -1,136 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include "cute/numeric/math.hpp"
|
| 34 |
-
|
| 35 |
-
namespace example
|
| 36 |
-
{
|
| 37 |
-
|
| 38 |
-
// Naive grid-stride loop implementation of gather
|
| 39 |
-
template<typename Element, typename Func>
|
| 40 |
-
__global__ void
|
| 41 |
-
gather_kernel(Element const * __restrict__ input,
|
| 42 |
-
Element * __restrict__ output,
|
| 43 |
-
Func func,
|
| 44 |
-
int num_elems_input,
|
| 45 |
-
int num_elems_output,
|
| 46 |
-
cutlass::FastDivmod stride_divmod)
|
| 47 |
-
{
|
| 48 |
-
Element const * input_b = input + blockIdx.z * num_elems_input;
|
| 49 |
-
Element * output_b = output + blockIdx.z * num_elems_output;
|
| 50 |
-
int tidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 51 |
-
for (int k = tidx; k < num_elems_output; k += blockDim.x * gridDim.x) {
|
| 52 |
-
int i,j;
|
| 53 |
-
stride_divmod(j, i, k);
|
| 54 |
-
output_b[k] = input_b[i + func(j) * stride_divmod.divisor];
|
| 55 |
-
}
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
// Gather elements along strided dimension of the tensor according to given indices
|
| 59 |
-
template<typename Element, typename Func>
|
| 60 |
-
void
|
| 61 |
-
gather(Element const * input,
|
| 62 |
-
Element * output,
|
| 63 |
-
Func func,
|
| 64 |
-
int batch_size,
|
| 65 |
-
int num_elems_input,
|
| 66 |
-
int num_elems_output,
|
| 67 |
-
int stride,
|
| 68 |
-
cutlass::KernelHardwareInfo const& hw_info)
|
| 69 |
-
{
|
| 70 |
-
// Upcast to uint128_t data type
|
| 71 |
-
int factor = 128 / cutlass::sizeof_bits<Element>::value;
|
| 72 |
-
assert(stride % factor == 0);
|
| 73 |
-
int stride_upcast = stride/factor;
|
| 74 |
-
int num_elems_input_upcast = num_elems_input / factor;
|
| 75 |
-
int num_elems_output_upcast = num_elems_output / factor;
|
| 76 |
-
|
| 77 |
-
cutlass::FastDivmod stride_divmod(stride_upcast);
|
| 78 |
-
dim3 blocks(hw_info.sm_count, 1, batch_size);
|
| 79 |
-
gather_kernel<<<blocks, 1024>>>(reinterpret_cast<cute::uint128_t const *>(input),
|
| 80 |
-
reinterpret_cast<cute::uint128_t *>(output),
|
| 81 |
-
func,
|
| 82 |
-
num_elems_input_upcast,
|
| 83 |
-
num_elems_output_upcast,
|
| 84 |
-
stride_divmod);
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
// Naive grid-stride loop implementation of scatter
|
| 88 |
-
template<typename Element, typename Func>
|
| 89 |
-
__global__ void
|
| 90 |
-
scatter_kernel(Element const * __restrict__ input,
|
| 91 |
-
Element * __restrict__ output,
|
| 92 |
-
Func func,
|
| 93 |
-
int num_elems_input,
|
| 94 |
-
int num_elems_output,
|
| 95 |
-
cutlass::FastDivmod stride_divmod)
|
| 96 |
-
{
|
| 97 |
-
Element const * input_b = input + blockIdx.z * num_elems_input;
|
| 98 |
-
Element * output_b = output + blockIdx.z * num_elems_output;
|
| 99 |
-
int tidx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 100 |
-
for (int k = tidx; k < num_elems_input; k += blockDim.x * gridDim.x) {
|
| 101 |
-
int i,j;
|
| 102 |
-
stride_divmod(j, i, k);
|
| 103 |
-
output_b[i + func(j) * stride_divmod.divisor] = input_b[k];
|
| 104 |
-
}
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
// Gather elements along strided dimension of the tensor according to given indices
|
| 108 |
-
template<typename Element, typename Func>
|
| 109 |
-
void
|
| 110 |
-
scatter(Element const * input,
|
| 111 |
-
Element * output,
|
| 112 |
-
Func func,
|
| 113 |
-
int batch_size,
|
| 114 |
-
int num_elems_input,
|
| 115 |
-
int num_elems_output,
|
| 116 |
-
int stride,
|
| 117 |
-
cutlass::KernelHardwareInfo const& hw_info)
|
| 118 |
-
{
|
| 119 |
-
// Upcast to uint128_t data type
|
| 120 |
-
int factor = 128 / cutlass::sizeof_bits<Element>::value;
|
| 121 |
-
assert(stride % factor == 0);
|
| 122 |
-
int stride_upcast = stride/factor;
|
| 123 |
-
int num_elems_input_upcast = num_elems_input / factor;
|
| 124 |
-
int num_elems_output_upcast = num_elems_output / factor;
|
| 125 |
-
|
| 126 |
-
cutlass::FastDivmod stride_divmod(stride_upcast);
|
| 127 |
-
dim3 blocks(hw_info.sm_count, 1, batch_size);
|
| 128 |
-
scatter_kernel<<<blocks, 1024>>>(reinterpret_cast<cute::uint128_t const *>(input),
|
| 129 |
-
reinterpret_cast<cute::uint128_t *>(output),
|
| 130 |
-
func,
|
| 131 |
-
num_elems_input_upcast,
|
| 132 |
-
num_elems_output_upcast,
|
| 133 |
-
stride_divmod);
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
} // namespace example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Functor performing elementwise operations used by epilogues.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/cutlass.h"
|
| 38 |
-
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 39 |
-
#include "cutlass/epilogue/collective/detail.hpp"
|
| 40 |
-
|
| 41 |
-
#include "cute/tensor.hpp"
|
| 42 |
-
#include "cute/numeric/numeric_types.hpp"
|
| 43 |
-
|
| 44 |
-
#include "gather_tensor.hpp"
|
| 45 |
-
|
| 46 |
-
namespace cutlass::epilogue::collective {
|
| 47 |
-
|
| 48 |
-
/// Applies an element wise operation to all elements within the fragment
|
| 49 |
-
/// and scatter-writes them out to destination storage.
|
| 50 |
-
/// GatherC and ScatterD are types of user-defined functions that apply the
|
| 51 |
-
/// transoformation of the strided coordinate (e.g. through an index array).
|
| 52 |
-
template <
|
| 53 |
-
class StrideC_,
|
| 54 |
-
class StrideD_,
|
| 55 |
-
class ThreadEpilogueOp_,
|
| 56 |
-
class EpilogueSchedule_,
|
| 57 |
-
class GatherC_,
|
| 58 |
-
class ScatterD_
|
| 59 |
-
>
|
| 60 |
-
class EpilogueGatherScatter {
|
| 61 |
-
public:
|
| 62 |
-
//
|
| 63 |
-
// Type Aliases
|
| 64 |
-
//
|
| 65 |
-
using EpilogueSchedule = EpilogueSchedule_;
|
| 66 |
-
|
| 67 |
-
// derived types of output thread level operator
|
| 68 |
-
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
| 69 |
-
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
| 70 |
-
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
| 71 |
-
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
| 72 |
-
using ElementScalar = ElementCompute;
|
| 73 |
-
using ElementC = typename ThreadEpilogueOp::ElementC;
|
| 74 |
-
using StrideC = StrideC_;
|
| 75 |
-
using ElementD = typename ThreadEpilogueOp::ElementD;
|
| 76 |
-
using StrideD = StrideD_;
|
| 77 |
-
|
| 78 |
-
// Every epilogue needs these two GmemTiledCopy{C,D} aliases.
|
| 79 |
-
// If you don't know what they should be, just use void.
|
| 80 |
-
using GmemTiledCopyC = void;
|
| 81 |
-
using GmemTiledCopyD = void;
|
| 82 |
-
|
| 83 |
-
using GatherC = GatherC_;
|
| 84 |
-
using ScatterD = ScatterD_;
|
| 85 |
-
|
| 86 |
-
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
| 87 |
-
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
| 88 |
-
|
| 89 |
-
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
| 90 |
-
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
| 91 |
-
|
| 92 |
-
struct SharedStorage { };
|
| 93 |
-
|
| 94 |
-
// Host side epilogue arguments
|
| 95 |
-
struct Arguments {
|
| 96 |
-
typename ThreadEpilogueOp::Params thread_params{};
|
| 97 |
-
ElementC const* ptr_C = nullptr;
|
| 98 |
-
StrideC dC{};
|
| 99 |
-
ElementD* ptr_D = nullptr;
|
| 100 |
-
StrideD dD{};
|
| 101 |
-
GatherC gather_C{};
|
| 102 |
-
ScatterD scatter_D{};
|
| 103 |
-
};
|
| 104 |
-
|
| 105 |
-
// Device side epilogue params
|
| 106 |
-
using Params = Arguments;
|
| 107 |
-
|
| 108 |
-
//
|
| 109 |
-
// Methods
|
| 110 |
-
//
|
| 111 |
-
|
| 112 |
-
template <class ProblemShape>
|
| 113 |
-
static constexpr Params
|
| 114 |
-
to_underlying_arguments(
|
| 115 |
-
[[maybe_unused]] ProblemShape const& _,
|
| 116 |
-
Arguments const& args,
|
| 117 |
-
[[maybe_unused]] void* workspace) {
|
| 118 |
-
return args;
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
template<class ProblemShape>
|
| 122 |
-
static bool
|
| 123 |
-
can_implement(
|
| 124 |
-
[[maybe_unused]] ProblemShape const& problem_shape,
|
| 125 |
-
[[maybe_unused]] Arguments const& args) {
|
| 126 |
-
return true;
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
CUTLASS_HOST_DEVICE
|
| 130 |
-
EpilogueGatherScatter(Params const& params_) : params(params_) { }
|
| 131 |
-
|
| 132 |
-
template<
|
| 133 |
-
class ProblemShapeMNKL,
|
| 134 |
-
class BlockShapeMNK,
|
| 135 |
-
class BlockCoordMNKL,
|
| 136 |
-
class FrgEngine, class FrgLayout,
|
| 137 |
-
class TiledMma,
|
| 138 |
-
class ResidueMNK
|
| 139 |
-
>
|
| 140 |
-
CUTLASS_DEVICE void
|
| 141 |
-
operator()(
|
| 142 |
-
ProblemShapeMNKL problem_shape_mnkl,
|
| 143 |
-
BlockShapeMNK blk_shape_MNK,
|
| 144 |
-
BlockCoordMNKL blk_coord_mnkl,
|
| 145 |
-
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
|
| 146 |
-
TiledMma tiled_mma,
|
| 147 |
-
ResidueMNK residue_mnk,
|
| 148 |
-
int thread_idx,
|
| 149 |
-
char* smem_buf)
|
| 150 |
-
{
|
| 151 |
-
using namespace cute;
|
| 152 |
-
using X = Underscore;
|
| 153 |
-
|
| 154 |
-
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
| 155 |
-
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
| 156 |
-
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
| 157 |
-
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
| 158 |
-
|
| 159 |
-
(void) smem_buf;
|
| 160 |
-
ThreadEpilogueOp epilogue_op{params.thread_params};
|
| 161 |
-
|
| 162 |
-
// Separate out problem shape for convenience
|
| 163 |
-
auto M = get<0>(problem_shape_mnkl);
|
| 164 |
-
auto N = get<1>(problem_shape_mnkl);
|
| 165 |
-
auto L = get<3>(problem_shape_mnkl);
|
| 166 |
-
|
| 167 |
-
auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
|
| 168 |
-
auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
|
| 169 |
-
|
| 170 |
-
// Represent the full output tensor
|
| 171 |
-
Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l)
|
| 172 |
-
Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l)
|
| 173 |
-
|
| 174 |
-
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
| 175 |
-
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
| 176 |
-
|
| 177 |
-
// Slice to get the tile this CTA is responsible for
|
| 178 |
-
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
| 179 |
-
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
| 180 |
-
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
| 181 |
-
|
| 182 |
-
// Partition source and destination tiles to match the accumulator partitioning
|
| 183 |
-
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
| 184 |
-
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
|
| 185 |
-
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
|
| 186 |
-
|
| 187 |
-
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
|
| 188 |
-
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
|
| 189 |
-
"Source and destination must have the same number of elements.");
|
| 190 |
-
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
|
| 191 |
-
"Accumulator count must have the same destination element count.");
|
| 192 |
-
|
| 193 |
-
// Make an identity coordinate tensor for predicating our output MN tile
|
| 194 |
-
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
| 195 |
-
Tensor tCcD = thr_mma.partition_C(cD);
|
| 196 |
-
|
| 197 |
-
// source is needed
|
| 198 |
-
if (epilogue_op.is_source_needed()) {
|
| 199 |
-
CUTLASS_PRAGMA_UNROLL
|
| 200 |
-
for (int i = 0; i < size(accumulators); ++i) {
|
| 201 |
-
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
|
| 202 |
-
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
|
| 203 |
-
}
|
| 204 |
-
}
|
| 205 |
-
}
|
| 206 |
-
// source is not needed, avoid load
|
| 207 |
-
else {
|
| 208 |
-
CUTLASS_PRAGMA_UNROLL
|
| 209 |
-
for (int i = 0; i < size(accumulators); ++i) {
|
| 210 |
-
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
|
| 211 |
-
tCgD(i) = epilogue_op(accumulators(i));
|
| 212 |
-
}
|
| 213 |
-
}
|
| 214 |
-
}
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
private:
|
| 218 |
-
Params params;
|
| 219 |
-
};
|
| 220 |
-
|
| 221 |
-
} // namespace cutlass::epilogue::collective
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Simple permutation kernel implementation.
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#include "cutlass/layout/pitch_linear.h"
|
| 37 |
-
#include "cutlass/layout/matrix.h"
|
| 38 |
-
#include "cutlass/tensor_view.h"
|
| 39 |
-
#include "cutlass/fast_math.h"
|
| 40 |
-
#include "cute/numeric/numeric_types.hpp"
|
| 41 |
-
|
| 42 |
-
namespace example
|
| 43 |
-
{
|
| 44 |
-
|
| 45 |
-
/**
|
| 46 |
-
* Assumes column-major input (M mode is contiguous, N mode is strided).
|
| 47 |
-
* For row major, the inputs must be switched accordingly.
|
| 48 |
-
*/
|
| 49 |
-
template<bool Batched, typename Element, typename Permute>
|
| 50 |
-
__global__ void
|
| 51 |
-
permute_kernel(Element const* __restrict__ input,
|
| 52 |
-
Element* __restrict__ output,
|
| 53 |
-
Permute permute,
|
| 54 |
-
int64_t num_elems,
|
| 55 |
-
cutlass::FastDivmod stride_divmod)
|
| 56 |
-
{
|
| 57 |
-
// CUTLASS 2.x batched permute functions assume 0 batch stride for target tensor
|
| 58 |
-
Element const * input_b = input + blockIdx.z * num_elems;
|
| 59 |
-
Element * output_b = output + (Batched ? 0 : blockIdx.z * num_elems);
|
| 60 |
-
for (int64_t k = threadIdx.x + blockIdx.x * blockDim.x; k < num_elems; k += blockDim.x * gridDim.x)
|
| 61 |
-
{
|
| 62 |
-
int i, j;
|
| 63 |
-
stride_divmod(j, i, k);
|
| 64 |
-
output_b[permute(cutlass::PitchLinearCoord(i, j))] = input_b[i + j * stride_divmod.divisor];
|
| 65 |
-
}
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
template<bool Batched, typename Permute, typename Element>
|
| 69 |
-
void permute(Element const* input,
|
| 70 |
-
Element * output,
|
| 71 |
-
int64_t num_elems,
|
| 72 |
-
int stride,
|
| 73 |
-
int batch_count,
|
| 74 |
-
cutlass::KernelHardwareInfo const& hw_info)
|
| 75 |
-
{
|
| 76 |
-
// Upcast to uint128_t data type
|
| 77 |
-
int factor = 128 / cutlass::sizeof_bits<Element>::value;
|
| 78 |
-
assert(stride % factor == 0);
|
| 79 |
-
int stride_upcast = stride/factor;
|
| 80 |
-
int64_t num_elems_upcast = num_elems / factor;
|
| 81 |
-
Permute permute_upcast(cutlass::PitchLinearCoord(stride_upcast, int(num_elems_upcast/stride_upcast)), stride_upcast);
|
| 82 |
-
|
| 83 |
-
cutlass::FastDivmod stride_divmod(stride);
|
| 84 |
-
dim3 blocks(hw_info.sm_count, 1, batch_count);
|
| 85 |
-
permute_kernel<Batched><<<blocks, 1024>>>(reinterpret_cast<cute::uint128_t const *>(input),
|
| 86 |
-
reinterpret_cast<cute::uint128_t *>(output),
|
| 87 |
-
permute_upcast,
|
| 88 |
-
num_elems_upcast,
|
| 89 |
-
stride_upcast);
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
} // namespace example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp
DELETED
|
@@ -1,274 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Additional permutation information for the example.
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#include "cutlass/layout/permute.h"
|
| 37 |
-
#include "cutlass/gemm/gemm.h"
|
| 38 |
-
|
| 39 |
-
namespace example
|
| 40 |
-
{
|
| 41 |
-
|
| 42 |
-
using namespace cute;
|
| 43 |
-
|
| 44 |
-
// This struct is specialized below for different CUTLASS 2.x permutation ops
|
| 45 |
-
// to describe the operation in terms of target CuTe shape and stride order.
|
| 46 |
-
template<class Permute>
|
| 47 |
-
struct PermuteTraits {};
|
| 48 |
-
|
| 49 |
-
// Use X as a placeholder for shape division result
|
| 50 |
-
using X = Underscore;
|
| 51 |
-
|
| 52 |
-
// Reshape a rank-2 shape into a multidimensional shape.
|
| 53 |
-
// Input:
|
| 54 |
-
// shape = (A, B, ...)
|
| 55 |
-
// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...)
|
| 56 |
-
// Output:
|
| 57 |
-
// ((A1, ..., A/prod(A1..Am), ..., Am), (B1, ..., B/prod(B1..Bn), ..., Bn), ...)
|
| 58 |
-
template<class Shape, class TargetShape>
|
| 59 |
-
constexpr auto
|
| 60 |
-
reshape(Shape const& shape, TargetShape const& target_shape)
|
| 61 |
-
{
|
| 62 |
-
if constexpr (is_tuple<Shape>::value) {
|
| 63 |
-
return cute::transform(shape, target_shape, [](auto && s, auto && t){ return reshape(s, t); });
|
| 64 |
-
}
|
| 65 |
-
else {
|
| 66 |
-
auto idx = find_if(target_shape, [](auto x){ return is_underscore<decltype(x)>{}; });
|
| 67 |
-
constexpr int I = decltype(idx)::value;
|
| 68 |
-
static_assert(I < tuple_size_v<TargetShape>, "Each mode of TargetShape must contain a placeholder X");
|
| 69 |
-
auto divisors = remove<I>(target_shape);
|
| 70 |
-
assert(shape % product(divisors) == 0);
|
| 71 |
-
return replace<I>(target_shape, shape / product(divisors));
|
| 72 |
-
}
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
// Given a tensor layout, compute a permutation layout consisting of:
|
| 76 |
-
// - sub-modes corresponding to the implied multidimensional shape of the source tensor
|
| 77 |
-
// - strides accounting for the permutation operation being performed
|
| 78 |
-
template<class Permute, bool Transpose, class Shape, class Stride>
|
| 79 |
-
constexpr auto
|
| 80 |
-
make_permute_layout(Layout<Shape,Stride> const& layout) {
|
| 81 |
-
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
| 82 |
-
if constexpr (Transpose) {
|
| 83 |
-
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
| 84 |
-
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
| 85 |
-
return select<1,0,2>(make_permute_layout<Permute, false>(select<1,0,2>(layout)));
|
| 86 |
-
}
|
| 87 |
-
else {
|
| 88 |
-
if constexpr (cutlass::layout::is_trivial_permute<Permute>) {
|
| 89 |
-
// Special case for NoPermute. Use a depth-2 layout for consistency with other permutations.
|
| 90 |
-
using ShapeProfile = tuple<tuple<X>, tuple<X>, tuple<X>>;
|
| 91 |
-
return unflatten(layout, ShapeProfile{});
|
| 92 |
-
}
|
| 93 |
-
else {
|
| 94 |
-
// Here's where the permutation layout is actually built
|
| 95 |
-
using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
|
| 96 |
-
using StrideOrder = typename PermuteTraits<Permute>::StrideOrder;
|
| 97 |
-
return make_ordered_layout(reshape(layout.shape(), ShapeProfile{}), StrideOrder{});
|
| 98 |
-
}
|
| 99 |
-
}
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
namespace detail
|
| 103 |
-
{
|
| 104 |
-
|
| 105 |
-
template<int I>
|
| 106 |
-
struct is_constant_pred {
|
| 107 |
-
template <class T>
|
| 108 |
-
constexpr auto operator()(T) {
|
| 109 |
-
return is_constant<I, T>{};
|
| 110 |
-
}
|
| 111 |
-
};
|
| 112 |
-
|
| 113 |
-
template<class Permutation, int... I>
|
| 114 |
-
constexpr auto
|
| 115 |
-
inverse_impl(Permutation const & perm, seq<I...>) {
|
| 116 |
-
return cute::make_tuple(Int<find_if(Permutation{}, is_constant_pred<I>{})>{}...);
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
} // namespace detail
|
| 120 |
-
|
| 121 |
-
// Compute an inverse of a permutation represented as a tuple of cute::Int<>
|
| 122 |
-
template<class Permutation>
|
| 123 |
-
constexpr auto
|
| 124 |
-
inverse(Permutation const & perm) {
|
| 125 |
-
auto flat_perm = flatten(perm);
|
| 126 |
-
return unflatten(detail::inverse_impl(flat_perm, tuple_seq<decltype(flat_perm)>{}), perm);
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
template<class T>
|
| 130 |
-
using inverse_t = decltype(inverse(T{}));
|
| 131 |
-
|
| 132 |
-
// Given a rank-2 layout of tensor that is assumed to have been permuted,
|
| 133 |
-
// compute the original rank-2 layout of the tensor prior to the permutation.
|
| 134 |
-
// This is needed to form the correct input to the standalone permutation kernel.
|
| 135 |
-
template<class Permute, bool Transpose, class Shape, class Stride>
|
| 136 |
-
constexpr auto
|
| 137 |
-
make_original_layout(Layout<Shape,Stride> const& layout) {
|
| 138 |
-
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
| 139 |
-
if constexpr (Transpose) {
|
| 140 |
-
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
| 141 |
-
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
| 142 |
-
return select<1,0,2>(make_original_layout<Permute, false>(select<1,0,2>(layout)));
|
| 143 |
-
}
|
| 144 |
-
else {
|
| 145 |
-
using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
|
| 146 |
-
auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{}));
|
| 147 |
-
using IndexOrder = typename PermuteTraits<Permute>::IndexOrder;
|
| 148 |
-
auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get<i>(re_shape); });
|
| 149 |
-
using OrigOrder = conditional_t<cutlass::gemm::detail::is_major<0,Stride>(), seq<0,1,2>, seq<1,0,2>>;
|
| 150 |
-
// print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n");
|
| 151 |
-
// print("Original shape: "); print(orig_shape); print("\n");
|
| 152 |
-
return make_ordered_layout(product_each(orig_shape), OrigOrder{});
|
| 153 |
-
}
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
/////////////// Tensor4DPermute0213 ////////////////////
|
| 157 |
-
|
| 158 |
-
template<int D1, int D2>
|
| 159 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermute0213ColumnMajor<D1, D2>>
|
| 160 |
-
{
|
| 161 |
-
static constexpr bool kBatched = false;
|
| 162 |
-
using ShapeProfile = Shape<Shape<X,Int<D1>>, Shape<Int<D2>,X>, Shape<X>>;
|
| 163 |
-
using IndexOrder = Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
|
| 164 |
-
using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
|
| 165 |
-
};
|
| 166 |
-
|
| 167 |
-
template<int D1, int D2>
|
| 168 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermute0213ColumnMajorInverse<D1, D2>>
|
| 169 |
-
{
|
| 170 |
-
static constexpr bool kBatched = false;
|
| 171 |
-
using ShapeProfile = Shape<Shape<X,Int<D2>>, Shape<Int<D1>,X>, Shape<X>>;
|
| 172 |
-
using IndexOrder = Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
|
| 173 |
-
using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_2>, Step<_1,_3>, Step<_4>>;
|
| 174 |
-
};
|
| 175 |
-
|
| 176 |
-
template<int D1, int D2>
|
| 177 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>>
|
| 178 |
-
{
|
| 179 |
-
static constexpr bool kBatched = false;
|
| 180 |
-
using ShapeProfile = Shape<Shape<Int<D1>,X>, Shape<X,Int<D2>>, Shape<X>>;
|
| 181 |
-
using IndexOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
|
| 182 |
-
using StrideOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
|
| 183 |
-
};
|
| 184 |
-
|
| 185 |
-
template<int D1, int D2>
|
| 186 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermute0213RowMajorInverse<D1, D2>>
|
| 187 |
-
{
|
| 188 |
-
static constexpr bool kBatched = false;
|
| 189 |
-
using ShapeProfile = Shape<Shape<Int<D2>,X>, Shape<X,Int<D1>>, Shape<X>>;
|
| 190 |
-
using IndexOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
|
| 191 |
-
using StrideOrder = Step<Step<_1,_3>, Step<_0,_2>, Step<_4>>;
|
| 192 |
-
};
|
| 193 |
-
|
| 194 |
-
/////////////// Tensor4DPermuteBMM0321 ////////////////////
|
| 195 |
-
|
| 196 |
-
template<int D>
|
| 197 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D>>
|
| 198 |
-
{
|
| 199 |
-
static constexpr bool kBatched = true;
|
| 200 |
-
using ShapeProfile = Shape<Shape<X>, Shape<X>, Shape<Int<D>,X>>;
|
| 201 |
-
using IndexOrder = Step<Step<_0,_2>, Step<_1>, Step<_3>>;
|
| 202 |
-
using StrideOrder = Step<Step<_0>, Step<_2>, Step<_1,_3>>;
|
| 203 |
-
};
|
| 204 |
-
|
| 205 |
-
template<int D>
|
| 206 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
|
| 207 |
-
{
|
| 208 |
-
static constexpr bool kBatched = true;
|
| 209 |
-
using ShapeProfile = Shape<Shape<X,Int<D>>, Shape<X>, Shape<X>>;
|
| 210 |
-
using IndexOrder = Step<Step<_0>, Step<_2>, Step<_1,_3>>;
|
| 211 |
-
using StrideOrder = Step<Step<_0,_2>, Step<_1>, Step<_3>>;
|
| 212 |
-
};
|
| 213 |
-
|
| 214 |
-
/////////////// Tensor4DPermuteBMM0213 ////////////////////
|
| 215 |
-
|
| 216 |
-
template<int D>
|
| 217 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D>>
|
| 218 |
-
{
|
| 219 |
-
static constexpr bool kBatched = true;
|
| 220 |
-
using ShapeProfile = Shape<Shape<X>, Shape<X>, Shape<Int<D>,X>>;
|
| 221 |
-
using IndexOrder = Step<Step<_0>, Step<_1,_2>, Step<_3>>;
|
| 222 |
-
using StrideOrder = Step<Step<_2>, Step<_0>, Step<_1,_3>>;
|
| 223 |
-
};
|
| 224 |
-
|
| 225 |
-
template<int D>
|
| 226 |
-
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
|
| 227 |
-
{
|
| 228 |
-
static constexpr bool kBatched = true;
|
| 229 |
-
using ShapeProfile = Shape<Shape<X>, Shape<X,Int<D>>, Shape<X>>;
|
| 230 |
-
using IndexOrder = Step<Step<_0>, Step<_1>, Step<_2,_3>>;
|
| 231 |
-
using StrideOrder = Step<Step<_1>, Step<_0,_2>, Step<_3>>;
|
| 232 |
-
};
|
| 233 |
-
|
| 234 |
-
/////////////// Tensor5DPermute02413 ////////////////////
|
| 235 |
-
|
| 236 |
-
template<int D1, int D2, int D3>
|
| 237 |
-
struct PermuteTraits<cutlass::layout::Tensor5DPermute02413ColumnMajor<D1, D2, D3>>
|
| 238 |
-
{
|
| 239 |
-
static constexpr bool kBatched = false;
|
| 240 |
-
using ShapeProfile = Shape<Shape<X,Int<D1>>, Shape<Int<D2>,Int<D3>,X>, Shape<X>>;
|
| 241 |
-
using IndexOrder = Step<Step<_0,_2>, Step<_4,_1,_3>, Step<_5>>;
|
| 242 |
-
using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_3>, Step<_1,_4,_2>, Step<_5>>;
|
| 243 |
-
};
|
| 244 |
-
|
| 245 |
-
template<int D1, int D2, int D3>
|
| 246 |
-
struct PermuteTraits<cutlass::layout::Tensor5DPermute02413ColumnMajorInverse<D1, D2, D3>>
|
| 247 |
-
{
|
| 248 |
-
static constexpr bool kBatched = false;
|
| 249 |
-
using ShapeProfile = Shape<Shape<X,Int<D2>>, Shape<X,Int<D1>,Int<D3>>, Shape<X>>;
|
| 250 |
-
using IndexOrder = Step<Step<_0,_3>, Step<_1,_4,_2>, Step<_5>>;
|
| 251 |
-
using StrideOrder = inverse_t<IndexOrder>; // Step<Step<_0,_2>, Step<_4,_1,_3>, Step<_5>>;
|
| 252 |
-
};
|
| 253 |
-
|
| 254 |
-
/////////////// Tensor5DPermute20314 ////////////////////
|
| 255 |
-
|
| 256 |
-
template<int D1, int D2, int D3>
|
| 257 |
-
struct PermuteTraits<cutlass::layout::Tensor5DPermute20314RowMajor<D1, D2, D3>>
|
| 258 |
-
{
|
| 259 |
-
static constexpr bool kBatched = false;
|
| 260 |
-
using ShapeProfile = Shape<Shape<Int<D1>,X>, Shape<X,Int<D3>,Int<D2>>, Shape<X>>;
|
| 261 |
-
using IndexOrder = Step<Step<_2,_0>, Step<_3,_1,_4>, Step<_5>>;
|
| 262 |
-
using StrideOrder = Step<Step<_1,_3>, Step<_0,_2,_4>, Step<_5>>;
|
| 263 |
-
};
|
| 264 |
-
|
| 265 |
-
template<int D1, int D2, int D3>
|
| 266 |
-
struct PermuteTraits<cutlass::layout::Tensor5DPermute20314RowMajorInverse<D1, D2, D3>>
|
| 267 |
-
{
|
| 268 |
-
static constexpr bool kBatched = false;
|
| 269 |
-
using ShapeProfile = Shape<Shape<X,Int<D2>>, Shape<X,Int<D1>,Int<D3>>, Shape<X>>;
|
| 270 |
-
using IndexOrder = Step<Step<_3,_0>, Step<_2,_4,_1>, Step<_5>>;
|
| 271 |
-
using StrideOrder = Step<Step<_4,_2>, Step<_0,_3,_1>, Step<_5>>;
|
| 272 |
-
};
|
| 273 |
-
|
| 274 |
-
} // namespace example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
// Command line options parsing
|
| 33 |
-
template<typename RasterOrderOptions>
|
| 34 |
-
struct Options {
|
| 35 |
-
|
| 36 |
-
bool help = false;
|
| 37 |
-
|
| 38 |
-
float alpha = 1.f, beta = 0.f;
|
| 39 |
-
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
|
| 40 |
-
bool device_scale = false;
|
| 41 |
-
bool save_aux = true;
|
| 42 |
-
bool save_amax = true;
|
| 43 |
-
int iterations = 1000;
|
| 44 |
-
int m = 1024, n = 512, k = 1024, l = 1;
|
| 45 |
-
RasterOrderOptions raster;
|
| 46 |
-
int swizzle;
|
| 47 |
-
|
| 48 |
-
// Parses the command line
|
| 49 |
-
void parse(int argc, char const **args) {
|
| 50 |
-
cutlass::CommandLine cmd(argc, args);
|
| 51 |
-
|
| 52 |
-
if (cmd.check_cmd_line_flag("help")) {
|
| 53 |
-
help = true;
|
| 54 |
-
return;
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
cmd.get_cmd_line_argument("m", m);
|
| 58 |
-
cmd.get_cmd_line_argument("n", n);
|
| 59 |
-
cmd.get_cmd_line_argument("k", k);
|
| 60 |
-
cmd.get_cmd_line_argument("l", l);
|
| 61 |
-
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
| 62 |
-
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
| 63 |
-
cmd.get_cmd_line_argument("scale_a", scale_a, 1.f);
|
| 64 |
-
cmd.get_cmd_line_argument("scale_b", scale_b, 1.f);
|
| 65 |
-
cmd.get_cmd_line_argument("scale_c", scale_c, 1.f);
|
| 66 |
-
cmd.get_cmd_line_argument("scale_d", scale_d, 1.f);
|
| 67 |
-
cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f);
|
| 68 |
-
cmd.get_cmd_line_argument("device_scale", device_scale, false);
|
| 69 |
-
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
| 70 |
-
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
| 71 |
-
cmd.get_cmd_line_argument("iterations", iterations);
|
| 72 |
-
|
| 73 |
-
char raster_char;
|
| 74 |
-
cmd.get_cmd_line_argument("raster", raster_char);
|
| 75 |
-
|
| 76 |
-
if (raster_char == 'N' || raster_char == 'n') {
|
| 77 |
-
raster = RasterOrderOptions::AlongN;
|
| 78 |
-
}
|
| 79 |
-
else if (raster_char == 'M' || raster_char == 'm') {
|
| 80 |
-
raster = RasterOrderOptions::AlongM;
|
| 81 |
-
}
|
| 82 |
-
else if (raster_char == 'H' || raster_char == 'h') {
|
| 83 |
-
raster = RasterOrderOptions::Heuristic;
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
/// Prints the usage statement.
|
| 90 |
-
std::ostream & print_usage(std::ostream &out) const {
|
| 91 |
-
|
| 92 |
-
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
|
| 93 |
-
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
| 94 |
-
<< "Options:\n\n"
|
| 95 |
-
<< " --help If specified, displays this usage statement\n\n"
|
| 96 |
-
<< " --m=<int> Sets the M extent of the GEMM\n"
|
| 97 |
-
<< " --n=<int> Sets the N extent of the GEMM\n"
|
| 98 |
-
<< " --k=<int> Sets the K extent of the GEMM\n"
|
| 99 |
-
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
| 100 |
-
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
| 101 |
-
<< " --beta=<f32> Epilogue scalar beta\n"
|
| 102 |
-
<< " --scale_a=<f32> Scaling factor for A\n"
|
| 103 |
-
<< " --scale_b=<f32> Scaling factor for B\n"
|
| 104 |
-
<< " --scale_c=<f32> Scaling factor for C\n"
|
| 105 |
-
<< " --scale_d=<f32> Scaling factor for D (ignored for non-fp8 D)\n"
|
| 106 |
-
<< " --scale_aux=<f32> Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n"
|
| 107 |
-
<< " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
|
| 108 |
-
<< " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
|
| 109 |
-
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
| 110 |
-
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
| 111 |
-
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
| 112 |
-
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
| 113 |
-
|
| 114 |
-
out
|
| 115 |
-
<< "\n\nExamples:\n\n"
|
| 116 |
-
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
| 117 |
-
|
| 118 |
-
return out;
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
/// Compute performance in GFLOP/s
|
| 122 |
-
double gflops(double runtime_s) const
|
| 123 |
-
{
|
| 124 |
-
// Two flops per multiply-add
|
| 125 |
-
uint64_t flop = uint64_t(2) * m * n * k;
|
| 126 |
-
double gflop = double(flop) / double(1.0e9);
|
| 127 |
-
return gflop / runtime_s;
|
| 128 |
-
}
|
| 129 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp
DELETED
|
@@ -1,246 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 36 |
-
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
| 37 |
-
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
| 38 |
-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 39 |
-
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
| 40 |
-
#include "cutlass/util/command_line.h"
|
| 41 |
-
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 42 |
-
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 43 |
-
|
| 44 |
-
#include "cute/tensor.hpp"
|
| 45 |
-
|
| 46 |
-
#include <cuda.h>
|
| 47 |
-
#include <numeric>
|
| 48 |
-
#include "helper.h"
|
| 49 |
-
|
| 50 |
-
enum MixedDtypeGemmMode {
|
| 51 |
-
ConvertOnly,
|
| 52 |
-
ScaleOnly,
|
| 53 |
-
ScaleWithZeroPoint
|
| 54 |
-
};
|
| 55 |
-
|
| 56 |
-
/// Command line options parsing
|
| 57 |
-
struct MixedDtypeOptions {
|
| 58 |
-
|
| 59 |
-
bool help = false;
|
| 60 |
-
|
| 61 |
-
float alpha = 1.0f;
|
| 62 |
-
float beta = 0.0f;
|
| 63 |
-
int iterations = 100;
|
| 64 |
-
int warmup = 10;
|
| 65 |
-
int mode = 1;
|
| 66 |
-
int m = 5120, n = 4096, k = 4096;
|
| 67 |
-
int g = 128;
|
| 68 |
-
int l = 1;
|
| 69 |
-
|
| 70 |
-
// Parses the command line
|
| 71 |
-
void parse(int argc, char const **args) {
|
| 72 |
-
cutlass::CommandLine cmd(argc, args);
|
| 73 |
-
|
| 74 |
-
if (cmd.check_cmd_line_flag("help")) {
|
| 75 |
-
help = true;
|
| 76 |
-
return;
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
cmd.get_cmd_line_argument("m", m);
|
| 80 |
-
cmd.get_cmd_line_argument("n", n);
|
| 81 |
-
cmd.get_cmd_line_argument("k", k);
|
| 82 |
-
cmd.get_cmd_line_argument("l", l);
|
| 83 |
-
cmd.get_cmd_line_argument("g", g);
|
| 84 |
-
cmd.get_cmd_line_argument("mode", mode);
|
| 85 |
-
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
| 86 |
-
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
| 87 |
-
cmd.get_cmd_line_argument("iterations", iterations);
|
| 88 |
-
cmd.get_cmd_line_argument("warmup", warmup);
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
/// Prints the usage statement.
|
| 92 |
-
std::ostream & print_usage(std::ostream &out) const {
|
| 93 |
-
|
| 94 |
-
out << "55_hopper_mixed_dtype_gemm\n\n"
|
| 95 |
-
<< " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
|
| 96 |
-
<< "Options:\n\n"
|
| 97 |
-
<< " --help If specified, displays this usage statement\n\n"
|
| 98 |
-
<< " --m=<int> Sets the M extent of the GEMM\n"
|
| 99 |
-
<< " --n=<int> Sets the N extent of the GEMM\n"
|
| 100 |
-
<< " --k=<int> Sets the K extent of the GEMM\n"
|
| 101 |
-
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
| 102 |
-
<< " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
| 103 |
-
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
| 104 |
-
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
| 105 |
-
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
| 106 |
-
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
| 107 |
-
<< " --warmup=<int> Number of warmup iterations to perform.\n\n";
|
| 108 |
-
|
| 109 |
-
out
|
| 110 |
-
<< "\n\nExamples:\n\n"
|
| 111 |
-
<< "$ " << "55_hopper_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
| 112 |
-
|
| 113 |
-
return out;
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
/// Compute performance in GFLOP/s
|
| 117 |
-
double gflops(double runtime_s) const
|
| 118 |
-
{
|
| 119 |
-
// Two flops per multiply-add
|
| 120 |
-
uint64_t flop = uint64_t(2) * m * n * k * l;
|
| 121 |
-
double gflop = double(flop) / double(1.0e9);
|
| 122 |
-
return gflop / runtime_s;
|
| 123 |
-
}
|
| 124 |
-
};
|
| 125 |
-
|
| 126 |
-
/// Result structure
|
| 127 |
-
struct MixedDtypeResult
|
| 128 |
-
{
|
| 129 |
-
double avg_runtime_ms = 0.0;
|
| 130 |
-
double gflops = 0.0;
|
| 131 |
-
cutlass::Status status = cutlass::Status::kSuccess;
|
| 132 |
-
cudaError_t error = cudaSuccess;
|
| 133 |
-
bool passed = false;
|
| 134 |
-
|
| 135 |
-
};
|
| 136 |
-
|
| 137 |
-
/// Profiling Loop
|
| 138 |
-
template <class Gemm>
|
| 139 |
-
void mixed_dtype_profiling(
|
| 140 |
-
Gemm& gemm,
|
| 141 |
-
MixedDtypeOptions const& options,
|
| 142 |
-
MixedDtypeResult& result) {
|
| 143 |
-
|
| 144 |
-
if (options.iterations <= 0) return;
|
| 145 |
-
|
| 146 |
-
cudaEvent_t start, stop;
|
| 147 |
-
cudaEventCreate(&start);
|
| 148 |
-
cudaEventCreate(&stop);
|
| 149 |
-
|
| 150 |
-
std::vector<float> runtimes;
|
| 151 |
-
runtimes.reserve(options.iterations);
|
| 152 |
-
|
| 153 |
-
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
| 154 |
-
cudaEventRecord(start);
|
| 155 |
-
CUTLASS_CHECK(gemm.run());
|
| 156 |
-
cudaEventRecord(stop);
|
| 157 |
-
cudaEventSynchronize(stop);
|
| 158 |
-
|
| 159 |
-
if (iter >= options.warmup) {
|
| 160 |
-
float milliseconds = 0;
|
| 161 |
-
cudaEventElapsedTime(&milliseconds, start, stop);
|
| 162 |
-
runtimes.push_back(milliseconds);
|
| 163 |
-
}
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
cudaEventDestroy(start);
|
| 167 |
-
cudaEventDestroy(stop);
|
| 168 |
-
|
| 169 |
-
// Compute average setup and runtime and GFLOPs.
|
| 170 |
-
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
| 171 |
-
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
| 172 |
-
|
| 173 |
-
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
| 174 |
-
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
| 175 |
-
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
| 176 |
-
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
/// Helpers to initialize a block of device data
|
| 180 |
-
template <class Element>
|
| 181 |
-
bool initialize_tensor(
|
| 182 |
-
cutlass::DeviceAllocation<Element>& block,
|
| 183 |
-
uint64_t seed = 2023) {
|
| 184 |
-
|
| 185 |
-
double scope_max, scope_min;
|
| 186 |
-
int bits_input = cutlass::sizeof_bits<Element>::value;
|
| 187 |
-
int bits_output = cutlass::sizeof_bits<Element>::value;
|
| 188 |
-
|
| 189 |
-
if (bits_input == 1) {
|
| 190 |
-
scope_max = 2;
|
| 191 |
-
scope_min = 0;
|
| 192 |
-
}
|
| 193 |
-
else if (bits_input <= 8) {
|
| 194 |
-
scope_max = 2;
|
| 195 |
-
scope_min = -2;
|
| 196 |
-
}
|
| 197 |
-
else if (bits_output == 16) {
|
| 198 |
-
scope_max = 5;
|
| 199 |
-
scope_min = -5;
|
| 200 |
-
}
|
| 201 |
-
else {
|
| 202 |
-
scope_max = 8;
|
| 203 |
-
scope_min = -8;
|
| 204 |
-
}
|
| 205 |
-
cutlass::reference::device::BlockFillRandomUniform(
|
| 206 |
-
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
| 207 |
-
|
| 208 |
-
return true;
|
| 209 |
-
}
|
| 210 |
-
|
| 211 |
-
template <class Element>
|
| 212 |
-
bool initialize_scale(
|
| 213 |
-
cutlass::DeviceAllocation<Element>& block,
|
| 214 |
-
MixedDtypeOptions const& options,
|
| 215 |
-
uint64_t seed = 2023) {
|
| 216 |
-
|
| 217 |
-
// If no scales, initialize with 1 so we can use the same kernel to dequantize the data
|
| 218 |
-
float scope_max = 1.0f, scope_min = 1.0f;
|
| 219 |
-
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
|
| 220 |
-
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
|
| 221 |
-
scope_max = 2.f;
|
| 222 |
-
scope_min = 0.1f;
|
| 223 |
-
}
|
| 224 |
-
cutlass::reference::device::BlockFillRandomUniform(
|
| 225 |
-
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
| 226 |
-
|
| 227 |
-
return true;
|
| 228 |
-
}
|
| 229 |
-
|
| 230 |
-
template <class Element>
|
| 231 |
-
bool initialize_zero(
|
| 232 |
-
cutlass::DeviceAllocation<Element>& block,
|
| 233 |
-
MixedDtypeOptions const& options,
|
| 234 |
-
uint64_t seed = 2023) {
|
| 235 |
-
|
| 236 |
-
// If no bias, initialize with 0 so we can use the same kernel to dequantize the data
|
| 237 |
-
float scope_max = 0.0f, scope_min = 0.0f;
|
| 238 |
-
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
| 239 |
-
scope_max = 2.0f;
|
| 240 |
-
scope_min = -2.0f;
|
| 241 |
-
}
|
| 242 |
-
cutlass::reference::device::BlockFillRandomUniform(
|
| 243 |
-
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
| 244 |
-
|
| 245 |
-
return true;
|
| 246 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h
DELETED
|
@@ -1,320 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include "cute/tensor.hpp"
|
| 34 |
-
#include "cute/atom/mma_atom.hpp"
|
| 35 |
-
#include "cute/atom/copy_atom.hpp"
|
| 36 |
-
#include <random>
|
| 37 |
-
|
| 38 |
-
#include "cutlass/util/print_error.hpp"
|
| 39 |
-
|
| 40 |
-
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 41 |
-
#include "cutlass/gemm/collective/collective_mma.hpp"
|
| 42 |
-
|
| 43 |
-
using namespace cute;
|
| 44 |
-
|
| 45 |
-
struct AmpereUnpredicatedFprop {
|
| 46 |
-
//
|
| 47 |
-
// Static config for conv problem shape
|
| 48 |
-
//
|
| 49 |
-
using D = _6;
|
| 50 |
-
using H = _4;
|
| 51 |
-
using W = _4;
|
| 52 |
-
|
| 53 |
-
using T = _3;
|
| 54 |
-
using R = _3;
|
| 55 |
-
using S = _3;
|
| 56 |
-
|
| 57 |
-
using Z = _4;
|
| 58 |
-
using P = _2;
|
| 59 |
-
using Q = _2;
|
| 60 |
-
|
| 61 |
-
using C = _64;
|
| 62 |
-
using K = _128;
|
| 63 |
-
|
| 64 |
-
// Tiler config
|
| 65 |
-
using Tiler_K = decltype(cute::min(K{}, _128{}));
|
| 66 |
-
using Tiler_C = decltype(cute::min(C{}, _32{}));
|
| 67 |
-
using Tiler_N = _4;
|
| 68 |
-
using TileM = Tiler_K;
|
| 69 |
-
using TileN = Shape<Tiler_N, Z, P, Q>;
|
| 70 |
-
using TileK = Shape<Tiler_C,_1,_1,_1>;
|
| 71 |
-
using PIPE = _3;
|
| 72 |
-
using TilerFlt = Shape<TileM, TileK>;
|
| 73 |
-
using TilerAct = Shape<TileN, TileK>;
|
| 74 |
-
using TilerOut = Shape<TileM, TileN>;
|
| 75 |
-
|
| 76 |
-
using TileSizeM = Int<size(TileM{})>;
|
| 77 |
-
using TileSizeN = Int<size(TileN{})>;
|
| 78 |
-
using TileSizeK = Int<size(TileK{})>;
|
| 79 |
-
static constexpr int Stages = PIPE::value;
|
| 80 |
-
|
| 81 |
-
using ElementFlt = tfloat32_t;
|
| 82 |
-
using ElementAct = tfloat32_t;
|
| 83 |
-
using ElementOut = float;
|
| 84 |
-
|
| 85 |
-
using TiledMma = TiledMMA<
|
| 86 |
-
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
| 87 |
-
Layout<Shape<_2,_2,_1>>,
|
| 88 |
-
Tile<_32,_32,Underscore>>;
|
| 89 |
-
|
| 90 |
-
static constexpr int MaxThreadsPerBlock = size(TiledMma{});
|
| 91 |
-
static constexpr int MinBlocksPerMultiprocessor = 1;
|
| 92 |
-
|
| 93 |
-
union SharedStorage {
|
| 94 |
-
struct {
|
| 95 |
-
ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})];
|
| 96 |
-
ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})];
|
| 97 |
-
} mainloop;
|
| 98 |
-
|
| 99 |
-
struct {
|
| 100 |
-
ElementOut sCMatrix[size(TileM{}) * size(TileN{})];
|
| 101 |
-
} epilogue;
|
| 102 |
-
};
|
| 103 |
-
|
| 104 |
-
//
|
| 105 |
-
// Stencil tensor
|
| 106 |
-
//
|
| 107 |
-
|
| 108 |
-
using GmemLayoutFlt = decltype(make_ordered_layout(
|
| 109 |
-
Shape< K, Shape< C, T, R, S>>{},
|
| 110 |
-
tuple<_4, tuple<_0,_3,_2,_1>>{}));
|
| 111 |
-
|
| 112 |
-
// We have 64 elements * 32b each in the major mode that we can vectorize
|
| 113 |
-
// Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4
|
| 114 |
-
// Rest along the minor mode
|
| 115 |
-
using GmemTiledCopyFlt = decltype(make_tiled_copy(
|
| 116 |
-
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementFlt>{},
|
| 117 |
-
Layout<Shape <_16, _8>,
|
| 118 |
-
Stride< _8, _1>>{},
|
| 119 |
-
Layout<Shape < _1, _4>>{}));
|
| 120 |
-
|
| 121 |
-
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
|
| 122 |
-
// using SmemLayoutFlt = decltype(
|
| 123 |
-
// composition(Swizzle<3,2,3>{},
|
| 124 |
-
// make_ordered_layout(
|
| 125 |
-
// Shape<TileSizeM,TileSizeK,PIPE>{},
|
| 126 |
-
// tuple< _1, _0, _2>{})));
|
| 127 |
-
|
| 128 |
-
using SmemLayoutAtomFlt = decltype(
|
| 129 |
-
composition(Swizzle<1,2,3>{},
|
| 130 |
-
Layout<Shape <_8,Shape <_4, _2>>,
|
| 131 |
-
Stride<_4,Stride<_1,_32>>>{}));
|
| 132 |
-
|
| 133 |
-
using SmemCopyAtomFlt = Copy_Atom<SM75_U32x4_LDSM_N, ElementFlt>;
|
| 134 |
-
|
| 135 |
-
//
|
| 136 |
-
// Activation tensor
|
| 137 |
-
//
|
| 138 |
-
|
| 139 |
-
// Activation tensor is major in the contraction mode, so vectorize that mode first
|
| 140 |
-
// Then lay out the rest of the threads along the other mode
|
| 141 |
-
using GmemTiledCopyAct = decltype(make_tiled_copy(
|
| 142 |
-
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementAct>{},
|
| 143 |
-
Layout<Shape <_16, _8>,
|
| 144 |
-
Stride< _8, _1>>{},
|
| 145 |
-
Layout<Shape < _1, _4>>{}));
|
| 146 |
-
|
| 147 |
-
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
|
| 148 |
-
// using SmemLayoutAct = decltype(
|
| 149 |
-
// composition(Swizzle<3,2,3>{},
|
| 150 |
-
// make_ordered_layout(
|
| 151 |
-
// Shape<TileSizeN,TileSizeK,PIPE>{},
|
| 152 |
-
// tuple< _1, _0, _2>{})));
|
| 153 |
-
|
| 154 |
-
using SmemLayoutAtomAct = decltype(
|
| 155 |
-
composition(Swizzle<1,2,3>{},
|
| 156 |
-
Layout<Shape <_8,Shape <_4, _2>>,
|
| 157 |
-
Stride<_4,Stride<_1,_32>>>{}));
|
| 158 |
-
|
| 159 |
-
using SmemCopyAtomAct = Copy_Atom<SM75_U32x4_LDSM_N, ElementAct>;
|
| 160 |
-
|
| 161 |
-
//
|
| 162 |
-
// Output tensor
|
| 163 |
-
//
|
| 164 |
-
|
| 165 |
-
using GmemTiledCopyOut = decltype(make_tiled_copy(
|
| 166 |
-
Copy_Atom<UniversalCopy<uint128_t>, ElementAct>{},
|
| 167 |
-
Layout<Shape <_8, _16>,
|
| 168 |
-
Stride<_1, _8>>{},
|
| 169 |
-
Layout<Shape <_4, _1>>{}));
|
| 170 |
-
|
| 171 |
-
using SmemCopyAtomOut = Copy_Atom<UniversalCopy<uint32_t>, ElementOut>;
|
| 172 |
-
|
| 173 |
-
// This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability
|
| 174 |
-
using SmemLayoutOut = Layout<Shape<TileSizeM, TileSizeN>>;
|
| 175 |
-
|
| 176 |
-
//
|
| 177 |
-
// Conv functor
|
| 178 |
-
//
|
| 179 |
-
template <class EngineFlt, class TensorActivation, class TensorOutput>
|
| 180 |
-
void __device__
|
| 181 |
-
operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K, (C,T,R,S))
|
| 182 |
-
TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S))
|
| 183 |
-
TensorOutput mOut, // ( K, (N,Z,P,Q))
|
| 184 |
-
char* smem_buf) const {
|
| 185 |
-
using namespace cute;
|
| 186 |
-
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
|
| 187 |
-
cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
|
| 188 |
-
Shape<TileM,TileN,TileK>,
|
| 189 |
-
ElementFlt,
|
| 190 |
-
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
|
| 191 |
-
ElementAct,
|
| 192 |
-
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
|
| 193 |
-
TiledMma,
|
| 194 |
-
GmemTiledCopyFlt,
|
| 195 |
-
SmemLayoutAtomFlt,
|
| 196 |
-
SmemCopyAtomFlt,
|
| 197 |
-
cute::identity,
|
| 198 |
-
GmemTiledCopyAct,
|
| 199 |
-
SmemLayoutAtomAct,
|
| 200 |
-
SmemCopyAtomAct,
|
| 201 |
-
cute::identity>;
|
| 202 |
-
|
| 203 |
-
TiledMma tiled_mma;
|
| 204 |
-
Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
|
| 205 |
-
clear(accum);
|
| 206 |
-
|
| 207 |
-
// Set up tensors
|
| 208 |
-
// NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
|
| 209 |
-
Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k')
|
| 210 |
-
Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1)
|
| 211 |
-
Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n')
|
| 212 |
-
|
| 213 |
-
// Compute m_coord and n_coord with their post-tiled shapes
|
| 214 |
-
auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
|
| 215 |
-
auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
|
| 216 |
-
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k')
|
| 217 |
-
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1)
|
| 218 |
-
Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N)
|
| 219 |
-
|
| 220 |
-
auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
|
| 221 |
-
int k_tile_count = size<2>(gA);
|
| 222 |
-
|
| 223 |
-
CollectiveMainloop collective_mma;
|
| 224 |
-
collective_mma(
|
| 225 |
-
accum,
|
| 226 |
-
gA,
|
| 227 |
-
gB,
|
| 228 |
-
accum,
|
| 229 |
-
k_tile_iter, k_tile_count,
|
| 230 |
-
Underscore{}, // no residue since we do not support predication
|
| 231 |
-
threadIdx.x,
|
| 232 |
-
smem_buf);
|
| 233 |
-
|
| 234 |
-
//
|
| 235 |
-
// Epilogue
|
| 236 |
-
//
|
| 237 |
-
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 238 |
-
Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{});
|
| 239 |
-
|
| 240 |
-
auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma);
|
| 241 |
-
auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x);
|
| 242 |
-
auto tCrC = smem_thr_copy_C.retile_S(accum);
|
| 243 |
-
auto tCsC = smem_thr_copy_C.partition_D(sC);
|
| 244 |
-
copy(smem_tiled_copy_C, tCrC, tCsC);
|
| 245 |
-
|
| 246 |
-
__syncthreads();
|
| 247 |
-
|
| 248 |
-
GmemTiledCopyOut gmem_tiled_copy_C;
|
| 249 |
-
auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x);
|
| 250 |
-
auto tDsC = gmem_thr_copy_C.partition_S(sC);
|
| 251 |
-
auto tDgC = gmem_thr_copy_C.partition_D(gC);
|
| 252 |
-
copy(gmem_tiled_copy_C, tDsC, tDgC);
|
| 253 |
-
|
| 254 |
-
#if 0
|
| 255 |
-
if (thread0()) {
|
| 256 |
-
print("mAct = "); print(mAct); print('\n');
|
| 257 |
-
print("mFlt = "); print(mFlt); print('\n');
|
| 258 |
-
print("mOut = "); print(mOut); print('\n');
|
| 259 |
-
print("gA = "); print(gA); print('\n');
|
| 260 |
-
print("gB = "); print(gB); print('\n');
|
| 261 |
-
print("gC = "); print(gC); print('\n');
|
| 262 |
-
print("sA = "); print(sA.layout()); print('\n');
|
| 263 |
-
print("sB = "); print(sB.layout()); print('\n');
|
| 264 |
-
print("sC = "); print(sC.layout()); print('\n');
|
| 265 |
-
print("tAgA = "); print(tAgA.layout()); print('\n');
|
| 266 |
-
print("tBgB = "); print(tBgB.layout()); print('\n');
|
| 267 |
-
print("tAsA = "); print(tAsA.layout()); print('\n');
|
| 268 |
-
print("tBsB = "); print(tBsB.layout()); print('\n');
|
| 269 |
-
print("tCsA = "); print(tCsA.layout()); print('\n');
|
| 270 |
-
print("tCsB = "); print(tCsB.layout()); print('\n');
|
| 271 |
-
print("tCrC = "); print(tCrC.layout()); print('\n');
|
| 272 |
-
print("tCsC = "); print(tCsC.layout()); print('\n');
|
| 273 |
-
print("tDsC = "); print(tDsC.layout()); print('\n');
|
| 274 |
-
print("tDgC = "); print(tDgC.layout()); print('\n');
|
| 275 |
-
print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n');
|
| 276 |
-
print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n');
|
| 277 |
-
print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n');
|
| 278 |
-
print("k_tile_count = "); print(size<2>(gA)); print('\n');
|
| 279 |
-
print("k_tile_iter = "); print(*k_tile_iter); print('\n');
|
| 280 |
-
print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n');
|
| 281 |
-
}
|
| 282 |
-
#endif
|
| 283 |
-
}
|
| 284 |
-
};
|
| 285 |
-
|
| 286 |
-
template <class TensorFlt, class TensorAct, class TensorOut>
|
| 287 |
-
inline int
|
| 288 |
-
fprop_reference(
|
| 289 |
-
TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S))
|
| 290 |
-
TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S))
|
| 291 |
-
TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q))
|
| 292 |
-
TensorOut mOutputRef) {
|
| 293 |
-
int32_t N = size<1,0>(mOutputRef);
|
| 294 |
-
int32_t Z = size<1,1>(mOutputRef);
|
| 295 |
-
int32_t P = size<1,2>(mOutputRef);
|
| 296 |
-
int32_t Q = size<1,3>(mOutputRef);
|
| 297 |
-
int32_t T = size<1,3>(mStencil);
|
| 298 |
-
int32_t R = size<1,2>(mStencil);
|
| 299 |
-
int32_t S = size<1,1>(mStencil);
|
| 300 |
-
int32_t C = size<1,0>(mStencil);
|
| 301 |
-
|
| 302 |
-
size_t K = static_cast<size_t>(size<0>(mOutputRef));
|
| 303 |
-
size_t NZPQ = static_cast<size_t>(size<1>(mOutputRef));
|
| 304 |
-
size_t CTRS = static_cast<size_t>(size<1>(mStencil));
|
| 305 |
-
|
| 306 |
-
#if defined(_OPENMP)
|
| 307 |
-
#pragma omp parallel for
|
| 308 |
-
#endif
|
| 309 |
-
for (size_t logical_m = 0; logical_m < K; ++logical_m) {
|
| 310 |
-
for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) {
|
| 311 |
-
auto accumulator = float(0);
|
| 312 |
-
for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) {
|
| 313 |
-
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
|
| 314 |
-
}
|
| 315 |
-
mOutputRef(logical_m, logical_n) = accumulator;
|
| 316 |
-
}
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01);
|
| 320 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp
DELETED
|
@@ -1,242 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/gemm/collective/collective_builder.hpp"
|
| 35 |
-
|
| 36 |
-
#include "dispatch_policy_extra.hpp"
|
| 37 |
-
#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
|
| 38 |
-
#include "../pipeline/prefetch_pipeline_sm90.hpp"
|
| 39 |
-
|
| 40 |
-
namespace cutlass::gemm::collective {
|
| 41 |
-
|
| 42 |
-
namespace detail {
|
| 43 |
-
|
| 44 |
-
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
| 45 |
-
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int stages>
|
| 46 |
-
constexpr int
|
| 47 |
-
compute_stage_count_or_override_prefetch(StageCount<stages> stage_count) {
|
| 48 |
-
return stages;
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
| 52 |
-
template<int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes>
|
| 53 |
-
constexpr int
|
| 54 |
-
compute_stage_count_or_override_prefetch(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
| 55 |
-
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
|
| 56 |
-
constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>);
|
| 57 |
-
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
|
| 58 |
-
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
|
| 59 |
-
constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size
|
| 60 |
-
constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}));
|
| 61 |
-
constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast<int>(mainloop_pipeline_bytes);
|
| 62 |
-
|
| 63 |
-
return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes;
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
} // namespace detail
|
| 67 |
-
|
| 68 |
-
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
|
| 69 |
-
template <
|
| 70 |
-
class ElementA,
|
| 71 |
-
class GmemLayoutATag,
|
| 72 |
-
int AlignmentA,
|
| 73 |
-
class ElementB,
|
| 74 |
-
class GmemLayoutBTag,
|
| 75 |
-
int AlignmentB,
|
| 76 |
-
class ElementAccumulator,
|
| 77 |
-
class TileShape_MNK,
|
| 78 |
-
class ClusterShape_MNK,
|
| 79 |
-
class StageCountType,
|
| 80 |
-
class KernelScheduleType
|
| 81 |
-
>
|
| 82 |
-
struct CollectiveBuilder<
|
| 83 |
-
arch::Sm90,
|
| 84 |
-
arch::OpClassTensorOp,
|
| 85 |
-
ElementA,
|
| 86 |
-
GmemLayoutATag,
|
| 87 |
-
AlignmentA,
|
| 88 |
-
ElementB,
|
| 89 |
-
GmemLayoutBTag,
|
| 90 |
-
AlignmentB,
|
| 91 |
-
ElementAccumulator,
|
| 92 |
-
TileShape_MNK,
|
| 93 |
-
ClusterShape_MNK,
|
| 94 |
-
StageCountType,
|
| 95 |
-
KernelScheduleType,
|
| 96 |
-
cute::enable_if_t<
|
| 97 |
-
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>>
|
| 98 |
-
> {
|
| 99 |
-
static_assert(is_static<TileShape_MNK>::value);
|
| 100 |
-
static_assert(is_static<ClusterShape_MNK>::value);
|
| 101 |
-
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
| 102 |
-
"Not meet TMA alignment requirement yet\n");
|
| 103 |
-
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
|
| 104 |
-
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
| 105 |
-
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
| 106 |
-
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
|
| 107 |
-
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
| 108 |
-
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
| 109 |
-
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
| 110 |
-
#endif
|
| 111 |
-
|
| 112 |
-
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
| 113 |
-
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
| 114 |
-
|
| 115 |
-
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
|
| 116 |
-
|
| 117 |
-
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
| 118 |
-
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
| 119 |
-
|
| 120 |
-
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
| 121 |
-
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
| 122 |
-
|
| 123 |
-
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
| 124 |
-
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
| 125 |
-
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
| 126 |
-
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
| 127 |
-
|
| 128 |
-
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
|
| 129 |
-
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
| 130 |
-
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
| 131 |
-
|
| 132 |
-
using SmemCopyAtomA = void;
|
| 133 |
-
using SmemCopyAtomB = void;
|
| 134 |
-
|
| 135 |
-
using CollectiveOp = CollectiveMma<
|
| 136 |
-
DispatchPolicy,
|
| 137 |
-
TileShape_MNK,
|
| 138 |
-
ElementA,
|
| 139 |
-
TagToStrideA_t<GmemLayoutATag>,
|
| 140 |
-
ElementB,
|
| 141 |
-
TagToStrideB_t<GmemLayoutBTag>,
|
| 142 |
-
TiledMma,
|
| 143 |
-
GmemTiledCopyA,
|
| 144 |
-
SmemLayoutAtomA,
|
| 145 |
-
SmemCopyAtomA,
|
| 146 |
-
cute::identity,
|
| 147 |
-
GmemTiledCopyB,
|
| 148 |
-
SmemLayoutAtomB,
|
| 149 |
-
SmemCopyAtomB,
|
| 150 |
-
cute::identity
|
| 151 |
-
>;
|
| 152 |
-
};
|
| 153 |
-
|
| 154 |
-
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps
|
| 155 |
-
template <
|
| 156 |
-
class ElementA,
|
| 157 |
-
class GmemLayoutATag,
|
| 158 |
-
int AlignmentA,
|
| 159 |
-
class ElementB,
|
| 160 |
-
class GmemLayoutBTag,
|
| 161 |
-
int AlignmentB,
|
| 162 |
-
class ElementAccumulator,
|
| 163 |
-
class TileShape_MNK,
|
| 164 |
-
class ClusterShape_MNK,
|
| 165 |
-
class StageCountType,
|
| 166 |
-
class KernelScheduleType
|
| 167 |
-
>
|
| 168 |
-
struct CollectiveBuilder<
|
| 169 |
-
arch::Sm90,
|
| 170 |
-
arch::OpClassTensorOp,
|
| 171 |
-
ElementA,
|
| 172 |
-
GmemLayoutATag,
|
| 173 |
-
AlignmentA,
|
| 174 |
-
ElementB,
|
| 175 |
-
GmemLayoutBTag,
|
| 176 |
-
AlignmentB,
|
| 177 |
-
ElementAccumulator,
|
| 178 |
-
TileShape_MNK,
|
| 179 |
-
ClusterShape_MNK,
|
| 180 |
-
StageCountType,
|
| 181 |
-
KernelScheduleType,
|
| 182 |
-
cute::enable_if_t<
|
| 183 |
-
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>>
|
| 184 |
-
> {
|
| 185 |
-
static_assert(is_static<TileShape_MNK>::value);
|
| 186 |
-
static_assert(is_static<ClusterShape_MNK>::value);
|
| 187 |
-
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
| 188 |
-
"Not meet TMA alignment requirement yet\n");
|
| 189 |
-
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
|
| 190 |
-
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
| 191 |
-
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
| 192 |
-
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
|
| 193 |
-
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
| 194 |
-
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
| 195 |
-
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
| 196 |
-
#endif
|
| 197 |
-
|
| 198 |
-
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
| 199 |
-
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
| 200 |
-
|
| 201 |
-
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
|
| 202 |
-
|
| 203 |
-
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
| 204 |
-
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
| 205 |
-
|
| 206 |
-
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
| 207 |
-
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
| 208 |
-
|
| 209 |
-
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
| 210 |
-
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
| 211 |
-
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
| 212 |
-
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
| 213 |
-
|
| 214 |
-
static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch<detail::sm90_smem_capacity_bytes,
|
| 215 |
-
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
| 216 |
-
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
| 217 |
-
|
| 218 |
-
using SmemCopyAtomA = void;
|
| 219 |
-
using SmemCopyAtomB = void;
|
| 220 |
-
|
| 221 |
-
using CollectiveOp = CollectiveMma<
|
| 222 |
-
DispatchPolicy,
|
| 223 |
-
TileShape_MNK,
|
| 224 |
-
ElementA,
|
| 225 |
-
TagToStrideA_t<GmemLayoutATag>,
|
| 226 |
-
ElementB,
|
| 227 |
-
TagToStrideB_t<GmemLayoutBTag>,
|
| 228 |
-
TiledMma,
|
| 229 |
-
GmemTiledCopyA,
|
| 230 |
-
SmemLayoutAtomA,
|
| 231 |
-
SmemCopyAtomA,
|
| 232 |
-
cute::identity,
|
| 233 |
-
GmemTiledCopyB,
|
| 234 |
-
SmemLayoutAtomB,
|
| 235 |
-
SmemCopyAtomB,
|
| 236 |
-
cute::identity
|
| 237 |
-
>;
|
| 238 |
-
};
|
| 239 |
-
|
| 240 |
-
} // namespace cutlass::gemm::collective
|
| 241 |
-
|
| 242 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
namespace cutlass::gemm {
|
| 35 |
-
|
| 36 |
-
// Standard non-persistent kernel with a single producer warp, and one prefetch warp.
|
| 37 |
-
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
|
| 38 |
-
// while the producer warp is waiting on griddepcontrol.
|
| 39 |
-
// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and
|
| 40 |
-
// according to prefetch ratio.
|
| 41 |
-
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { };
|
| 42 |
-
|
| 43 |
-
// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp.
|
| 44 |
-
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
|
| 45 |
-
// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not
|
| 46 |
-
// wait on griddepcontrol and loads immediately.
|
| 47 |
-
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { };
|
| 48 |
-
|
| 49 |
-
template<
|
| 50 |
-
int Stages_,
|
| 51 |
-
class ClusterShape_ = Shape<_1,_1,_1>,
|
| 52 |
-
class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch
|
| 53 |
-
>
|
| 54 |
-
struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch {
|
| 55 |
-
constexpr static int Stages = Stages_;
|
| 56 |
-
using ClusterShape = ClusterShape_;
|
| 57 |
-
using ArchTag = arch::Sm90;
|
| 58 |
-
using Schedule = KernelSchedule;
|
| 59 |
-
};
|
| 60 |
-
|
| 61 |
-
} // namespace cutlass::gemm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp
DELETED
|
@@ -1,871 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 36 |
-
#include "cutlass/numeric_types.h"
|
| 37 |
-
#include "cutlass/pipeline/pipeline.hpp"
|
| 38 |
-
#include "cutlass/trace.h"
|
| 39 |
-
|
| 40 |
-
#include "cute/arch/cluster_sm90.hpp"
|
| 41 |
-
#include "cute/arch/copy_sm90.hpp"
|
| 42 |
-
#include "cute/algorithm/functional.hpp"
|
| 43 |
-
#include "cute/atom/mma_atom.hpp"
|
| 44 |
-
#include "cute/algorithm/gemm.hpp"
|
| 45 |
-
#include "cute/numeric/arithmetic_tuple.hpp"
|
| 46 |
-
#include "cutlass/arch/grid_dependency_control.h"
|
| 47 |
-
|
| 48 |
-
#include "dispatch_policy_extra.hpp"
|
| 49 |
-
|
| 50 |
-
#include "../pipeline/prefetch_pipeline_sm90.hpp"
|
| 51 |
-
|
| 52 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
-
|
| 54 |
-
namespace cutlass::gemm::collective {
|
| 55 |
-
using namespace cute;
|
| 56 |
-
|
| 57 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
-
|
| 59 |
-
namespace detail {
|
| 60 |
-
|
| 61 |
-
constexpr int PrefetchStages = 4;
|
| 62 |
-
constexpr int PrefetchInitialStages = 1;
|
| 63 |
-
// This determines how much shmem we set aside for prefetch.
|
| 64 |
-
// We don't reuse anything loaded by prefetcher, so we can keep
|
| 65 |
-
// loading into the same place -- there will be a conflict when
|
| 66 |
-
// writing, but it doesn't affect performance as much as the doors
|
| 67 |
-
// that this opens.
|
| 68 |
-
constexpr int PrefetchStagesActual = 1;
|
| 69 |
-
|
| 70 |
-
} // namespace detail
|
| 71 |
-
|
| 72 |
-
// WarpSpecialized Mainloop
|
| 73 |
-
template <
|
| 74 |
-
int Stages,
|
| 75 |
-
class ClusterShape,
|
| 76 |
-
class KernelSchedule,
|
| 77 |
-
class TileShape_,
|
| 78 |
-
class ElementA_,
|
| 79 |
-
class StrideA_,
|
| 80 |
-
class ElementB_,
|
| 81 |
-
class StrideB_,
|
| 82 |
-
class TiledMma_,
|
| 83 |
-
class GmemTiledCopyA_,
|
| 84 |
-
class SmemLayoutAtomA_,
|
| 85 |
-
class SmemCopyAtomA_,
|
| 86 |
-
class TransformA_,
|
| 87 |
-
class GmemTiledCopyB_,
|
| 88 |
-
class SmemLayoutAtomB_,
|
| 89 |
-
class SmemCopyAtomB_,
|
| 90 |
-
class TransformB_>
|
| 91 |
-
struct CollectiveMma<
|
| 92 |
-
MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>,
|
| 93 |
-
TileShape_,
|
| 94 |
-
ElementA_,
|
| 95 |
-
StrideA_,
|
| 96 |
-
ElementB_,
|
| 97 |
-
StrideB_,
|
| 98 |
-
TiledMma_,
|
| 99 |
-
GmemTiledCopyA_,
|
| 100 |
-
SmemLayoutAtomA_,
|
| 101 |
-
SmemCopyAtomA_,
|
| 102 |
-
TransformA_,
|
| 103 |
-
GmemTiledCopyB_,
|
| 104 |
-
SmemLayoutAtomB_,
|
| 105 |
-
SmemCopyAtomB_,
|
| 106 |
-
TransformB_>
|
| 107 |
-
{
|
| 108 |
-
//
|
| 109 |
-
// Type Aliases
|
| 110 |
-
//
|
| 111 |
-
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>;
|
| 112 |
-
using TileShape = TileShape_;
|
| 113 |
-
using ElementA = ElementA_;
|
| 114 |
-
using StrideA = StrideA_;
|
| 115 |
-
using ElementB = ElementB_;
|
| 116 |
-
using StrideB = StrideB_;
|
| 117 |
-
using TiledMma = TiledMma_;
|
| 118 |
-
using ElementAccumulator = typename TiledMma::ValTypeC;
|
| 119 |
-
using GmemTiledCopyA = GmemTiledCopyA_;
|
| 120 |
-
using GmemTiledCopyB = GmemTiledCopyB_;
|
| 121 |
-
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
| 122 |
-
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
| 123 |
-
using SmemCopyAtomA = SmemCopyAtomA_;
|
| 124 |
-
using SmemCopyAtomB = SmemCopyAtomB_;
|
| 125 |
-
using TransformA = TransformA_;
|
| 126 |
-
using TransformB = TransformB_;
|
| 127 |
-
using ArchTag = typename DispatchPolicy::ArchTag;
|
| 128 |
-
|
| 129 |
-
static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
|
| 130 |
-
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
| 131 |
-
|
| 132 |
-
using PrefetcherPipeline = cutlass::PrefetchPipeline<detail::PrefetchStages>;
|
| 133 |
-
|
| 134 |
-
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
| 135 |
-
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
| 136 |
-
using PipelineParams = typename MainloopPipeline::Params;
|
| 137 |
-
|
| 138 |
-
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
| 139 |
-
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
| 140 |
-
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
| 141 |
-
|
| 142 |
-
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
| 143 |
-
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
| 144 |
-
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
| 145 |
-
|
| 146 |
-
// Tile along modes in a way that maximizes the TMA box size.
|
| 147 |
-
using SmemLayoutA = decltype(tile_to_shape(
|
| 148 |
-
SmemLayoutAtomA{},
|
| 149 |
-
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
| 150 |
-
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
| 151 |
-
using SmemLayoutB = decltype(tile_to_shape(
|
| 152 |
-
SmemLayoutAtomB{},
|
| 153 |
-
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
| 154 |
-
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
| 155 |
-
|
| 156 |
-
static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages);
|
| 157 |
-
static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages);
|
| 158 |
-
|
| 159 |
-
using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
|
| 160 |
-
cute::Int<size<0>(SmemLayoutA{})>{},
|
| 161 |
-
cute::Int<size<1>(SmemLayoutA{})>{},
|
| 162 |
-
cute::Int<detail::PrefetchStagesActual>{})));
|
| 163 |
-
|
| 164 |
-
static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
|
| 165 |
-
|
| 166 |
-
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
| 167 |
-
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
| 168 |
-
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
| 169 |
-
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
| 170 |
-
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
| 171 |
-
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
| 172 |
-
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
| 173 |
-
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
| 174 |
-
|
| 175 |
-
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
| 176 |
-
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
| 177 |
-
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
| 178 |
-
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
| 179 |
-
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
| 180 |
-
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
| 181 |
-
|
| 182 |
-
// Defined outside the class where it's used, to work around MSVC issues
|
| 183 |
-
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<detail::PrefetchStages>;
|
| 184 |
-
|
| 185 |
-
struct SharedStorage {
|
| 186 |
-
struct TensorStorage : cute::aligned_struct<128, _0> {
|
| 187 |
-
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
| 188 |
-
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
| 189 |
-
cute::array_aligned<typename TiledMma::ValTypeA, prefetch_smem_size> smem_prefetch;
|
| 190 |
-
} tensors;
|
| 191 |
-
|
| 192 |
-
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
| 193 |
-
PipelineStorage pipeline;
|
| 194 |
-
PrefetcherPipelineStorage prefetcher_pipeline;
|
| 195 |
-
};
|
| 196 |
-
using TensorStorage = typename SharedStorage::TensorStorage;
|
| 197 |
-
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
| 198 |
-
|
| 199 |
-
// Host side kernel arguments
|
| 200 |
-
struct Arguments {
|
| 201 |
-
ElementA const* ptr_A;
|
| 202 |
-
StrideA dA;
|
| 203 |
-
ElementB const* ptr_B;
|
| 204 |
-
StrideB dB;
|
| 205 |
-
uint32_t mma_promotion_interval = 4;
|
| 206 |
-
float overlap_ratio = 0.5;
|
| 207 |
-
float prefetch_ratio = -1.0;
|
| 208 |
-
};
|
| 209 |
-
|
| 210 |
-
// Device side kernel params
|
| 211 |
-
struct Params {
|
| 212 |
-
// Assumption: StrideA is congruent with Problem_MK
|
| 213 |
-
using TMA_A = decltype(make_tma_copy_A_sm90(
|
| 214 |
-
GmemTiledCopyA{},
|
| 215 |
-
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
| 216 |
-
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
| 217 |
-
TileShape{},
|
| 218 |
-
ClusterShape{}));
|
| 219 |
-
// Assumption: StrideB is congruent with Problem_NK
|
| 220 |
-
using TMA_B = decltype(make_tma_copy_B_sm90(
|
| 221 |
-
GmemTiledCopyB{},
|
| 222 |
-
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
| 223 |
-
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
| 224 |
-
TileShape{},
|
| 225 |
-
ClusterShape{}));
|
| 226 |
-
|
| 227 |
-
TMA_A tma_load_a;
|
| 228 |
-
TMA_B tma_load_b;
|
| 229 |
-
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
| 230 |
-
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
| 231 |
-
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
| 232 |
-
float overlap_ratio = 0.5;
|
| 233 |
-
float prefetch_ratio = -1.0;
|
| 234 |
-
};
|
| 235 |
-
|
| 236 |
-
//
|
| 237 |
-
// Methods
|
| 238 |
-
//
|
| 239 |
-
|
| 240 |
-
template <class ProblemShape>
|
| 241 |
-
static constexpr Params
|
| 242 |
-
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
| 243 |
-
(void) workspace;
|
| 244 |
-
|
| 245 |
-
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
| 246 |
-
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
| 247 |
-
auto [M,N,K,L] = problem_shape_MNKL;
|
| 248 |
-
|
| 249 |
-
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
| 250 |
-
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
|
| 251 |
-
|
| 252 |
-
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
| 253 |
-
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
| 254 |
-
|
| 255 |
-
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
| 256 |
-
GmemTiledCopyA{},
|
| 257 |
-
tensor_a,
|
| 258 |
-
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
| 259 |
-
TileShape{},
|
| 260 |
-
ClusterShape{});
|
| 261 |
-
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
| 262 |
-
GmemTiledCopyB{},
|
| 263 |
-
tensor_b,
|
| 264 |
-
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
| 265 |
-
TileShape{},
|
| 266 |
-
ClusterShape{});
|
| 267 |
-
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
| 268 |
-
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
| 269 |
-
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
| 270 |
-
|
| 271 |
-
return {
|
| 272 |
-
tma_load_a,
|
| 273 |
-
tma_load_b,
|
| 274 |
-
transaction_bytes,
|
| 275 |
-
transaction_bytes_mk,
|
| 276 |
-
transaction_bytes_nk,
|
| 277 |
-
args.overlap_ratio,
|
| 278 |
-
args.prefetch_ratio
|
| 279 |
-
};
|
| 280 |
-
}
|
| 281 |
-
|
| 282 |
-
template<class ProblemShape>
|
| 283 |
-
static bool
|
| 284 |
-
can_implement(
|
| 285 |
-
ProblemShape const& problem_shape,
|
| 286 |
-
[[maybe_unused]] Arguments const& args) {
|
| 287 |
-
constexpr int tma_alignment_bits = 128;
|
| 288 |
-
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
| 289 |
-
auto [M,N,K,L] = problem_shape_MNKL;
|
| 290 |
-
|
| 291 |
-
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
| 292 |
-
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
| 293 |
-
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
| 294 |
-
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
| 295 |
-
|
| 296 |
-
if (!implementable) {
|
| 297 |
-
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
| 298 |
-
return false;
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
if (args.overlap_ratio > 1.0) {
|
| 302 |
-
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n");
|
| 303 |
-
return false;
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
if (args.prefetch_ratio > 1.0) {
|
| 307 |
-
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n");
|
| 308 |
-
return false;
|
| 309 |
-
}
|
| 310 |
-
|
| 311 |
-
return true;
|
| 312 |
-
}
|
| 313 |
-
|
| 314 |
-
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
| 315 |
-
static constexpr int K_PIPE_MMAS = 1;
|
| 316 |
-
static constexpr uint32_t TmaTransactionBytesMK =
|
| 317 |
-
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
| 318 |
-
static constexpr uint32_t TmaTransactionBytesNK =
|
| 319 |
-
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
| 320 |
-
|
| 321 |
-
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
| 322 |
-
CUTLASS_DEVICE
|
| 323 |
-
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
| 324 |
-
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
| 325 |
-
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
| 326 |
-
}
|
| 327 |
-
|
| 328 |
-
/// Set up the data needed by this collective for load and mma.
|
| 329 |
-
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
| 330 |
-
/// Returned tuple must contain at least two elements, with the first two elements being:
|
| 331 |
-
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
| 332 |
-
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
| 333 |
-
/// The rest of the tensors can be specified as needed by this collective.
|
| 334 |
-
template <class ProblemShape_MNKL>
|
| 335 |
-
CUTLASS_DEVICE auto
|
| 336 |
-
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
| 337 |
-
using X = Underscore;
|
| 338 |
-
// Separate out problem shape for convenience
|
| 339 |
-
auto [M,N,K,L] = problem_shape_MNKL;
|
| 340 |
-
|
| 341 |
-
// TMA requires special handling of strides to deal with coord codomain mapping
|
| 342 |
-
// Represent the full tensors -- get these from TMA
|
| 343 |
-
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
| 344 |
-
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
| 345 |
-
|
| 346 |
-
// Make tiled views, defer the slice
|
| 347 |
-
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
| 348 |
-
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
| 349 |
-
|
| 350 |
-
return cute::make_tuple(gA_mkl, gB_nkl);
|
| 351 |
-
}
|
| 352 |
-
|
| 353 |
-
template <
|
| 354 |
-
class TensorA, class TensorB,
|
| 355 |
-
class KTileIterator, class BlockCoord
|
| 356 |
-
>
|
| 357 |
-
CUTLASS_DEVICE void
|
| 358 |
-
load(
|
| 359 |
-
Params const& mainloop_params,
|
| 360 |
-
MainloopPipeline pipeline,
|
| 361 |
-
PrefetcherPipeline prefetcher_pipeline,
|
| 362 |
-
PipelineState smem_pipe_write,
|
| 363 |
-
TensorA const& gA_mkl,
|
| 364 |
-
TensorB const& gB_nkl,
|
| 365 |
-
BlockCoord const& blk_coord,
|
| 366 |
-
KTileIterator k_tile_iter, int k_tile_count,
|
| 367 |
-
int thread_idx,
|
| 368 |
-
uint32_t block_rank_in_cluster,
|
| 369 |
-
TensorStorage& shared_tensors) {
|
| 370 |
-
int lane_predicate = cute::elect_one_sync();
|
| 371 |
-
|
| 372 |
-
if (lane_predicate) {
|
| 373 |
-
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
|
| 374 |
-
float overlap_ratio = mainloop_params.overlap_ratio;
|
| 375 |
-
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
|
| 376 |
-
|
| 377 |
-
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 378 |
-
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
| 379 |
-
|
| 380 |
-
//
|
| 381 |
-
// Prepare the TMA loads for A
|
| 382 |
-
//
|
| 383 |
-
|
| 384 |
-
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
| 385 |
-
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
| 386 |
-
|
| 387 |
-
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
| 388 |
-
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
| 389 |
-
|
| 390 |
-
// Partition the inputs based on the current block coordinates.
|
| 391 |
-
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
| 392 |
-
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
| 393 |
-
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
| 394 |
-
|
| 395 |
-
// Applies the mapping from cta_tma_a
|
| 396 |
-
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
| 397 |
-
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
| 398 |
-
|
| 399 |
-
// Applies the mapping from cta_tma_b
|
| 400 |
-
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
| 401 |
-
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
| 402 |
-
|
| 403 |
-
uint16_t mcast_mask_a = 0;
|
| 404 |
-
uint16_t mcast_mask_b = 0;
|
| 405 |
-
|
| 406 |
-
// Issue TmaLoads
|
| 407 |
-
// Maps the tile -> block, value
|
| 408 |
-
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
| 409 |
-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 410 |
-
for (int n = 0; n < size<1>(block_layout); ++n) {
|
| 411 |
-
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
| 412 |
-
}
|
| 413 |
-
}
|
| 414 |
-
|
| 415 |
-
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
| 416 |
-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 417 |
-
for (int m = 0; m < size<0>(block_layout); ++m) {
|
| 418 |
-
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
| 419 |
-
}
|
| 420 |
-
}
|
| 421 |
-
|
| 422 |
-
// We have to wait on dependent grids because of B.
|
| 423 |
-
cutlass::arch::wait_on_dependent_grids();
|
| 424 |
-
|
| 425 |
-
// Signal prefetcher to stop
|
| 426 |
-
prefetcher_pipeline.producer_arrive();
|
| 427 |
-
|
| 428 |
-
bool launch_dep_grids = false;
|
| 429 |
-
// Mainloop
|
| 430 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 431 |
-
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
|
| 432 |
-
// LOCK smem_pipe_write for _writing_
|
| 433 |
-
pipeline.producer_acquire(smem_pipe_write);
|
| 434 |
-
|
| 435 |
-
//
|
| 436 |
-
// Copy gmem to smem for *k_tile_iter
|
| 437 |
-
//
|
| 438 |
-
|
| 439 |
-
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
| 440 |
-
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
| 441 |
-
|
| 442 |
-
int write_stage = smem_pipe_write.index();
|
| 443 |
-
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
| 444 |
-
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
| 445 |
-
++k_tile_iter;
|
| 446 |
-
|
| 447 |
-
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
| 448 |
-
launch_dep_grids = true;
|
| 449 |
-
cutlass::arch::launch_dependent_grids();
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
// Advance smem_pipe_write
|
| 453 |
-
++smem_pipe_write;
|
| 454 |
-
}
|
| 455 |
-
if (!disable_gdc && !launch_dep_grids) {
|
| 456 |
-
cutlass::arch::launch_dependent_grids();
|
| 457 |
-
}
|
| 458 |
-
}
|
| 459 |
-
}
|
| 460 |
-
|
| 461 |
-
template <
|
| 462 |
-
class TensorA,
|
| 463 |
-
class KTileIterator, class BlockCoord
|
| 464 |
-
>
|
| 465 |
-
CUTLASS_DEVICE void
|
| 466 |
-
load_MK(
|
| 467 |
-
Params const& mainloop_params,
|
| 468 |
-
MainloopPipeline pipeline,
|
| 469 |
-
PrefetcherPipeline prefetcher_pipeline,
|
| 470 |
-
PipelineState smem_pipe_write,
|
| 471 |
-
TensorA const& gA_mkl,
|
| 472 |
-
BlockCoord const& blk_coord,
|
| 473 |
-
KTileIterator k_tile_iter, int k_tile_count,
|
| 474 |
-
int thread_idx,
|
| 475 |
-
uint32_t block_rank_in_cluster,
|
| 476 |
-
TensorStorage& shared_tensors) {
|
| 477 |
-
int lane_predicate = cute::elect_one_sync();
|
| 478 |
-
|
| 479 |
-
if (lane_predicate) {
|
| 480 |
-
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
|
| 481 |
-
float overlap_ratio = mainloop_params.overlap_ratio;
|
| 482 |
-
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
|
| 483 |
-
|
| 484 |
-
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 485 |
-
|
| 486 |
-
//
|
| 487 |
-
// Prepare the TMA loads for A
|
| 488 |
-
//
|
| 489 |
-
|
| 490 |
-
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
| 491 |
-
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
| 492 |
-
|
| 493 |
-
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
| 494 |
-
|
| 495 |
-
// Partition the inputs based on the current block coordinates.
|
| 496 |
-
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
| 497 |
-
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
| 498 |
-
|
| 499 |
-
// Applies the mapping from cta_tma_a
|
| 500 |
-
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
| 501 |
-
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
| 502 |
-
|
| 503 |
-
uint16_t mcast_mask_a = 0;
|
| 504 |
-
|
| 505 |
-
// Issue TmaLoads
|
| 506 |
-
// Maps the tile -> block, value
|
| 507 |
-
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
| 508 |
-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 509 |
-
for (int n = 0; n < size<1>(block_layout); ++n) {
|
| 510 |
-
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
| 511 |
-
}
|
| 512 |
-
}
|
| 513 |
-
|
| 514 |
-
// Don't wait on dependent grids when loading `A`, because
|
| 515 |
-
// we assume `A` (weights) are static.
|
| 516 |
-
|
| 517 |
-
bool launch_dep_grids = false;
|
| 518 |
-
// Mainloop
|
| 519 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 520 |
-
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
|
| 521 |
-
// LOCK smem_pipe_write for _writing_
|
| 522 |
-
pipeline.producer_acquire(smem_pipe_write);
|
| 523 |
-
|
| 524 |
-
//
|
| 525 |
-
// Copy gmem to smem for *k_tile_iter
|
| 526 |
-
//
|
| 527 |
-
|
| 528 |
-
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
| 529 |
-
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
| 530 |
-
|
| 531 |
-
int write_stage = smem_pipe_write.index();
|
| 532 |
-
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
| 533 |
-
++k_tile_iter;
|
| 534 |
-
|
| 535 |
-
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
| 536 |
-
launch_dep_grids = true;
|
| 537 |
-
cutlass::arch::launch_dependent_grids();
|
| 538 |
-
}
|
| 539 |
-
|
| 540 |
-
// Advance smem_pipe_write
|
| 541 |
-
++smem_pipe_write;
|
| 542 |
-
}
|
| 543 |
-
if (!disable_gdc && !launch_dep_grids) {
|
| 544 |
-
cutlass::arch::launch_dependent_grids();
|
| 545 |
-
}
|
| 546 |
-
}
|
| 547 |
-
}
|
| 548 |
-
|
| 549 |
-
template <
|
| 550 |
-
class TensorB,
|
| 551 |
-
class KTileIterator, class BlockCoord
|
| 552 |
-
>
|
| 553 |
-
CUTLASS_DEVICE void
|
| 554 |
-
load_NK(
|
| 555 |
-
Params const& mainloop_params,
|
| 556 |
-
MainloopPipeline pipeline,
|
| 557 |
-
PrefetcherPipeline prefetcher_pipeline,
|
| 558 |
-
PipelineState smem_pipe_write,
|
| 559 |
-
TensorB const& gB_nkl,
|
| 560 |
-
BlockCoord const& blk_coord,
|
| 561 |
-
KTileIterator k_tile_iter, int k_tile_count,
|
| 562 |
-
int thread_idx,
|
| 563 |
-
uint32_t block_rank_in_cluster,
|
| 564 |
-
TensorStorage& shared_tensors) {
|
| 565 |
-
int lane_predicate = cute::elect_one_sync();
|
| 566 |
-
|
| 567 |
-
if (lane_predicate) {
|
| 568 |
-
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
| 569 |
-
|
| 570 |
-
//
|
| 571 |
-
// Prepare the TMA loads for B
|
| 572 |
-
//
|
| 573 |
-
|
| 574 |
-
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
| 575 |
-
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
| 576 |
-
|
| 577 |
-
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
| 578 |
-
|
| 579 |
-
// Partition the inputs based on the current block coordinates.
|
| 580 |
-
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
| 581 |
-
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
| 582 |
-
|
| 583 |
-
// Applies the mapping from cta_tma_b
|
| 584 |
-
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
| 585 |
-
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
| 586 |
-
|
| 587 |
-
uint16_t mcast_mask_b = 0;
|
| 588 |
-
|
| 589 |
-
// Issue TmaLoads
|
| 590 |
-
// Maps the tile -> block, value
|
| 591 |
-
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
| 592 |
-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 593 |
-
for (int m = 0; m < size<0>(block_layout); ++m) {
|
| 594 |
-
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
| 595 |
-
}
|
| 596 |
-
}
|
| 597 |
-
|
| 598 |
-
// Ensure that the prefetched kernel does not touch
|
| 599 |
-
// unflushed global memory prior to this instruction
|
| 600 |
-
cutlass::arch::wait_on_dependent_grids();
|
| 601 |
-
|
| 602 |
-
// Signal prefetcher to stop
|
| 603 |
-
prefetcher_pipeline.producer_arrive();
|
| 604 |
-
|
| 605 |
-
// Mainloop
|
| 606 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 607 |
-
for (; k_tile_count > 0; --k_tile_count) {
|
| 608 |
-
// LOCK smem_pipe_write for _writing_
|
| 609 |
-
pipeline.producer_acquire(smem_pipe_write);
|
| 610 |
-
|
| 611 |
-
//
|
| 612 |
-
// Copy gmem to smem for *k_tile_iter
|
| 613 |
-
//
|
| 614 |
-
|
| 615 |
-
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
| 616 |
-
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
| 617 |
-
|
| 618 |
-
int write_stage = smem_pipe_write.index();
|
| 619 |
-
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
| 620 |
-
++k_tile_iter;
|
| 621 |
-
|
| 622 |
-
// Advance smem_pipe_write
|
| 623 |
-
++smem_pipe_write;
|
| 624 |
-
}
|
| 625 |
-
}
|
| 626 |
-
}
|
| 627 |
-
|
| 628 |
-
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
| 629 |
-
CUTLASS_DEVICE void
|
| 630 |
-
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
| 631 |
-
int lane_predicate = cute::elect_one_sync();
|
| 632 |
-
|
| 633 |
-
// Issue the epilogue waits
|
| 634 |
-
if (lane_predicate) {
|
| 635 |
-
/* This helps avoid early exit of blocks in Cluster
|
| 636 |
-
* Waits for all stages to either be released (all
|
| 637 |
-
* Consumer UNLOCKs), or if the stage was never used
|
| 638 |
-
* then would just be acquired since the phase was
|
| 639 |
-
* still inverted from make_producer_start_state
|
| 640 |
-
*/
|
| 641 |
-
pipeline.producer_tail(smem_pipe_write);
|
| 642 |
-
}
|
| 643 |
-
}
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
template <
|
| 647 |
-
class TensorA,
|
| 648 |
-
class KTileIterator, class BlockCoord
|
| 649 |
-
>
|
| 650 |
-
CUTLASS_DEVICE void
|
| 651 |
-
prefetch_MK(
|
| 652 |
-
Params const& mainloop_params,
|
| 653 |
-
PrefetcherPipeline prefetcher_pipeline,
|
| 654 |
-
PipelineState smem_pipe_write,
|
| 655 |
-
TensorA const& gA_mkl,
|
| 656 |
-
BlockCoord const& blk_coord,
|
| 657 |
-
KTileIterator k_tile_iter, int k_tile_count,
|
| 658 |
-
int thread_idx,
|
| 659 |
-
uint32_t block_rank_in_cluster,
|
| 660 |
-
TensorStorage& shared_tensors) {
|
| 661 |
-
int lane_predicate = cute::elect_one_sync();
|
| 662 |
-
|
| 663 |
-
if (lane_predicate) {
|
| 664 |
-
bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
|
| 665 |
-
float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
|
| 666 |
-
int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
|
| 667 |
-
prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages);
|
| 668 |
-
|
| 669 |
-
Tensor sA = make_tensor(
|
| 670 |
-
make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 671 |
-
|
| 672 |
-
//
|
| 673 |
-
// Prepare the TMA loads for A
|
| 674 |
-
//
|
| 675 |
-
|
| 676 |
-
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
| 677 |
-
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
| 678 |
-
|
| 679 |
-
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
| 680 |
-
|
| 681 |
-
// Partition the inputs based on the current block coordinates.
|
| 682 |
-
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
| 683 |
-
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
| 684 |
-
|
| 685 |
-
// Applies the mapping from cta_tma_a
|
| 686 |
-
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
| 687 |
-
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
| 688 |
-
|
| 689 |
-
uint16_t mcast_mask_a = 0;
|
| 690 |
-
|
| 691 |
-
// Issue TmaLoads
|
| 692 |
-
// Maps the tile -> block, value
|
| 693 |
-
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
| 694 |
-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 695 |
-
for (int n = 0; n < size<1>(block_layout); ++n) {
|
| 696 |
-
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
| 697 |
-
}
|
| 698 |
-
}
|
| 699 |
-
|
| 700 |
-
uint32_t prefetcher_stage = 0;
|
| 701 |
-
uint32_t prefetcher_phase = 0;
|
| 702 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 703 |
-
for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) {
|
| 704 |
-
|
| 705 |
-
if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) {
|
| 706 |
-
break;
|
| 707 |
-
}
|
| 708 |
-
|
| 709 |
-
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages);
|
| 710 |
-
using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
|
| 711 |
-
BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);
|
| 712 |
-
|
| 713 |
-
int write_stage = 0;
|
| 714 |
-
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
| 715 |
-
++k_tile_iter;
|
| 716 |
-
++k_tile_iter;
|
| 717 |
-
|
| 718 |
-
prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase);
|
| 719 |
-
}
|
| 720 |
-
prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase);
|
| 721 |
-
}
|
| 722 |
-
}
|
| 723 |
-
|
| 724 |
-
/// Perform a collective-scoped matrix multiply-accumulate
|
| 725 |
-
/// Consumer Perspective
|
| 726 |
-
template <
|
| 727 |
-
class FrgTensorC
|
| 728 |
-
>
|
| 729 |
-
CUTLASS_DEVICE void
|
| 730 |
-
mma(MainloopPipeline pipeline,
|
| 731 |
-
PipelineState smem_pipe_read,
|
| 732 |
-
FrgTensorC& accum,
|
| 733 |
-
int k_tile_count,
|
| 734 |
-
int thread_idx,
|
| 735 |
-
TensorStorage& shared_tensors,
|
| 736 |
-
Params const& mainloop_params) {
|
| 737 |
-
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
| 738 |
-
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
| 739 |
-
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
| 740 |
-
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
| 741 |
-
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
| 742 |
-
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
| 743 |
-
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
| 744 |
-
|
| 745 |
-
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 746 |
-
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
| 747 |
-
|
| 748 |
-
//
|
| 749 |
-
// Define C accumulators and A/B partitioning
|
| 750 |
-
//
|
| 751 |
-
|
| 752 |
-
TiledMma tiled_mma;
|
| 753 |
-
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
| 754 |
-
|
| 755 |
-
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
| 756 |
-
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
| 757 |
-
|
| 758 |
-
// Allocate "fragments/descriptors"
|
| 759 |
-
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
| 760 |
-
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
| 761 |
-
|
| 762 |
-
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
| 763 |
-
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
| 764 |
-
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
| 765 |
-
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
| 766 |
-
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
| 767 |
-
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
| 768 |
-
|
| 769 |
-
//
|
| 770 |
-
// PIPELINED MAIN LOOP
|
| 771 |
-
//
|
| 772 |
-
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
| 773 |
-
"ERROR : Incorrect number of MMAs in flight");
|
| 774 |
-
|
| 775 |
-
// We release buffers to producer warps(dma load) with some mmas in flight
|
| 776 |
-
PipelineState smem_pipe_release = smem_pipe_read;
|
| 777 |
-
|
| 778 |
-
// Prologue GMMAs
|
| 779 |
-
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
| 780 |
-
|
| 781 |
-
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
| 782 |
-
|
| 783 |
-
warpgroup_fence_operand(accum);
|
| 784 |
-
CUTLASS_PRAGMA_UNROLL
|
| 785 |
-
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
| 786 |
-
{
|
| 787 |
-
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
| 788 |
-
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
| 789 |
-
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
| 790 |
-
|
| 791 |
-
int read_stage = smem_pipe_read.index();
|
| 792 |
-
warpgroup_arrive();
|
| 793 |
-
// Unroll the K mode manually to set scale D to 1
|
| 794 |
-
CUTLASS_PRAGMA_UNROLL
|
| 795 |
-
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
| 796 |
-
// (V,M,K) x (V,N,K) => (V,M,N)
|
| 797 |
-
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
| 798 |
-
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
| 799 |
-
}
|
| 800 |
-
|
| 801 |
-
warpgroup_commit_batch();
|
| 802 |
-
|
| 803 |
-
++smem_pipe_read;
|
| 804 |
-
}
|
| 805 |
-
|
| 806 |
-
warpgroup_fence_operand(accum);
|
| 807 |
-
// Mainloop GMMAs
|
| 808 |
-
k_tile_count -= prologue_mma_count;
|
| 809 |
-
|
| 810 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 811 |
-
for ( ; k_tile_count > 0; --k_tile_count)
|
| 812 |
-
{
|
| 813 |
-
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
| 814 |
-
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
| 815 |
-
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
| 816 |
-
|
| 817 |
-
//
|
| 818 |
-
// Compute on k_tile
|
| 819 |
-
//
|
| 820 |
-
|
| 821 |
-
int read_stage = smem_pipe_read.index();
|
| 822 |
-
warpgroup_fence_operand(accum);
|
| 823 |
-
warpgroup_arrive();
|
| 824 |
-
// Unroll the K mode manually to set scale D to 1
|
| 825 |
-
CUTLASS_PRAGMA_UNROLL
|
| 826 |
-
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
| 827 |
-
// (V,M,K) x (V,N,K) => (V,M,N)
|
| 828 |
-
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
| 829 |
-
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
| 830 |
-
}
|
| 831 |
-
warpgroup_commit_batch();
|
| 832 |
-
|
| 833 |
-
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
| 834 |
-
warpgroup_wait<K_PIPE_MMAS>();
|
| 835 |
-
warpgroup_fence_operand(accum);
|
| 836 |
-
|
| 837 |
-
// UNLOCK smem_pipe_release, done _computing_ on it
|
| 838 |
-
pipeline.consumer_release(smem_pipe_release);
|
| 839 |
-
|
| 840 |
-
// Advance smem_pipe_read and smem_pipe_release
|
| 841 |
-
++smem_pipe_read;
|
| 842 |
-
++smem_pipe_release;
|
| 843 |
-
}
|
| 844 |
-
|
| 845 |
-
warpgroup_fence_operand(accum);
|
| 846 |
-
}
|
| 847 |
-
|
| 848 |
-
/// Perform a Consumer Epilogue to release all buffers
|
| 849 |
-
CUTLASS_DEVICE void
|
| 850 |
-
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
| 851 |
-
// Prologue GMMAs
|
| 852 |
-
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
| 853 |
-
k_tile_count -= prologue_mma_count;
|
| 854 |
-
|
| 855 |
-
smem_pipe_release.advance(k_tile_count);
|
| 856 |
-
|
| 857 |
-
// Wait on all GMMAs to complete
|
| 858 |
-
warpgroup_wait<0>();
|
| 859 |
-
|
| 860 |
-
for (int count = 0; count < prologue_mma_count; ++count) {
|
| 861 |
-
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
| 862 |
-
++smem_pipe_release;
|
| 863 |
-
}
|
| 864 |
-
}
|
| 865 |
-
};
|
| 866 |
-
|
| 867 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 868 |
-
|
| 869 |
-
} // namespace cutlass::gemm::collective
|
| 870 |
-
|
| 871 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp
DELETED
|
@@ -1,117 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
// Command line options parsing
|
| 33 |
-
struct Options {
|
| 34 |
-
|
| 35 |
-
bool help = false;
|
| 36 |
-
|
| 37 |
-
float alpha = 1.f, beta = 0.f;
|
| 38 |
-
float overlap_ratio = 0.5f, prefetch_ratio = 0.5f;
|
| 39 |
-
int iterations = 1000;
|
| 40 |
-
int n = 64, m = 1280, k = 8192, l = 1;
|
| 41 |
-
|
| 42 |
-
// Parses the command line
|
| 43 |
-
void parse(int argc, char const **args) {
|
| 44 |
-
cutlass::CommandLine cmd(argc, args);
|
| 45 |
-
|
| 46 |
-
if (cmd.check_cmd_line_flag("help")) {
|
| 47 |
-
help = true;
|
| 48 |
-
return;
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
cmd.get_cmd_line_argument("m", m);
|
| 52 |
-
cmd.get_cmd_line_argument("n", n);
|
| 53 |
-
cmd.get_cmd_line_argument("k", k);
|
| 54 |
-
cmd.get_cmd_line_argument("l", l);
|
| 55 |
-
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
| 56 |
-
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
| 57 |
-
cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f);
|
| 58 |
-
cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f);
|
| 59 |
-
cmd.get_cmd_line_argument("iterations", iterations);
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
/// Prints the usage statement.
|
| 63 |
-
std::ostream & print_usage(std::ostream &out) const {
|
| 64 |
-
|
| 65 |
-
out << "63_hopper_gemm_with_weight_prefetch\n\n"
|
| 66 |
-
<< " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n"
|
| 67 |
-
<< " For more details please refer to the source file.\n\n"
|
| 68 |
-
<< "Options:\n\n"
|
| 69 |
-
<< " --help If specified, displays this usage statement\n\n"
|
| 70 |
-
<< " --m=<int> Sets the M extent of the GEMM\n"
|
| 71 |
-
<< " --n=<int> Sets the N extent of the GEMM\n"
|
| 72 |
-
<< " --k=<int> Sets the K extent of the GEMM\n"
|
| 73 |
-
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
| 74 |
-
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
| 75 |
-
<< " --beta=<f32> Epilogue scalar beta\n"
|
| 76 |
-
<< " --p=<f32> Prefetch ratio\n"
|
| 77 |
-
<< " --o=<f32> Overlap ratio\n"
|
| 78 |
-
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
| 79 |
-
|
| 80 |
-
out
|
| 81 |
-
<< "\n\nExamples:\n\n"
|
| 82 |
-
<< "$ " << "63_hopper_gemm_with_weight_prefetch" <<
|
| 83 |
-
" --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n";
|
| 84 |
-
|
| 85 |
-
return out;
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
/// Compute performance in GFLOP/s
|
| 89 |
-
double gflops(double runtime_s) const
|
| 90 |
-
{
|
| 91 |
-
// Two flops per multiply-add
|
| 92 |
-
uint64_t flop = uint64_t(2) * m * n * k * l;
|
| 93 |
-
double gflop = double(flop) / double(1.0e9);
|
| 94 |
-
return gflop / runtime_s;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
/// Compute effective bandwidth in GB/sec
|
| 98 |
-
double effective_bandwidth(
|
| 99 |
-
double runtime_s,
|
| 100 |
-
size_t bytes_a,
|
| 101 |
-
size_t bytes_b,
|
| 102 |
-
size_t bytes_c,
|
| 103 |
-
size_t bytes_d
|
| 104 |
-
) const
|
| 105 |
-
{
|
| 106 |
-
static double const kBytesPerGiB = double(1ull << 30);
|
| 107 |
-
|
| 108 |
-
double bytes_in =
|
| 109 |
-
(double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A
|
| 110 |
-
(double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B
|
| 111 |
-
(beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C
|
| 112 |
-
double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D
|
| 113 |
-
|
| 114 |
-
double gb_total = (bytes_in + bytes_out) / kBytesPerGiB;
|
| 115 |
-
return gb_total / runtime_s;
|
| 116 |
-
}
|
| 117 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp
DELETED
|
@@ -1,561 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/fast_math.h"
|
| 36 |
-
#include "cutlass/kernel_hardware_info.hpp"
|
| 37 |
-
#include "cute/arch/cluster_sm90.hpp"
|
| 38 |
-
#include "cutlass/arch/reg_reconfig.h"
|
| 39 |
-
#include "cutlass/arch/mma_sm90.h"
|
| 40 |
-
#include "cutlass/epilogue/collective/detail.hpp"
|
| 41 |
-
#include "cutlass/gemm/gemm.h"
|
| 42 |
-
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 43 |
-
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
| 44 |
-
#include "cutlass/pipeline/pipeline.hpp"
|
| 45 |
-
#include "cutlass/trace.h"
|
| 46 |
-
|
| 47 |
-
#include "cute/tensor.hpp"
|
| 48 |
-
|
| 49 |
-
#include "../collective/dispatch_policy_extra.hpp"
|
| 50 |
-
|
| 51 |
-
///////////////////////////////////////////////////////////////////////////////
|
| 52 |
-
|
| 53 |
-
namespace cutlass::gemm::kernel {
|
| 54 |
-
|
| 55 |
-
///////////////////////////////////////////////////////////////////////////////
|
| 56 |
-
|
| 57 |
-
// GEMM + Prefetch for the A tensor + (optional) split DMA warps
|
| 58 |
-
template <
|
| 59 |
-
class ProblemShape_,
|
| 60 |
-
class CollectiveMainloop_,
|
| 61 |
-
class CollectiveEpilogue_,
|
| 62 |
-
class TileScheduler_
|
| 63 |
-
>
|
| 64 |
-
class GemmUniversal<
|
| 65 |
-
ProblemShape_,
|
| 66 |
-
CollectiveMainloop_,
|
| 67 |
-
CollectiveEpilogue_,
|
| 68 |
-
TileScheduler_,
|
| 69 |
-
cute::enable_if_t<
|
| 70 |
-
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA> ||
|
| 71 |
-
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>
|
| 72 |
-
>
|
| 73 |
-
>
|
| 74 |
-
{
|
| 75 |
-
public:
|
| 76 |
-
//
|
| 77 |
-
// Type Aliases
|
| 78 |
-
//
|
| 79 |
-
using ProblemShape = ProblemShape_;
|
| 80 |
-
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
| 81 |
-
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
| 82 |
-
static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
|
| 83 |
-
|
| 84 |
-
static constexpr bool SplitWarps = cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>;
|
| 85 |
-
|
| 86 |
-
// Mainloop derived types
|
| 87 |
-
using CollectiveMainloop = CollectiveMainloop_;
|
| 88 |
-
using TileShape = typename CollectiveMainloop::TileShape;
|
| 89 |
-
using TiledMma = typename CollectiveMainloop::TiledMma;
|
| 90 |
-
using ArchTag = typename CollectiveMainloop::ArchTag;
|
| 91 |
-
using ElementA = typename CollectiveMainloop::ElementA;
|
| 92 |
-
using StrideA = typename CollectiveMainloop::StrideA;
|
| 93 |
-
using ElementB = typename CollectiveMainloop::ElementB;
|
| 94 |
-
using StrideB = typename CollectiveMainloop::StrideB;
|
| 95 |
-
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
| 96 |
-
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
| 97 |
-
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
| 98 |
-
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
| 99 |
-
using MainloopParams = typename CollectiveMainloop::Params;
|
| 100 |
-
static_assert(ArchTag::kMinComputeCapability >= 90);
|
| 101 |
-
|
| 102 |
-
// Epilogue derived types
|
| 103 |
-
using CollectiveEpilogue = CollectiveEpilogue_;
|
| 104 |
-
using ElementC = typename CollectiveEpilogue::ElementC;
|
| 105 |
-
using StrideC = typename CollectiveEpilogue::StrideC;
|
| 106 |
-
using ElementD = typename CollectiveEpilogue::ElementD;
|
| 107 |
-
using StrideD = typename CollectiveEpilogue::StrideD;
|
| 108 |
-
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
| 109 |
-
using EpilogueParams = typename CollectiveEpilogue::Params;
|
| 110 |
-
|
| 111 |
-
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
|
| 112 |
-
"TMA warp-specialized kernel does not support specializing the tile scheduler.");
|
| 113 |
-
using TileSchedulerTag = TileScheduler_;
|
| 114 |
-
using TileScheduler = typename detail::TileSchedulerSelector<
|
| 115 |
-
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
| 116 |
-
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
| 117 |
-
|
| 118 |
-
// Kernel level shared memory storage
|
| 119 |
-
struct SharedStorage {
|
| 120 |
-
// Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union
|
| 121 |
-
union TensorStorage {
|
| 122 |
-
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
| 123 |
-
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
| 124 |
-
|
| 125 |
-
MainloopTensorStorage mainloop;
|
| 126 |
-
EpilogueTensorStorage epilogue;
|
| 127 |
-
} tensors;
|
| 128 |
-
|
| 129 |
-
struct PipelineStorage : cute::aligned_struct<16, _1> {
|
| 130 |
-
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
| 131 |
-
using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage;
|
| 132 |
-
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
| 133 |
-
|
| 134 |
-
alignas(16) MainloopPipelineStorage mainloop;
|
| 135 |
-
alignas(16) EpiLoadPipelineStorage epi_load;
|
| 136 |
-
alignas(16) PrefetcherPipelineStorage prefetcher;
|
| 137 |
-
} pipelines;
|
| 138 |
-
};
|
| 139 |
-
|
| 140 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 141 |
-
|
| 142 |
-
static constexpr uint32_t NumLoadWarpGroups = 1;
|
| 143 |
-
static constexpr uint32_t NumMmaWarpGroups = 1;
|
| 144 |
-
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
| 145 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
| 146 |
-
|
| 147 |
-
// Device side arguments
|
| 148 |
-
struct Arguments {
|
| 149 |
-
GemmUniversalMode mode{};
|
| 150 |
-
ProblemShape problem_shape{};
|
| 151 |
-
MainloopArguments mainloop{};
|
| 152 |
-
EpilogueArguments epilogue{};
|
| 153 |
-
KernelHardwareInfo hw_info{};
|
| 154 |
-
TileSchedulerArguments scheduler{};
|
| 155 |
-
};
|
| 156 |
-
|
| 157 |
-
// Kernel entry point API
|
| 158 |
-
struct Params {
|
| 159 |
-
GemmUniversalMode mode{};
|
| 160 |
-
ProblemShape problem_shape{};
|
| 161 |
-
MainloopParams mainloop{};
|
| 162 |
-
EpilogueParams epilogue{};
|
| 163 |
-
};
|
| 164 |
-
|
| 165 |
-
//
|
| 166 |
-
// Methods
|
| 167 |
-
//
|
| 168 |
-
|
| 169 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 170 |
-
static
|
| 171 |
-
Params
|
| 172 |
-
to_underlying_arguments(Arguments const& args, void* workspace) {
|
| 173 |
-
(void) workspace;
|
| 174 |
-
auto problem_shape = args.problem_shape;
|
| 175 |
-
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
|
| 176 |
-
// swap M/N
|
| 177 |
-
get<0>(problem_shape) = get<1>(args.problem_shape);
|
| 178 |
-
get<1>(problem_shape) = get<0>(args.problem_shape);
|
| 179 |
-
}
|
| 180 |
-
return {
|
| 181 |
-
args.mode,
|
| 182 |
-
problem_shape,
|
| 183 |
-
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
| 184 |
-
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
|
| 185 |
-
};
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
static bool
|
| 189 |
-
can_implement(Arguments const& args) {
|
| 190 |
-
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
| 191 |
-
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
| 192 |
-
if (!implementable) {
|
| 193 |
-
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
| 194 |
-
return implementable;
|
| 195 |
-
}
|
| 196 |
-
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
| 197 |
-
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
| 198 |
-
implementable &= TileScheduler::can_implement(args.scheduler);
|
| 199 |
-
|
| 200 |
-
return implementable;
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
static
|
| 204 |
-
size_t
|
| 205 |
-
get_workspace_size(Arguments const& args) {
|
| 206 |
-
return 0;
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
static
|
| 210 |
-
cutlass::Status
|
| 211 |
-
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
| 212 |
-
CudaHostAdapter* cuda_adapter = nullptr) {
|
| 213 |
-
return Status::kSuccess;
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
// Computes the kernel launch grid shape based on runtime parameters
|
| 217 |
-
static dim3
|
| 218 |
-
get_grid_shape(Params const& params) {
|
| 219 |
-
auto cluster_shape = ClusterShape{};
|
| 220 |
-
auto tile_shape = TileShape{};
|
| 221 |
-
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
| 222 |
-
return TileScheduler::get_tiled_cta_shape_mnl(
|
| 223 |
-
problem_shape_MNKL, tile_shape, cluster_shape);
|
| 224 |
-
}
|
| 225 |
-
|
| 226 |
-
static dim3
|
| 227 |
-
get_block_shape() {
|
| 228 |
-
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
CUTLASS_DEVICE
|
| 232 |
-
void
|
| 233 |
-
operator()(Params const& params, char* smem_buf) {
|
| 234 |
-
using namespace cute;
|
| 235 |
-
using X = Underscore;
|
| 236 |
-
|
| 237 |
-
#if defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
| 238 |
-
# define ENABLE_SM90_KERNEL_LEVEL 1
|
| 239 |
-
#endif
|
| 240 |
-
|
| 241 |
-
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
| 242 |
-
#if ! defined(ENABLE_SM90_KERNEL_LEVEL)
|
| 243 |
-
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
| 244 |
-
#else
|
| 245 |
-
|
| 246 |
-
enum class WarpGroupRole {
|
| 247 |
-
Producer = 0,
|
| 248 |
-
Consumer = 1,
|
| 249 |
-
};
|
| 250 |
-
// Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK.
|
| 251 |
-
// Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused.
|
| 252 |
-
// Both modes use Warp1 to prefetch.
|
| 253 |
-
enum class ProducerWarpRole {
|
| 254 |
-
Warp0 = 0,
|
| 255 |
-
PrefetchMK = 1,
|
| 256 |
-
Warp2 = 2,
|
| 257 |
-
UnusedWarp = 3
|
| 258 |
-
};
|
| 259 |
-
|
| 260 |
-
// Kernel level shared memory storage
|
| 261 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 262 |
-
|
| 263 |
-
int thread_idx = int(threadIdx.x);
|
| 264 |
-
int lane_idx = canonical_lane_idx();
|
| 265 |
-
int warp_idx = canonical_warp_idx_sync();
|
| 266 |
-
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
| 267 |
-
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
| 268 |
-
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
| 269 |
-
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
| 270 |
-
int lane_predicate = cute::elect_one_sync();
|
| 271 |
-
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
// Issue Tma Descriptor Prefetch from a single thread
|
| 275 |
-
if ((warp_idx == 0) && lane_predicate) {
|
| 276 |
-
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
| 277 |
-
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
// Mainloop Load pipeline
|
| 281 |
-
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
| 282 |
-
typename MainloopPipeline::Params mainloop_pipeline_params;
|
| 283 |
-
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
| 284 |
-
if (warp_group_role == WarpGroupRole::Producer && (
|
| 285 |
-
producer_warp_role == ProducerWarpRole::Warp0 ||
|
| 286 |
-
producer_warp_role == ProducerWarpRole::Warp2)) {
|
| 287 |
-
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
| 288 |
-
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
| 289 |
-
}
|
| 290 |
-
if (warp_group_role == WarpGroupRole::Consumer) {
|
| 291 |
-
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
| 292 |
-
}
|
| 293 |
-
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
| 294 |
-
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
| 295 |
-
bool should_prefetch = params.mainloop.prefetch_ratio > 0;
|
| 296 |
-
using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline;
|
| 297 |
-
typename PrefetcherPipeline::Params prefetcher_pipeline_params;
|
| 298 |
-
prefetcher_pipeline_params.num_prefetchers = 1;
|
| 299 |
-
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
| 300 |
-
prefetcher_pipeline_params.should_prefetch = should_prefetch;
|
| 301 |
-
prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk;
|
| 302 |
-
}
|
| 303 |
-
PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params);
|
| 304 |
-
|
| 305 |
-
// Epilogue Load pipeline
|
| 306 |
-
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
| 307 |
-
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
| 308 |
-
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) {
|
| 309 |
-
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
| 310 |
-
}
|
| 311 |
-
if (warp_group_role == WarpGroupRole::Consumer) {
|
| 312 |
-
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
| 313 |
-
}
|
| 314 |
-
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
| 315 |
-
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
| 316 |
-
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
| 317 |
-
if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
|
| 318 |
-
epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
|
| 319 |
-
}
|
| 320 |
-
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
| 321 |
-
|
| 322 |
-
// Epilogue Store pipeline
|
| 323 |
-
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
| 324 |
-
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
| 325 |
-
epi_store_pipeline_params.always_wait = true;
|
| 326 |
-
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
| 327 |
-
|
| 328 |
-
// Initialize starting pipeline states for the collectives
|
| 329 |
-
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
| 330 |
-
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
| 331 |
-
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
| 332 |
-
|
| 333 |
-
// For the DMA Load (producer) we start with an opposite phase
|
| 334 |
-
// i.e., we skip all waits since we know that the buffer is indeed empty
|
| 335 |
-
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
| 336 |
-
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
| 337 |
-
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
| 338 |
-
|
| 339 |
-
auto cluster_wait_fn = [&] () {
|
| 340 |
-
// We need this to guarantee that the Pipeline init is visible
|
| 341 |
-
// To all producers and consumer thread blocks in the Cluster
|
| 342 |
-
if constexpr (size(ClusterShape{}) > 1) {
|
| 343 |
-
// Non-prefetcher warps arrive and wait,
|
| 344 |
-
// Prefetcher warp can go ahead without waiting.
|
| 345 |
-
cute::cluster_arrive_relaxed();
|
| 346 |
-
if (warp_group_role != WarpGroupRole::Producer ||
|
| 347 |
-
producer_warp_role != ProducerWarpRole::PrefetchMK) {
|
| 348 |
-
cute::cluster_wait();
|
| 349 |
-
}
|
| 350 |
-
return [] () {};
|
| 351 |
-
}
|
| 352 |
-
else {
|
| 353 |
-
// __syncthreads() but only for non prefetcher warps
|
| 354 |
-
if (should_prefetch) {
|
| 355 |
-
|
| 356 |
-
// Use a named barrier to let the prefetcher warp start loading into the L2
|
| 357 |
-
// without waiting to sync with all other warps.
|
| 358 |
-
// All other warps need to sync because the mainloop pipeline init
|
| 359 |
-
// should be visible to all of them.
|
| 360 |
-
// Prefetcher has its own barriers, and the only warps it would need to sync
|
| 361 |
-
// with would be the DMA warps.
|
| 362 |
-
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
|
| 363 |
-
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
| 364 |
-
blockDim.x * blockDim.y * blockDim.z,
|
| 365 |
-
/*id*/ 0);
|
| 366 |
-
// Prefetcher warp doesn't arrive on this barrier.
|
| 367 |
-
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
| 368 |
-
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
|
| 369 |
-
/*id*/ 1);
|
| 370 |
-
|
| 371 |
-
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
| 372 |
-
__syncwarp();
|
| 373 |
-
prefetcher_arrive_barrier.arrive();
|
| 374 |
-
}
|
| 375 |
-
else if (warp_group_role == WarpGroupRole::Producer) {
|
| 376 |
-
prefetcher_arrive_barrier.arrive_and_wait();
|
| 377 |
-
cluster_arrive_barrier.arrive_and_wait();
|
| 378 |
-
}
|
| 379 |
-
else {
|
| 380 |
-
prefetcher_arrive_barrier.arrive();
|
| 381 |
-
cluster_arrive_barrier.arrive_and_wait();
|
| 382 |
-
}
|
| 383 |
-
} else {
|
| 384 |
-
__syncthreads();
|
| 385 |
-
}
|
| 386 |
-
return [] () {};
|
| 387 |
-
}
|
| 388 |
-
} ();
|
| 389 |
-
|
| 390 |
-
// Preconditions
|
| 391 |
-
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 392 |
-
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 393 |
-
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 394 |
-
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
| 395 |
-
|
| 396 |
-
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
| 397 |
-
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
| 398 |
-
|
| 399 |
-
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
| 400 |
-
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
| 401 |
-
TiledMma tiled_mma;
|
| 402 |
-
|
| 403 |
-
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
| 404 |
-
CollectiveMainloop collective_mainloop;
|
| 405 |
-
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
| 406 |
-
|
| 407 |
-
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
| 408 |
-
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
| 409 |
-
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
| 410 |
-
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
| 411 |
-
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
|
| 412 |
-
|
| 413 |
-
// Extract out partitioned A and B.
|
| 414 |
-
Tensor gA_mkl = get<0>(load_inputs);
|
| 415 |
-
Tensor gB_nkl = get<1>(load_inputs);
|
| 416 |
-
|
| 417 |
-
// Compute m_coord, n_coord, and l_coord with their post-tiled shapes
|
| 418 |
-
auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
|
| 419 |
-
auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
|
| 420 |
-
auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
|
| 421 |
-
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
| 422 |
-
|
| 423 |
-
// Get pipeline iterators and increments from tensor shapes
|
| 424 |
-
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
| 425 |
-
auto k_tile_count = size<3>(gA_mkl);
|
| 426 |
-
|
| 427 |
-
// Wait for all thread blocks in the Cluster
|
| 428 |
-
cluster_wait_fn();
|
| 429 |
-
|
| 430 |
-
if (warp_group_role == WarpGroupRole::Producer) {
|
| 431 |
-
if (producer_warp_role == ProducerWarpRole::Warp0) {
|
| 432 |
-
if constexpr(SplitWarps) {
|
| 433 |
-
collective_mainloop.load_NK(
|
| 434 |
-
params.mainloop,
|
| 435 |
-
mainloop_pipeline,
|
| 436 |
-
prefetcher_pipeline,
|
| 437 |
-
mainloop_pipe_producer_state,
|
| 438 |
-
gB_nkl,
|
| 439 |
-
blk_coord,
|
| 440 |
-
k_tile_iter, k_tile_count,
|
| 441 |
-
lane_idx,
|
| 442 |
-
block_rank_in_cluster,
|
| 443 |
-
shared_storage.tensors.mainloop
|
| 444 |
-
);
|
| 445 |
-
}
|
| 446 |
-
else {
|
| 447 |
-
collective_mainloop.load(
|
| 448 |
-
params.mainloop,
|
| 449 |
-
mainloop_pipeline,
|
| 450 |
-
prefetcher_pipeline,
|
| 451 |
-
mainloop_pipe_producer_state,
|
| 452 |
-
gA_mkl, gB_nkl,
|
| 453 |
-
blk_coord,
|
| 454 |
-
k_tile_iter, k_tile_count,
|
| 455 |
-
lane_idx,
|
| 456 |
-
block_rank_in_cluster,
|
| 457 |
-
shared_storage.tensors.mainloop
|
| 458 |
-
);
|
| 459 |
-
}
|
| 460 |
-
// Update starting mainloop pipeline state for the pipeline drain
|
| 461 |
-
mainloop_pipe_producer_state.advance(k_tile_count);
|
| 462 |
-
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
| 463 |
-
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
| 464 |
-
|
| 465 |
-
if (collective_epilogue.is_producer_load_needed()) {
|
| 466 |
-
// Ensure warp is converged before issuing epilogue loads
|
| 467 |
-
__syncwarp();
|
| 468 |
-
epi_load_pipe_producer_state = collective_epilogue.load(
|
| 469 |
-
epi_load_pipeline,
|
| 470 |
-
epi_load_pipe_producer_state,
|
| 471 |
-
problem_shape_MNKL,
|
| 472 |
-
blk_shape,
|
| 473 |
-
blk_coord,
|
| 474 |
-
tiled_mma,
|
| 475 |
-
lane_idx,
|
| 476 |
-
shared_storage.tensors.epilogue
|
| 477 |
-
);
|
| 478 |
-
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
| 479 |
-
}
|
| 480 |
-
}
|
| 481 |
-
else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) {
|
| 482 |
-
collective_mainloop.load_MK(
|
| 483 |
-
params.mainloop,
|
| 484 |
-
mainloop_pipeline,
|
| 485 |
-
prefetcher_pipeline,
|
| 486 |
-
mainloop_pipe_producer_state,
|
| 487 |
-
gA_mkl,
|
| 488 |
-
blk_coord,
|
| 489 |
-
k_tile_iter, k_tile_count,
|
| 490 |
-
lane_idx,
|
| 491 |
-
block_rank_in_cluster,
|
| 492 |
-
shared_storage.tensors.mainloop
|
| 493 |
-
);
|
| 494 |
-
// Update starting mainloop pipeline state for the pipeline drain
|
| 495 |
-
mainloop_pipe_producer_state.advance(k_tile_count);
|
| 496 |
-
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
| 497 |
-
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
| 498 |
-
} else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) {
|
| 499 |
-
collective_mainloop.prefetch_MK(
|
| 500 |
-
params.mainloop,
|
| 501 |
-
prefetcher_pipeline,
|
| 502 |
-
mainloop_pipe_producer_state,
|
| 503 |
-
gA_mkl,
|
| 504 |
-
blk_coord,
|
| 505 |
-
k_tile_iter, k_tile_count,
|
| 506 |
-
lane_idx,
|
| 507 |
-
block_rank_in_cluster,
|
| 508 |
-
shared_storage.tensors.mainloop
|
| 509 |
-
);
|
| 510 |
-
}
|
| 511 |
-
}
|
| 512 |
-
else if (warp_group_role == WarpGroupRole::Consumer) {
|
| 513 |
-
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
| 514 |
-
|
| 515 |
-
collective_mainloop.mma(
|
| 516 |
-
mainloop_pipeline,
|
| 517 |
-
mainloop_pipe_consumer_state,
|
| 518 |
-
accumulators,
|
| 519 |
-
k_tile_count,
|
| 520 |
-
warp_group_thread_idx,
|
| 521 |
-
shared_storage.tensors.mainloop,
|
| 522 |
-
params.mainloop
|
| 523 |
-
);
|
| 524 |
-
|
| 525 |
-
// Make sure the math instructions are done and free buffers before entering the epilogue
|
| 526 |
-
collective_mainloop.mma_tail(
|
| 527 |
-
mainloop_pipeline,
|
| 528 |
-
mainloop_pipe_consumer_state,
|
| 529 |
-
k_tile_count
|
| 530 |
-
);
|
| 531 |
-
|
| 532 |
-
// Epilogue and write to gD
|
| 533 |
-
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
| 534 |
-
collective_epilogue.store(
|
| 535 |
-
epi_load_pipeline,
|
| 536 |
-
epi_load_pipe_consumer_state,
|
| 537 |
-
epi_store_pipeline,
|
| 538 |
-
epi_store_pipe_producer_state,
|
| 539 |
-
problem_shape_MNKL,
|
| 540 |
-
blk_shape,
|
| 541 |
-
blk_coord,
|
| 542 |
-
accumulators,
|
| 543 |
-
tiled_mma,
|
| 544 |
-
warp_group_thread_idx,
|
| 545 |
-
shared_storage.tensors.epilogue
|
| 546 |
-
);
|
| 547 |
-
|
| 548 |
-
collective_epilogue.store_tail(
|
| 549 |
-
epi_load_pipeline,
|
| 550 |
-
epi_load_pipe_consumer_state_next,
|
| 551 |
-
epi_store_pipeline,
|
| 552 |
-
epi_store_pipe_producer_state_next
|
| 553 |
-
);
|
| 554 |
-
}
|
| 555 |
-
#endif
|
| 556 |
-
}
|
| 557 |
-
};
|
| 558 |
-
|
| 559 |
-
///////////////////////////////////////////////////////////////////////////////
|
| 560 |
-
|
| 561 |
-
} // namespace cutlass::gemm::kernel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|