Build uploaded using `kernels` (batch 7/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h +432 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h +252 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h +142 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h +661 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h +789 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h +1500 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h +641 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h +241 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h +1234 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h +406 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h +89 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h +125 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h +104 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h +147 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp +1271 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h +218 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h +132 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h +206 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h +203 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h +2860 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h +89 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h +561 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h +377 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h +679 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h +143 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h +78 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h +267 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp +394 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h +821 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h +1239 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp +94 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp +63 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp +271 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +917 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +785 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h +658 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h +519 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp +601 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h +194 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp +137 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp +448 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h +270 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h +388 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +269 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp +136 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp +65 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h +322 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h +1927 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h +2007 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +357 -0
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/arch/mma.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/quaternion.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/layout/matrix.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace arch {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Matrix multiply-add operation
|
| 53 |
+
template <
|
| 54 |
+
/// Layout of A matrix
|
| 55 |
+
typename LayoutA,
|
| 56 |
+
/// Layout of B matrix
|
| 57 |
+
typename LayoutB,
|
| 58 |
+
/// Layout of C matrix
|
| 59 |
+
typename LayoutC
|
| 60 |
+
>
|
| 61 |
+
struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
|
| 62 |
+
|
| 63 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 64 |
+
using Operator = OpMultiplyAdd;
|
| 65 |
+
using ElementC = float;
|
| 66 |
+
|
| 67 |
+
CUTLASS_HOST_DEVICE
|
| 68 |
+
void operator()(
|
| 69 |
+
Array<float, 1> &d,
|
| 70 |
+
Array<float, 1> const &a,
|
| 71 |
+
Array<float, 1> const &b,
|
| 72 |
+
Array<float, 1> const &c
|
| 73 |
+
) {
|
| 74 |
+
d[0] = a[0] * b[0] + c[0];
|
| 75 |
+
}
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 79 |
+
|
| 80 |
+
/// Matrix multiply-add operation
|
| 81 |
+
template <
|
| 82 |
+
/// Layout of A matrix
|
| 83 |
+
typename LayoutA,
|
| 84 |
+
/// Layout of B matrix
|
| 85 |
+
typename LayoutB,
|
| 86 |
+
/// Layout of C matrix
|
| 87 |
+
typename LayoutC
|
| 88 |
+
>
|
| 89 |
+
struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
|
| 90 |
+
|
| 91 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 92 |
+
using Operator = OpMultiplyAdd;
|
| 93 |
+
using ElementC = double;
|
| 94 |
+
|
| 95 |
+
CUTLASS_HOST_DEVICE
|
| 96 |
+
void operator()(
|
| 97 |
+
Array<double, 1> &d,
|
| 98 |
+
Array<double, 1> const &a,
|
| 99 |
+
Array<double, 1> const &b,
|
| 100 |
+
Array<double, 1> const &c
|
| 101 |
+
) {
|
| 102 |
+
|
| 103 |
+
d[0] = a[0] * b[0] + c[0];
|
| 104 |
+
}
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 108 |
+
|
| 109 |
+
/// Matrix multiply-add operation
|
| 110 |
+
template <
|
| 111 |
+
/// Layout of A matrix
|
| 112 |
+
typename LayoutA,
|
| 113 |
+
/// Layout of B matrix
|
| 114 |
+
typename LayoutB,
|
| 115 |
+
/// Layout of C matrix
|
| 116 |
+
typename LayoutC
|
| 117 |
+
>
|
| 118 |
+
struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
|
| 119 |
+
|
| 120 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 121 |
+
using Operator = OpMultiplyAdd;
|
| 122 |
+
using ElementC = int;
|
| 123 |
+
|
| 124 |
+
CUTLASS_HOST_DEVICE
|
| 125 |
+
void operator()(
|
| 126 |
+
Array<int, 1> &d,
|
| 127 |
+
Array<int, 1> const &a,
|
| 128 |
+
Array<int, 1> const &b,
|
| 129 |
+
Array<int, 1> const &c
|
| 130 |
+
) {
|
| 131 |
+
|
| 132 |
+
d[0] = a[0] * b[0] + c[0];
|
| 133 |
+
}
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 137 |
+
|
| 138 |
+
/// Matrix multiply-add operation
|
| 139 |
+
template <
|
| 140 |
+
/// Layout of A matrix
|
| 141 |
+
typename LayoutA,
|
| 142 |
+
/// Layout of B matrix
|
| 143 |
+
typename LayoutB,
|
| 144 |
+
/// Layout of C matrix
|
| 145 |
+
typename LayoutC
|
| 146 |
+
>
|
| 147 |
+
struct Mma<
|
| 148 |
+
gemm::GemmShape<1, 1, 1>,
|
| 149 |
+
1,
|
| 150 |
+
complex<float>,
|
| 151 |
+
LayoutA,
|
| 152 |
+
complex<float>,
|
| 153 |
+
LayoutB,
|
| 154 |
+
complex<float>,
|
| 155 |
+
LayoutC,
|
| 156 |
+
OpMultiplyAdd> {
|
| 157 |
+
|
| 158 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 159 |
+
using Operator = OpMultiplyAddComplex;
|
| 160 |
+
using ElementC = complex<float>;
|
| 161 |
+
|
| 162 |
+
CUTLASS_HOST_DEVICE
|
| 163 |
+
void operator()(
|
| 164 |
+
Array<complex<float>, 1> &d,
|
| 165 |
+
Array<complex<float>, 1> const &a,
|
| 166 |
+
Array<complex<float>, 1> const &b,
|
| 167 |
+
Array<complex<float>, 1> const &c
|
| 168 |
+
) {
|
| 169 |
+
|
| 170 |
+
d[0].real() = a[0].real() * b[0].real() + c[0].real();
|
| 171 |
+
d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
|
| 172 |
+
d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
|
| 173 |
+
d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
|
| 174 |
+
}
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 178 |
+
|
| 179 |
+
/// Matrix multiply-add operation
|
| 180 |
+
template <
|
| 181 |
+
/// Layout of A matrix
|
| 182 |
+
typename LayoutA,
|
| 183 |
+
/// Layout of B matrix
|
| 184 |
+
typename LayoutB,
|
| 185 |
+
/// Layout of C matrix
|
| 186 |
+
typename LayoutC
|
| 187 |
+
>
|
| 188 |
+
struct Mma<
|
| 189 |
+
gemm::GemmShape<1, 1, 1>,
|
| 190 |
+
1,
|
| 191 |
+
complex<float>,
|
| 192 |
+
LayoutA,
|
| 193 |
+
float,
|
| 194 |
+
LayoutB,
|
| 195 |
+
complex<float>,
|
| 196 |
+
LayoutC,
|
| 197 |
+
OpMultiplyAdd> {
|
| 198 |
+
|
| 199 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 200 |
+
using Operator = OpMultiplyAddComplex;
|
| 201 |
+
using ElementC = complex<float>;
|
| 202 |
+
|
| 203 |
+
CUTLASS_HOST_DEVICE
|
| 204 |
+
void operator()(
|
| 205 |
+
Array<complex<float>, 1> &d,
|
| 206 |
+
Array<complex<float>, 1> const &a,
|
| 207 |
+
Array<float, 1> const &b,
|
| 208 |
+
Array<complex<float>, 1> const &c
|
| 209 |
+
) {
|
| 210 |
+
|
| 211 |
+
d[0].real() = a[0].real() * b[0] + c[0].real();
|
| 212 |
+
d[0].imag() = a[0].imag() * b[0] + c[0].imag();
|
| 213 |
+
}
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 217 |
+
|
| 218 |
+
/// Matrix multiply-add operation
|
| 219 |
+
template <
|
| 220 |
+
/// Layout of A matrix
|
| 221 |
+
typename LayoutA,
|
| 222 |
+
/// Layout of B matrix
|
| 223 |
+
typename LayoutB,
|
| 224 |
+
/// Layout of C matrix
|
| 225 |
+
typename LayoutC
|
| 226 |
+
>
|
| 227 |
+
struct Mma<
|
| 228 |
+
gemm::GemmShape<1, 1, 1>,
|
| 229 |
+
1,
|
| 230 |
+
float,
|
| 231 |
+
LayoutA,
|
| 232 |
+
complex<float>,
|
| 233 |
+
LayoutB,
|
| 234 |
+
complex<float>,
|
| 235 |
+
LayoutC,
|
| 236 |
+
OpMultiplyAdd> {
|
| 237 |
+
|
| 238 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 239 |
+
using Operator = OpMultiplyAddComplex;
|
| 240 |
+
using ElementC = complex<float>;
|
| 241 |
+
|
| 242 |
+
CUTLASS_HOST_DEVICE
|
| 243 |
+
void operator()(
|
| 244 |
+
Array<complex<float>, 1> &d,
|
| 245 |
+
Array<float, 1> const &a,
|
| 246 |
+
Array<complex<float>, 1> const &b,
|
| 247 |
+
Array<complex<float>, 1> const &c
|
| 248 |
+
) {
|
| 249 |
+
|
| 250 |
+
d[0].real() = a[0] * b[0].real() + c[0].real();
|
| 251 |
+
d[0].imag() = a[0] * b[0].imag() + d[0].imag();
|
| 252 |
+
}
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 256 |
+
|
| 257 |
+
/// Matrix multiply-add operation
|
| 258 |
+
template <
|
| 259 |
+
/// Layout of A matrix
|
| 260 |
+
typename LayoutA,
|
| 261 |
+
/// Layout of B matrix
|
| 262 |
+
typename LayoutB,
|
| 263 |
+
/// Layout of C matrix
|
| 264 |
+
typename LayoutC
|
| 265 |
+
>
|
| 266 |
+
struct Mma<
|
| 267 |
+
gemm::GemmShape<1, 1, 1>,
|
| 268 |
+
1,
|
| 269 |
+
complex<double>,
|
| 270 |
+
LayoutA,
|
| 271 |
+
complex<double>,
|
| 272 |
+
LayoutB,
|
| 273 |
+
complex<double>,
|
| 274 |
+
LayoutC,
|
| 275 |
+
OpMultiplyAdd> {
|
| 276 |
+
|
| 277 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 278 |
+
using Operator = OpMultiplyAddComplex;
|
| 279 |
+
using ElementC = complex<double>;
|
| 280 |
+
|
| 281 |
+
CUTLASS_HOST_DEVICE
|
| 282 |
+
void operator()(
|
| 283 |
+
Array<complex<double>, 1> &d,
|
| 284 |
+
Array<complex<double>, 1> const &a,
|
| 285 |
+
Array<complex<double>, 1> const &b,
|
| 286 |
+
Array<complex<double>, 1> const &c
|
| 287 |
+
) {
|
| 288 |
+
|
| 289 |
+
d[0].real() = a[0].real() * b[0].real() + c[0].real();
|
| 290 |
+
d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
|
| 291 |
+
d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
|
| 292 |
+
d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
|
| 293 |
+
}
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
/// Matrix multiply-add operation
|
| 297 |
+
template <
|
| 298 |
+
/// Layout of A matrix
|
| 299 |
+
typename LayoutA,
|
| 300 |
+
/// Layout of B matrix
|
| 301 |
+
typename LayoutB,
|
| 302 |
+
/// Layout of C matrix
|
| 303 |
+
typename LayoutC
|
| 304 |
+
>
|
| 305 |
+
struct Mma<
|
| 306 |
+
gemm::GemmShape<1, 1, 1>,
|
| 307 |
+
1,
|
| 308 |
+
complex<double>,
|
| 309 |
+
LayoutA,
|
| 310 |
+
double,
|
| 311 |
+
LayoutB,
|
| 312 |
+
complex<double>,
|
| 313 |
+
LayoutC,
|
| 314 |
+
OpMultiplyAdd> {
|
| 315 |
+
|
| 316 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 317 |
+
using Operator = OpMultiplyAddComplex;
|
| 318 |
+
using ElementC = complex<double>;
|
| 319 |
+
|
| 320 |
+
CUTLASS_HOST_DEVICE
|
| 321 |
+
void operator()(
|
| 322 |
+
Array<complex<double>, 1> &d,
|
| 323 |
+
Array<complex<double>, 1> const &a,
|
| 324 |
+
Array<double, 1> const &b,
|
| 325 |
+
Array<complex<double>, 1> const &c
|
| 326 |
+
) {
|
| 327 |
+
|
| 328 |
+
d[0].real() = a[0].real() * b[0] + c[0].real();
|
| 329 |
+
d[0].imag() = a[0].imag() * b[0] + c[0].imag();
|
| 330 |
+
}
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
/// Matrix multiply-add operation
|
| 334 |
+
template <
|
| 335 |
+
/// Layout of A matrix
|
| 336 |
+
typename LayoutA,
|
| 337 |
+
/// Layout of B matrix
|
| 338 |
+
typename LayoutB,
|
| 339 |
+
/// Layout of C matrix
|
| 340 |
+
typename LayoutC
|
| 341 |
+
>
|
| 342 |
+
struct Mma<
|
| 343 |
+
gemm::GemmShape<1, 1, 1>,
|
| 344 |
+
1,
|
| 345 |
+
double,
|
| 346 |
+
LayoutA,
|
| 347 |
+
complex<double>,
|
| 348 |
+
LayoutB,
|
| 349 |
+
complex<double>,
|
| 350 |
+
LayoutC,
|
| 351 |
+
OpMultiplyAdd> {
|
| 352 |
+
|
| 353 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 354 |
+
using Operator = OpMultiplyAddComplex;
|
| 355 |
+
using ElementC = complex<double>;
|
| 356 |
+
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
void operator()(
|
| 359 |
+
Array<complex<double>, 1> &d,
|
| 360 |
+
Array<double, 1> const &a,
|
| 361 |
+
Array<complex<double>, 1> const &b,
|
| 362 |
+
Array<complex<double>, 1> const &c
|
| 363 |
+
) {
|
| 364 |
+
|
| 365 |
+
d[0].real() = a[0] * b[0].real() + c[0].real();
|
| 366 |
+
d[0].imag() = a[0] * b[0].imag() + d[0].imag();
|
| 367 |
+
}
|
| 368 |
+
};
|
| 369 |
+
|
| 370 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 371 |
+
|
| 372 |
+
/// Matrix multiply-add operation
|
| 373 |
+
template <
|
| 374 |
+
/// Layout of A matrix
|
| 375 |
+
typename LayoutA,
|
| 376 |
+
/// Layout of B matrix
|
| 377 |
+
typename LayoutB,
|
| 378 |
+
/// Layout of C matrix
|
| 379 |
+
typename LayoutC
|
| 380 |
+
>
|
| 381 |
+
struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
|
| 382 |
+
|
| 383 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 384 |
+
using Operator = OpMultiplyAdd;
|
| 385 |
+
using ElementC = float;
|
| 386 |
+
|
| 387 |
+
CUTLASS_HOST_DEVICE
|
| 388 |
+
void operator()(
|
| 389 |
+
Array<float, 1> &d,
|
| 390 |
+
Array<half_t, 1> const &a,
|
| 391 |
+
Array<half_t, 1> const &b,
|
| 392 |
+
Array<float, 1> const &c
|
| 393 |
+
) {
|
| 394 |
+
d[0] = float(a[0]) * float(b[0]) + c[0];
|
| 395 |
+
}
|
| 396 |
+
};
|
| 397 |
+
|
| 398 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 399 |
+
|
| 400 |
+
/// Matrix multiply-add operation for Quaternions
|
| 401 |
+
template <
|
| 402 |
+
/// Layout of A matrix
|
| 403 |
+
typename LayoutA,
|
| 404 |
+
/// Layout of B matrix
|
| 405 |
+
typename LayoutB,
|
| 406 |
+
/// Layout of C matrix
|
| 407 |
+
typename LayoutC
|
| 408 |
+
>
|
| 409 |
+
struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<float>, LayoutB, Quaternion<float>, LayoutC, OpMultiplyAdd> {
|
| 410 |
+
|
| 411 |
+
using Shape = gemm::GemmShape<1, 1, 1>;
|
| 412 |
+
using Operator = OpMultiplyAdd;
|
| 413 |
+
using Element = Quaternion<float>;
|
| 414 |
+
using ElementC = Element;
|
| 415 |
+
|
| 416 |
+
CUTLASS_HOST_DEVICE
|
| 417 |
+
void operator()(
|
| 418 |
+
Array<Element, 1> &d,
|
| 419 |
+
Array<Element, 1> const &a,
|
| 420 |
+
Array<Element, 1> const &b,
|
| 421 |
+
Array<Element, 1> const &c
|
| 422 |
+
) {
|
| 423 |
+
multiply_add<Element, Element, Element> op;
|
| 424 |
+
d[0] = op(a[0], b[0], c[0]);
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
};
|
| 428 |
+
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <cuda_fp16.h>
|
| 38 |
+
|
| 39 |
+
#include "cutlass/arch/mma.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/layout/matrix.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace arch {
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Matrix multiply-add operation
|
| 51 |
+
template <typename LayoutA, typename LayoutB, typename LayoutC>
|
| 52 |
+
struct Mma<
|
| 53 |
+
gemm::GemmShape<2,1,1>,
|
| 54 |
+
1,
|
| 55 |
+
half_t,
|
| 56 |
+
LayoutA,
|
| 57 |
+
half_t,
|
| 58 |
+
LayoutB,
|
| 59 |
+
half_t,
|
| 60 |
+
LayoutC,
|
| 61 |
+
OpMultiplyAdd> {
|
| 62 |
+
|
| 63 |
+
using Shape = gemm::GemmShape<2, 1, 1>;
|
| 64 |
+
using Operator = OpMultiplyAdd;
|
| 65 |
+
using ElementC = half_t;
|
| 66 |
+
|
| 67 |
+
CUTLASS_HOST_DEVICE
|
| 68 |
+
void operator()(
|
| 69 |
+
Array<half_t, 2> &d,
|
| 70 |
+
Array<half_t, 2> const &a,
|
| 71 |
+
Array<half_t, 1> const &b,
|
| 72 |
+
Array<half_t, 2> const &c
|
| 73 |
+
) {
|
| 74 |
+
|
| 75 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
|
| 76 |
+
|
| 77 |
+
__half2 const & A = reinterpret_cast<__half2 const &>(a);
|
| 78 |
+
__half2 B = __half2half2(reinterpret_cast<__half const &>(b));
|
| 79 |
+
__half2 const & C = reinterpret_cast<__half2 const &>(c);
|
| 80 |
+
|
| 81 |
+
__half2 D = __hfma2(A, B, C);
|
| 82 |
+
|
| 83 |
+
d = reinterpret_cast<Array<half_t, 2> &>(D);
|
| 84 |
+
|
| 85 |
+
#else
|
| 86 |
+
CUTLASS_PRAGMA_UNROLL
|
| 87 |
+
for (int i = 0; i < 2; ++i) {
|
| 88 |
+
d[i] = a[i] * b[0] + c[i];
|
| 89 |
+
}
|
| 90 |
+
#endif
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 95 |
+
|
| 96 |
+
/// Matrix multiply-add operation
|
| 97 |
+
template <typename LayoutA, typename LayoutB>
|
| 98 |
+
struct Mma<
|
| 99 |
+
gemm::GemmShape<1,2,1>,
|
| 100 |
+
1,
|
| 101 |
+
half_t,
|
| 102 |
+
LayoutA,
|
| 103 |
+
half_t,
|
| 104 |
+
LayoutB,
|
| 105 |
+
half_t,
|
| 106 |
+
layout::RowMajor,
|
| 107 |
+
OpMultiplyAdd> {
|
| 108 |
+
|
| 109 |
+
using Shape = gemm::GemmShape<1, 2, 1>;
|
| 110 |
+
using Operator = OpMultiplyAdd;
|
| 111 |
+
using ElementC = half_t;
|
| 112 |
+
|
| 113 |
+
CUTLASS_HOST_DEVICE
|
| 114 |
+
void operator()(
|
| 115 |
+
Array<half_t, 2> &d,
|
| 116 |
+
Array<half_t, 1> const &a,
|
| 117 |
+
Array<half_t, 2> const &b,
|
| 118 |
+
Array<half_t, 2> const &c
|
| 119 |
+
) {
|
| 120 |
+
|
| 121 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
|
| 122 |
+
|
| 123 |
+
__half2 const & A = __half2half2(reinterpret_cast<__half const &>(a));
|
| 124 |
+
__half2 B = reinterpret_cast<__half2 const &>(b);
|
| 125 |
+
__half2 const & C = reinterpret_cast<__half2 const &>(c);
|
| 126 |
+
|
| 127 |
+
__half2 D = __hfma2(A, B, C);
|
| 128 |
+
|
| 129 |
+
d = reinterpret_cast<Array<half_t, 2> &>(D);
|
| 130 |
+
|
| 131 |
+
#else
|
| 132 |
+
CUTLASS_PRAGMA_UNROLL
|
| 133 |
+
for (int i = 0; i < 2; ++i) {
|
| 134 |
+
d[i] = a[0] * b[i] + c[i];
|
| 135 |
+
}
|
| 136 |
+
#endif
|
| 137 |
+
}
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 141 |
+
|
| 142 |
+
/// Matrix multiply-add operation
|
| 143 |
+
template <>
|
| 144 |
+
struct Mma <
|
| 145 |
+
gemm::GemmShape<2, 2, 1>,
|
| 146 |
+
1,
|
| 147 |
+
half_t,
|
| 148 |
+
layout::ColumnMajor,
|
| 149 |
+
half_t,
|
| 150 |
+
layout::RowMajor,
|
| 151 |
+
half_t,
|
| 152 |
+
layout::ColumnMajor,
|
| 153 |
+
OpMultiplyAdd> {
|
| 154 |
+
|
| 155 |
+
using Shape = gemm::GemmShape<2, 2, 1>;
|
| 156 |
+
using Operator = OpMultiplyAdd;
|
| 157 |
+
using ElementC = half_t;
|
| 158 |
+
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
void operator()(
|
| 161 |
+
Array<half_t, 4> &d,
|
| 162 |
+
Array<half_t, 2> const &a,
|
| 163 |
+
Array<half_t, 2> const &b,
|
| 164 |
+
Array<half_t, 4> const &c
|
| 165 |
+
) {
|
| 166 |
+
|
| 167 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
|
| 168 |
+
|
| 169 |
+
__half2 const & A = reinterpret_cast<__half2 const &>(a);
|
| 170 |
+
__half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));
|
| 171 |
+
__half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));
|
| 172 |
+
|
| 173 |
+
__half2 const *C = reinterpret_cast<__half2 const *>(&c);
|
| 174 |
+
|
| 175 |
+
__half2 Dlo = __hfma2(A, Blo, C[0]);
|
| 176 |
+
__half2 Dhi = __hfma2(A, Bhi, C[1]);
|
| 177 |
+
|
| 178 |
+
Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
|
| 179 |
+
|
| 180 |
+
D[0] = reinterpret_cast<Array<half_t, 2> const &>(Dlo);
|
| 181 |
+
D[1] = reinterpret_cast<Array<half_t, 2> const &>(Dhi);
|
| 182 |
+
|
| 183 |
+
#else
|
| 184 |
+
CUTLASS_PRAGMA_UNROLL
|
| 185 |
+
for (int j = 0; j < 2; ++j) {
|
| 186 |
+
CUTLASS_PRAGMA_UNROLL
|
| 187 |
+
for (int i = 0; i < 2; ++i) {
|
| 188 |
+
d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
#endif
|
| 192 |
+
}
|
| 193 |
+
};
|
| 194 |
+
|
| 195 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 196 |
+
|
| 197 |
+
/// Matrix multiply-add operation
|
| 198 |
+
template <>
|
| 199 |
+
struct Mma<
|
| 200 |
+
gemm::GemmShape<2, 2, 1>,
|
| 201 |
+
1,
|
| 202 |
+
half_t,
|
| 203 |
+
layout::ColumnMajor,
|
| 204 |
+
half_t,
|
| 205 |
+
layout::RowMajor,
|
| 206 |
+
half_t,
|
| 207 |
+
layout::RowMajor,
|
| 208 |
+
OpMultiplyAdd> {
|
| 209 |
+
|
| 210 |
+
using Shape = gemm::GemmShape<2, 2, 1>;
|
| 211 |
+
using Operator = OpMultiplyAdd;
|
| 212 |
+
using ElementC = half_t;
|
| 213 |
+
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
void operator()(
|
| 216 |
+
Array<half_t, 4> &d,
|
| 217 |
+
Array<half_t, 2> const &a,
|
| 218 |
+
Array<half_t, 2> const &b,
|
| 219 |
+
Array<half_t, 4> const &c
|
| 220 |
+
) {
|
| 221 |
+
|
| 222 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
|
| 223 |
+
|
| 224 |
+
__half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
|
| 225 |
+
__half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
|
| 226 |
+
__half2 const & B = reinterpret_cast<__half2 const &>(b);
|
| 227 |
+
|
| 228 |
+
__half2 const *C = reinterpret_cast<__half2 const *>(&c);
|
| 229 |
+
|
| 230 |
+
__half2 Dlo = __hfma2(Alo, B, C[0]);
|
| 231 |
+
__half2 Dhi = __hfma2(Ahi, B, C[1]);
|
| 232 |
+
|
| 233 |
+
Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
|
| 234 |
+
|
| 235 |
+
D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);
|
| 236 |
+
D[1] = reinterpret_cast<Array<half_t, 2> &>(Dhi);
|
| 237 |
+
#else
|
| 238 |
+
CUTLASS_PRAGMA_UNROLL
|
| 239 |
+
for (int i = 0; i < 2; ++i) {
|
| 240 |
+
CUTLASS_PRAGMA_UNROLL
|
| 241 |
+
for (int j = 0; j < 2; ++j) {
|
| 242 |
+
d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
#endif
|
| 246 |
+
}
|
| 247 |
+
};
|
| 248 |
+
|
| 249 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 250 |
+
|
| 251 |
+
}
|
| 252 |
+
}
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/layout/matrix.h"
|
| 38 |
+
|
| 39 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace arch {
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
/// Matrix multiply-add operation
|
| 47 |
+
template <typename LayoutA, typename LayoutB, typename LayoutC>
|
| 48 |
+
struct Mma<
|
| 49 |
+
gemm::GemmShape<1,1,4>,
|
| 50 |
+
1,
|
| 51 |
+
int8_t,
|
| 52 |
+
LayoutA,
|
| 53 |
+
int8_t,
|
| 54 |
+
LayoutB,
|
| 55 |
+
int,
|
| 56 |
+
LayoutC,
|
| 57 |
+
OpMultiplyAdd> {
|
| 58 |
+
|
| 59 |
+
using Shape = gemm::GemmShape<1, 1, 4>;
|
| 60 |
+
using Operator = OpMultiplyAdd;
|
| 61 |
+
using ElementC = int;
|
| 62 |
+
|
| 63 |
+
CUTLASS_HOST_DEVICE
|
| 64 |
+
void operator()(
|
| 65 |
+
Array<int, 1> &d,
|
| 66 |
+
Array<int8_t, 4> const &a,
|
| 67 |
+
Array<int8_t, 4> const &b,
|
| 68 |
+
Array<int, 1> const &c
|
| 69 |
+
) {
|
| 70 |
+
|
| 71 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
|
| 72 |
+
|
| 73 |
+
unsigned const &A = reinterpret_cast<unsigned const &>(a);
|
| 74 |
+
unsigned const &B = reinterpret_cast<unsigned const &>(b);
|
| 75 |
+
|
| 76 |
+
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
| 77 |
+
: "=r"(d[0])
|
| 78 |
+
: "r"(A), "r"(B), "r"(c[0]));
|
| 79 |
+
|
| 80 |
+
#else
|
| 81 |
+
|
| 82 |
+
d[0] = c[0];
|
| 83 |
+
|
| 84 |
+
CUTLASS_PRAGMA_UNROLL
|
| 85 |
+
for (int k = 0; k < 4; ++k) {
|
| 86 |
+
d[0] += a[k] * b[k];
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
#endif
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
+
|
| 95 |
+
/// Matrix multiply-add operation
|
| 96 |
+
template <typename LayoutC>
|
| 97 |
+
struct Mma<
|
| 98 |
+
gemm::GemmShape<1, 1, 2>,
|
| 99 |
+
1,
|
| 100 |
+
int16_t,
|
| 101 |
+
layout::RowMajor,
|
| 102 |
+
int16_t,
|
| 103 |
+
layout::ColumnMajor,
|
| 104 |
+
int,
|
| 105 |
+
LayoutC,
|
| 106 |
+
OpMultiplyAdd> {
|
| 107 |
+
|
| 108 |
+
using Shape = gemm::GemmShape<1, 1, 2>;
|
| 109 |
+
using Operator = OpMultiplyAdd;
|
| 110 |
+
using ElementC = int;
|
| 111 |
+
|
| 112 |
+
CUTLASS_HOST_DEVICE
|
| 113 |
+
void operator()(
|
| 114 |
+
Array<int, 1> &d,
|
| 115 |
+
Array<int16_t, 2> const &a,
|
| 116 |
+
Array<int16_t, 2> const &b,
|
| 117 |
+
Array<int, 1> const &c
|
| 118 |
+
) {
|
| 119 |
+
|
| 120 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
|
| 121 |
+
|
| 122 |
+
unsigned const &A = reinterpret_cast<unsigned const &>(a);
|
| 123 |
+
unsigned const &B = reinterpret_cast<unsigned const &>(b);
|
| 124 |
+
|
| 125 |
+
asm volatile("dp2a.s32.s32 %0, %1, %2, %3;"
|
| 126 |
+
: "=r"(d[0])
|
| 127 |
+
: "r"(A), "r"(B), "r"(c[0]));
|
| 128 |
+
#else
|
| 129 |
+
d[0] = c[0];
|
| 130 |
+
|
| 131 |
+
CUTLASS_PRAGMA_UNROLL
|
| 132 |
+
for (int k = 0; k < 2; ++k) {
|
| 133 |
+
d[0] += a[k] * b[k];
|
| 134 |
+
}
|
| 135 |
+
#endif
|
| 136 |
+
}
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 140 |
+
|
| 141 |
+
}
|
| 142 |
+
}
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
#include "cutlass/cutlass.h"
|
| 36 |
+
#include CUDA_STD_HEADER(cassert)
|
| 37 |
+
|
| 38 |
+
#include "mma.h"
|
| 39 |
+
#include "cutlass/layout/matrix.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
|
| 42 |
+
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
|
| 43 |
+
#define CUTLASS_ARCH_MMA_SM70_SUPPORTED
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
|
| 47 |
+
|
| 48 |
+
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1))
|
| 49 |
+
#define CUTLASS_ARCH_MMA_SM70_ENABLED
|
| 50 |
+
#endif
|
| 51 |
+
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
namespace cutlass {
|
| 57 |
+
namespace arch {
|
| 58 |
+
|
| 59 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
//
|
| 61 |
+
// Matrix multiply accumulate 884 - FP16 accumulation
|
| 62 |
+
//
|
| 63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
/// Matrix multiply-add operation: F16 = F16 * F16 + F16
|
| 66 |
+
template <>
|
| 67 |
+
struct Mma<
|
| 68 |
+
gemm::GemmShape<8,8,4>,
|
| 69 |
+
8,
|
| 70 |
+
half_t,
|
| 71 |
+
layout::ColumnMajor,
|
| 72 |
+
half_t,
|
| 73 |
+
layout::ColumnMajor,
|
| 74 |
+
half_t,
|
| 75 |
+
layout::RowMajor,
|
| 76 |
+
OpMultiplyAdd> {
|
| 77 |
+
|
| 78 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 79 |
+
|
| 80 |
+
using ElementA = half_t;
|
| 81 |
+
using LayoutA = layout::ColumnMajor;
|
| 82 |
+
using FragmentA = Array<half_t, 4>;
|
| 83 |
+
|
| 84 |
+
using ElementB = half_t;
|
| 85 |
+
using LayoutB = layout::ColumnMajor;
|
| 86 |
+
using FragmentB = Array<half_t, 4>;
|
| 87 |
+
|
| 88 |
+
using ElementC = half_t;
|
| 89 |
+
using LayoutC = layout::RowMajor;
|
| 90 |
+
using FragmentC = Array<half_t, 8>;
|
| 91 |
+
|
| 92 |
+
using Operator = OpMultiplyAdd;
|
| 93 |
+
using ArchTag = arch::Sm70;
|
| 94 |
+
|
| 95 |
+
CUTLASS_HOST_DEVICE
|
| 96 |
+
void operator()(
|
| 97 |
+
FragmentC &d,
|
| 98 |
+
FragmentA const &a,
|
| 99 |
+
FragmentB const &b,
|
| 100 |
+
FragmentC const &c
|
| 101 |
+
) {
|
| 102 |
+
|
| 103 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 104 |
+
|
| 105 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 106 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 107 |
+
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
|
| 108 |
+
unsigned *D = reinterpret_cast<unsigned *>(&d);
|
| 109 |
+
|
| 110 |
+
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
|
| 111 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 112 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
|
| 113 |
+
);
|
| 114 |
+
|
| 115 |
+
#else
|
| 116 |
+
assert(0);
|
| 117 |
+
#if defined(__CUDA_ARCH__)
|
| 118 |
+
asm volatile ("brkpt;\n" ::);
|
| 119 |
+
#endif
|
| 120 |
+
#endif
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
/// Matrix multiply-add operation: F16 = F16 * F16 + F16
|
| 125 |
+
template <>
|
| 126 |
+
struct Mma<
|
| 127 |
+
gemm::GemmShape<8, 8, 4>,
|
| 128 |
+
8,
|
| 129 |
+
half_t,
|
| 130 |
+
layout::ColumnMajor,
|
| 131 |
+
half_t,
|
| 132 |
+
layout::RowMajor,
|
| 133 |
+
half_t,
|
| 134 |
+
layout::RowMajor,
|
| 135 |
+
OpMultiplyAdd> {
|
| 136 |
+
|
| 137 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 138 |
+
|
| 139 |
+
using ElementA = half_t;
|
| 140 |
+
using LayoutA = layout::ColumnMajor;
|
| 141 |
+
using FragmentA = Array<half_t, 4>;
|
| 142 |
+
|
| 143 |
+
using ElementB = half_t;
|
| 144 |
+
using LayoutB = layout::RowMajor;
|
| 145 |
+
using FragmentB = Array<half_t, 4>;
|
| 146 |
+
|
| 147 |
+
using ElementC = half_t;
|
| 148 |
+
using LayoutC = layout::RowMajor;
|
| 149 |
+
using FragmentC = Array<half_t, 8>;
|
| 150 |
+
|
| 151 |
+
using Operator = OpMultiplyAdd;
|
| 152 |
+
using ArchTag = arch::Sm70;
|
| 153 |
+
|
| 154 |
+
CUTLASS_HOST_DEVICE
|
| 155 |
+
void operator()(
|
| 156 |
+
FragmentC &d,
|
| 157 |
+
FragmentA const &a,
|
| 158 |
+
FragmentB const &b,
|
| 159 |
+
FragmentC const &c
|
| 160 |
+
) {
|
| 161 |
+
|
| 162 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 163 |
+
|
| 164 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 165 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 166 |
+
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
|
| 167 |
+
unsigned *D = reinterpret_cast<unsigned *>(&d);
|
| 168 |
+
|
| 169 |
+
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
|
| 170 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 171 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
|
| 172 |
+
);
|
| 173 |
+
|
| 174 |
+
#else
|
| 175 |
+
assert(0);
|
| 176 |
+
#if defined(__CUDA_ARCH__)
|
| 177 |
+
asm volatile ("brkpt;\n" ::);
|
| 178 |
+
#endif
|
| 179 |
+
#endif
|
| 180 |
+
}
|
| 181 |
+
};
|
| 182 |
+
|
| 183 |
+
/// Matrix multiply-add operation: F16 = F16 * F16 + F16
|
| 184 |
+
template <>
|
| 185 |
+
struct Mma<
|
| 186 |
+
gemm::GemmShape<8, 8, 4>,
|
| 187 |
+
8,
|
| 188 |
+
half_t,
|
| 189 |
+
layout::RowMajor,
|
| 190 |
+
half_t,
|
| 191 |
+
layout::ColumnMajor,
|
| 192 |
+
half_t,
|
| 193 |
+
layout::RowMajor,
|
| 194 |
+
OpMultiplyAdd> {
|
| 195 |
+
|
| 196 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 197 |
+
|
| 198 |
+
using ElementA = half_t;
|
| 199 |
+
using LayoutA = layout::RowMajor;
|
| 200 |
+
using FragmentA = Array<half_t, 4>;
|
| 201 |
+
|
| 202 |
+
using ElementB = half_t;
|
| 203 |
+
using LayoutB = layout::ColumnMajor;
|
| 204 |
+
using FragmentB = Array<half_t, 4>;
|
| 205 |
+
|
| 206 |
+
using ElementC = half_t;
|
| 207 |
+
using LayoutC = layout::RowMajor;
|
| 208 |
+
using FragmentC = Array<half_t, 8>;
|
| 209 |
+
|
| 210 |
+
using Operator = OpMultiplyAdd;
|
| 211 |
+
using ArchTag = arch::Sm70;
|
| 212 |
+
|
| 213 |
+
CUTLASS_HOST_DEVICE
|
| 214 |
+
void operator()(
|
| 215 |
+
FragmentC &d,
|
| 216 |
+
FragmentA const &a,
|
| 217 |
+
FragmentB const &b,
|
| 218 |
+
FragmentC const &c
|
| 219 |
+
) {
|
| 220 |
+
|
| 221 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 222 |
+
|
| 223 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 224 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 225 |
+
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
|
| 226 |
+
unsigned *D = reinterpret_cast<unsigned *>(&d);
|
| 227 |
+
|
| 228 |
+
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
|
| 229 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 230 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
|
| 231 |
+
);
|
| 232 |
+
|
| 233 |
+
#else
|
| 234 |
+
assert(0);
|
| 235 |
+
#if defined(__CUDA_ARCH__)
|
| 236 |
+
asm volatile ("brkpt;\n" ::);
|
| 237 |
+
#endif
|
| 238 |
+
#endif
|
| 239 |
+
}
|
| 240 |
+
};
|
| 241 |
+
|
| 242 |
+
/// Matrix multiply-add operation: F16 = F16 * F16 + F16
|
| 243 |
+
template <>
|
| 244 |
+
struct Mma<
|
| 245 |
+
gemm::GemmShape<8, 8, 4>,
|
| 246 |
+
8,
|
| 247 |
+
half_t,
|
| 248 |
+
layout::RowMajor,
|
| 249 |
+
half_t,
|
| 250 |
+
layout::RowMajor,
|
| 251 |
+
half_t,
|
| 252 |
+
layout::RowMajor,
|
| 253 |
+
OpMultiplyAdd> {
|
| 254 |
+
|
| 255 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 256 |
+
|
| 257 |
+
using ElementA = half_t;
|
| 258 |
+
using LayoutA = layout::RowMajor;
|
| 259 |
+
using FragmentA = Array<half_t, 4>;
|
| 260 |
+
|
| 261 |
+
using ElementB = half_t;
|
| 262 |
+
using LayoutB = layout::RowMajor;
|
| 263 |
+
using FragmentB = Array<half_t, 4>;
|
| 264 |
+
|
| 265 |
+
using ElementC = half_t;
|
| 266 |
+
using LayoutC = layout::RowMajor;
|
| 267 |
+
using FragmentC = Array<half_t, 8>;
|
| 268 |
+
|
| 269 |
+
using Operator = OpMultiplyAdd;
|
| 270 |
+
using ArchTag = arch::Sm70;
|
| 271 |
+
|
| 272 |
+
CUTLASS_HOST_DEVICE
|
| 273 |
+
void operator()(
|
| 274 |
+
FragmentC &d,
|
| 275 |
+
FragmentA const &a,
|
| 276 |
+
FragmentB const &b,
|
| 277 |
+
FragmentC const &c
|
| 278 |
+
) {
|
| 279 |
+
|
| 280 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 281 |
+
|
| 282 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 283 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 284 |
+
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
|
| 285 |
+
unsigned *D = reinterpret_cast<unsigned *>(&d);
|
| 286 |
+
|
| 287 |
+
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
|
| 288 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 289 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
|
| 290 |
+
);
|
| 291 |
+
|
| 292 |
+
#else
|
| 293 |
+
assert(0);
|
| 294 |
+
#if defined(__CUDA_ARCH__)
|
| 295 |
+
asm volatile ("brkpt;\n" ::);
|
| 296 |
+
#endif
|
| 297 |
+
#endif
|
| 298 |
+
}
|
| 299 |
+
};
|
| 300 |
+
|
| 301 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 302 |
+
//
|
| 303 |
+
// Matrix multiply accumulate 884 - FP32 accumulation
|
| 304 |
+
//
|
| 305 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 306 |
+
|
| 307 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 308 |
+
template <>
|
| 309 |
+
struct Mma<
|
| 310 |
+
gemm::GemmShape<8, 8, 4>,
|
| 311 |
+
8,
|
| 312 |
+
half_t,
|
| 313 |
+
layout::ColumnMajor,
|
| 314 |
+
half_t,
|
| 315 |
+
layout::ColumnMajor,
|
| 316 |
+
float,
|
| 317 |
+
layout::RowMajor,
|
| 318 |
+
OpMultiplyAdd> {
|
| 319 |
+
|
| 320 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 321 |
+
|
| 322 |
+
using ElementA = half_t;
|
| 323 |
+
using LayoutA = layout::ColumnMajor;
|
| 324 |
+
using FragmentA = Array<half_t, 4>;
|
| 325 |
+
|
| 326 |
+
using ElementB = half_t;
|
| 327 |
+
using LayoutB = layout::ColumnMajor;
|
| 328 |
+
using FragmentB = Array<half_t, 4>;
|
| 329 |
+
|
| 330 |
+
using ElementC = float;
|
| 331 |
+
using LayoutC = layout::RowMajor;
|
| 332 |
+
using FragmentC = Array<float, 8>;
|
| 333 |
+
|
| 334 |
+
using Operator = OpMultiplyAdd;
|
| 335 |
+
using ArchTag = arch::Sm70;
|
| 336 |
+
|
| 337 |
+
/// Multiply-add
|
| 338 |
+
CUTLASS_HOST_DEVICE
|
| 339 |
+
void operator()(
|
| 340 |
+
FragmentC &d,
|
| 341 |
+
FragmentA const &a,
|
| 342 |
+
FragmentB const &b,
|
| 343 |
+
FragmentC const &c
|
| 344 |
+
) {
|
| 345 |
+
|
| 346 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 347 |
+
|
| 348 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 349 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 350 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 351 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 352 |
+
|
| 353 |
+
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
| 354 |
+
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
|
| 355 |
+
: "=f"(D[0]),
|
| 356 |
+
"=f"(D[1]),
|
| 357 |
+
"=f"(D[2]),
|
| 358 |
+
"=f"(D[3]),
|
| 359 |
+
"=f"(D[4]),
|
| 360 |
+
"=f"(D[5]),
|
| 361 |
+
"=f"(D[6]),
|
| 362 |
+
"=f"(D[7])
|
| 363 |
+
: "r"(A[0]),
|
| 364 |
+
"r"(A[1]),
|
| 365 |
+
"r"(B[0]),
|
| 366 |
+
"r"(B[1]),
|
| 367 |
+
"f"(C[0]),
|
| 368 |
+
"f"(C[1]),
|
| 369 |
+
"f"(C[2]),
|
| 370 |
+
"f"(C[3]),
|
| 371 |
+
"f"(C[4]),
|
| 372 |
+
"f"(C[5]),
|
| 373 |
+
"f"(C[6]),
|
| 374 |
+
"f"(C[7])
|
| 375 |
+
);
|
| 376 |
+
|
| 377 |
+
#else
|
| 378 |
+
assert(0);
|
| 379 |
+
#if defined(__CUDA_ARCH__)
|
| 380 |
+
asm volatile ("brkpt;\n" ::);
|
| 381 |
+
#endif
|
| 382 |
+
#endif
|
| 383 |
+
}
|
| 384 |
+
};
|
| 385 |
+
|
| 386 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 387 |
+
template <>
|
| 388 |
+
struct Mma<
|
| 389 |
+
gemm::GemmShape<8, 8, 4>,
|
| 390 |
+
8,
|
| 391 |
+
half_t,
|
| 392 |
+
layout::ColumnMajor,
|
| 393 |
+
half_t,
|
| 394 |
+
layout::RowMajor,
|
| 395 |
+
float,
|
| 396 |
+
layout::RowMajor,
|
| 397 |
+
OpMultiplyAdd> {
|
| 398 |
+
|
| 399 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 400 |
+
|
| 401 |
+
using ElementA = half_t;
|
| 402 |
+
using LayoutA = layout::ColumnMajor;
|
| 403 |
+
using FragmentA = Array<half_t, 4>;
|
| 404 |
+
|
| 405 |
+
using ElementB = half_t;
|
| 406 |
+
using LayoutB = layout::RowMajor;
|
| 407 |
+
using FragmentB = Array<half_t, 4>;
|
| 408 |
+
|
| 409 |
+
using ElementC = float;
|
| 410 |
+
using LayoutC = layout::RowMajor;
|
| 411 |
+
using FragmentC = Array<float, 8>;
|
| 412 |
+
|
| 413 |
+
using Operator = OpMultiplyAdd;
|
| 414 |
+
using ArchTag = arch::Sm70;
|
| 415 |
+
|
| 416 |
+
/// Multiply-add
|
| 417 |
+
CUTLASS_HOST_DEVICE
|
| 418 |
+
void operator()(
|
| 419 |
+
FragmentC &d,
|
| 420 |
+
FragmentA const &a,
|
| 421 |
+
FragmentB const &b,
|
| 422 |
+
FragmentC const &c
|
| 423 |
+
) {
|
| 424 |
+
|
| 425 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 426 |
+
|
| 427 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 428 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 429 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 430 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 431 |
+
|
| 432 |
+
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
| 433 |
+
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
|
| 434 |
+
: "=f"(D[0]),
|
| 435 |
+
"=f"(D[1]),
|
| 436 |
+
"=f"(D[2]),
|
| 437 |
+
"=f"(D[3]),
|
| 438 |
+
"=f"(D[4]),
|
| 439 |
+
"=f"(D[5]),
|
| 440 |
+
"=f"(D[6]),
|
| 441 |
+
"=f"(D[7])
|
| 442 |
+
: "r"(A[0]),
|
| 443 |
+
"r"(A[1]),
|
| 444 |
+
"r"(B[0]),
|
| 445 |
+
"r"(B[1]),
|
| 446 |
+
"f"(C[0]),
|
| 447 |
+
"f"(C[1]),
|
| 448 |
+
"f"(C[2]),
|
| 449 |
+
"f"(C[3]),
|
| 450 |
+
"f"(C[4]),
|
| 451 |
+
"f"(C[5]),
|
| 452 |
+
"f"(C[6]),
|
| 453 |
+
"f"(C[7])
|
| 454 |
+
);
|
| 455 |
+
|
| 456 |
+
#else
|
| 457 |
+
assert(0);
|
| 458 |
+
#if defined(__CUDA_ARCH__)
|
| 459 |
+
asm volatile ("brkpt;\n" ::);
|
| 460 |
+
#endif
|
| 461 |
+
#endif
|
| 462 |
+
}
|
| 463 |
+
};
|
| 464 |
+
|
| 465 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 466 |
+
template <>
|
| 467 |
+
struct Mma<
|
| 468 |
+
gemm::GemmShape<8, 8, 4>,
|
| 469 |
+
8,
|
| 470 |
+
half_t,
|
| 471 |
+
layout::RowMajor,
|
| 472 |
+
half_t,
|
| 473 |
+
layout::ColumnMajor,
|
| 474 |
+
float,
|
| 475 |
+
layout::RowMajor,
|
| 476 |
+
OpMultiplyAdd> {
|
| 477 |
+
|
| 478 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 479 |
+
|
| 480 |
+
using ElementA = half_t;
|
| 481 |
+
using LayoutA = layout::RowMajor;
|
| 482 |
+
using FragmentA = Array<half_t, 4>;
|
| 483 |
+
|
| 484 |
+
using ElementB = half_t;
|
| 485 |
+
using LayoutB = layout::ColumnMajor;
|
| 486 |
+
using FragmentB = Array<half_t, 4>;
|
| 487 |
+
|
| 488 |
+
using ElementC = float;
|
| 489 |
+
using LayoutC = layout::RowMajor;
|
| 490 |
+
using FragmentC = Array<float, 8>;
|
| 491 |
+
|
| 492 |
+
using Operator = OpMultiplyAdd;
|
| 493 |
+
using ArchTag = arch::Sm70;
|
| 494 |
+
|
| 495 |
+
/// Multiply-add
|
| 496 |
+
CUTLASS_HOST_DEVICE
|
| 497 |
+
void operator()(
|
| 498 |
+
FragmentC &d,
|
| 499 |
+
FragmentA const &a,
|
| 500 |
+
FragmentB const &b,
|
| 501 |
+
FragmentC const &c
|
| 502 |
+
) {
|
| 503 |
+
|
| 504 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 505 |
+
|
| 506 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 507 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 508 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 509 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 510 |
+
|
| 511 |
+
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
| 512 |
+
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
|
| 513 |
+
: "=f"(D[0]),
|
| 514 |
+
"=f"(D[1]),
|
| 515 |
+
"=f"(D[2]),
|
| 516 |
+
"=f"(D[3]),
|
| 517 |
+
"=f"(D[4]),
|
| 518 |
+
"=f"(D[5]),
|
| 519 |
+
"=f"(D[6]),
|
| 520 |
+
"=f"(D[7])
|
| 521 |
+
: "r"(A[0]),
|
| 522 |
+
"r"(A[1]),
|
| 523 |
+
"r"(B[0]),
|
| 524 |
+
"r"(B[1]),
|
| 525 |
+
"f"(C[0]),
|
| 526 |
+
"f"(C[1]),
|
| 527 |
+
"f"(C[2]),
|
| 528 |
+
"f"(C[3]),
|
| 529 |
+
"f"(C[4]),
|
| 530 |
+
"f"(C[5]),
|
| 531 |
+
"f"(C[6]),
|
| 532 |
+
"f"(C[7])
|
| 533 |
+
);
|
| 534 |
+
|
| 535 |
+
#else
|
| 536 |
+
assert(0);
|
| 537 |
+
#if defined(__CUDA_ARCH__)
|
| 538 |
+
asm volatile ("brkpt;\n" ::);
|
| 539 |
+
#endif
|
| 540 |
+
#endif
|
| 541 |
+
}
|
| 542 |
+
};
|
| 543 |
+
|
| 544 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 545 |
+
template <>
|
| 546 |
+
struct Mma<
|
| 547 |
+
gemm::GemmShape<8, 8, 4>,
|
| 548 |
+
8,
|
| 549 |
+
half_t,
|
| 550 |
+
layout::RowMajor,
|
| 551 |
+
half_t,
|
| 552 |
+
layout::RowMajor,
|
| 553 |
+
float,
|
| 554 |
+
layout::RowMajor,
|
| 555 |
+
OpMultiplyAdd> {
|
| 556 |
+
|
| 557 |
+
using Shape = gemm::GemmShape<8, 8, 4>;
|
| 558 |
+
|
| 559 |
+
using ElementA = half_t;
|
| 560 |
+
using LayoutA = layout::RowMajor;
|
| 561 |
+
using FragmentA = Array<half_t, 4>;
|
| 562 |
+
|
| 563 |
+
using ElementB = half_t;
|
| 564 |
+
using LayoutB = layout::RowMajor;
|
| 565 |
+
using FragmentB = Array<half_t, 4>;
|
| 566 |
+
|
| 567 |
+
using ElementC = float;
|
| 568 |
+
using LayoutC = layout::RowMajor;
|
| 569 |
+
using FragmentC = Array<float, 8>;
|
| 570 |
+
|
| 571 |
+
using Operator = OpMultiplyAdd;
|
| 572 |
+
using ArchTag = arch::Sm70;
|
| 573 |
+
|
| 574 |
+
/// Multiply-add
|
| 575 |
+
CUTLASS_HOST_DEVICE
|
| 576 |
+
void operator()(
|
| 577 |
+
FragmentC &d,
|
| 578 |
+
FragmentA const &a,
|
| 579 |
+
FragmentB const &b,
|
| 580 |
+
FragmentC const &c
|
| 581 |
+
) {
|
| 582 |
+
|
| 583 |
+
#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
|
| 584 |
+
|
| 585 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 586 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 587 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 588 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 589 |
+
|
| 590 |
+
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
| 591 |
+
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
|
| 592 |
+
: "=f"(D[0]),
|
| 593 |
+
"=f"(D[1]),
|
| 594 |
+
"=f"(D[2]),
|
| 595 |
+
"=f"(D[3]),
|
| 596 |
+
"=f"(D[4]),
|
| 597 |
+
"=f"(D[5]),
|
| 598 |
+
"=f"(D[6]),
|
| 599 |
+
"=f"(D[7])
|
| 600 |
+
: "r"(A[0]),
|
| 601 |
+
"r"(A[1]),
|
| 602 |
+
"r"(B[0]),
|
| 603 |
+
"r"(B[1]),
|
| 604 |
+
"f"(C[0]),
|
| 605 |
+
"f"(C[1]),
|
| 606 |
+
"f"(C[2]),
|
| 607 |
+
"f"(C[3]),
|
| 608 |
+
"f"(C[4]),
|
| 609 |
+
"f"(C[5]),
|
| 610 |
+
"f"(C[6]),
|
| 611 |
+
"f"(C[7])
|
| 612 |
+
);
|
| 613 |
+
|
| 614 |
+
#else
|
| 615 |
+
assert(0);
|
| 616 |
+
#if defined(__CUDA_ARCH__)
|
| 617 |
+
asm volatile ("brkpt;\n" ::);
|
| 618 |
+
#endif
|
| 619 |
+
#endif
|
| 620 |
+
}
|
| 621 |
+
};
|
| 622 |
+
|
| 623 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 624 |
+
|
| 625 |
+
/// Matrix multiply-add operation specialized for the entire warp
|
| 626 |
+
template <
|
| 627 |
+
typename LayoutA,
|
| 628 |
+
typename LayoutB,
|
| 629 |
+
typename ElementC,
|
| 630 |
+
typename LayoutC,
|
| 631 |
+
typename Operator
|
| 632 |
+
>
|
| 633 |
+
struct Mma<
|
| 634 |
+
gemm::GemmShape<16, 16, 4>,
|
| 635 |
+
32,
|
| 636 |
+
half_t,
|
| 637 |
+
LayoutA,
|
| 638 |
+
half_t,
|
| 639 |
+
LayoutB,
|
| 640 |
+
ElementC,
|
| 641 |
+
LayoutC,
|
| 642 |
+
Operator
|
| 643 |
+
> :
|
| 644 |
+
public Mma<
|
| 645 |
+
gemm::GemmShape<8, 8, 4>,
|
| 646 |
+
8,
|
| 647 |
+
half_t,
|
| 648 |
+
LayoutA,
|
| 649 |
+
half_t,
|
| 650 |
+
LayoutB,
|
| 651 |
+
ElementC,
|
| 652 |
+
LayoutC,
|
| 653 |
+
Operator> {
|
| 654 |
+
|
| 655 |
+
using Shape = gemm::GemmShape<16, 16, 4>;
|
| 656 |
+
};
|
| 657 |
+
|
| 658 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 659 |
+
|
| 660 |
+
} // namespace arch
|
| 661 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply for SM75
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include CUDA_STD_HEADER(cassert)
|
| 38 |
+
|
| 39 |
+
#include "cutlass/arch/wmma.h"
|
| 40 |
+
|
| 41 |
+
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
| 42 |
+
// CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply.
|
| 43 |
+
#include <mma.h>
|
| 44 |
+
#include "cutlass/wmma_array.h"
|
| 45 |
+
#endif
|
| 46 |
+
|
| 47 |
+
// CUTLASS includes
|
| 48 |
+
#include "cutlass/arch/mma.h"
|
| 49 |
+
#include "cutlass/layout/matrix.h"
|
| 50 |
+
#include "cutlass/numeric_types.h"
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
|
| 55 |
+
|
| 56 |
+
#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1
|
| 57 |
+
|
| 58 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
| 59 |
+
#define CUTLASS_ARCH_MMA_SM75_ENABLED
|
| 60 |
+
#endif
|
| 61 |
+
#endif
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
namespace cutlass {
|
| 66 |
+
namespace arch {
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
//
|
| 70 |
+
// Matrix Multiply 1688 - FP16 accumulation
|
| 71 |
+
//
|
| 72 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 73 |
+
|
| 74 |
+
/// Matrix multiply-add operation - F16 = F16 * F16 + F16
|
| 75 |
+
template <>
|
| 76 |
+
struct Mma<
|
| 77 |
+
gemm::GemmShape<16, 8, 8>,
|
| 78 |
+
32,
|
| 79 |
+
half_t,
|
| 80 |
+
layout::RowMajor,
|
| 81 |
+
half_t,
|
| 82 |
+
layout::ColumnMajor,
|
| 83 |
+
half_t,
|
| 84 |
+
layout::RowMajor,
|
| 85 |
+
OpMultiplyAdd> {
|
| 86 |
+
|
| 87 |
+
using Shape = gemm::GemmShape<16, 8, 8>;
|
| 88 |
+
|
| 89 |
+
using ElementA = half_t;
|
| 90 |
+
using LayoutA = layout::RowMajor;
|
| 91 |
+
using FragmentA = Array<half_t, 4>;
|
| 92 |
+
|
| 93 |
+
using ElementB = half_t;
|
| 94 |
+
using LayoutB = layout::ColumnMajor;
|
| 95 |
+
using FragmentB = Array<half_t, 2>;
|
| 96 |
+
|
| 97 |
+
using ElementC = half_t;
|
| 98 |
+
using LayoutC = layout::RowMajor;
|
| 99 |
+
using FragmentC = Array<half_t, 4>;
|
| 100 |
+
|
| 101 |
+
using Operator = OpMultiplyAdd;
|
| 102 |
+
using ArchTag = arch::Sm75;
|
| 103 |
+
|
| 104 |
+
CUTLASS_HOST_DEVICE
|
| 105 |
+
void operator()(
|
| 106 |
+
FragmentC &d,
|
| 107 |
+
FragmentA const &a,
|
| 108 |
+
FragmentB const &b,
|
| 109 |
+
FragmentC const &c
|
| 110 |
+
) const {
|
| 111 |
+
|
| 112 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 113 |
+
|
| 114 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 115 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 116 |
+
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
|
| 117 |
+
unsigned *D = reinterpret_cast<unsigned *>(&d);
|
| 118 |
+
|
| 119 |
+
asm volatile(
|
| 120 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
| 121 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 122 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
|
| 123 |
+
|
| 124 |
+
#else
|
| 125 |
+
CUTLASS_UNUSED(a);
|
| 126 |
+
CUTLASS_UNUSED(b);
|
| 127 |
+
CUTLASS_UNUSED(c);
|
| 128 |
+
CUTLASS_UNUSED(d);
|
| 129 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 130 |
+
#endif
|
| 131 |
+
}
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 135 |
+
//
|
| 136 |
+
// Matrix Multiply 1688 - FP32 accumulation
|
| 137 |
+
//
|
| 138 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 139 |
+
|
| 140 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 141 |
+
template <>
|
| 142 |
+
struct Mma<
|
| 143 |
+
gemm::GemmShape<16, 8, 8>,
|
| 144 |
+
32,
|
| 145 |
+
half_t,
|
| 146 |
+
layout::RowMajor,
|
| 147 |
+
half_t,
|
| 148 |
+
layout::ColumnMajor,
|
| 149 |
+
float,
|
| 150 |
+
layout::RowMajor,
|
| 151 |
+
OpMultiplyAdd> {
|
| 152 |
+
|
| 153 |
+
using Shape = gemm::GemmShape<16, 8, 8>;
|
| 154 |
+
|
| 155 |
+
using ElementA = half_t;
|
| 156 |
+
using LayoutA = layout::RowMajor;
|
| 157 |
+
using FragmentA = Array<half_t, 4>;
|
| 158 |
+
|
| 159 |
+
using ElementB = half_t;
|
| 160 |
+
using LayoutB = layout::ColumnMajor;
|
| 161 |
+
using FragmentB = Array<half_t, 2>;
|
| 162 |
+
|
| 163 |
+
using ElementC = float;
|
| 164 |
+
using LayoutC = layout::RowMajor;
|
| 165 |
+
using FragmentC = Array<float, 4>;
|
| 166 |
+
|
| 167 |
+
using Operator = OpMultiplyAdd;
|
| 168 |
+
using ArchTag = arch::Sm75;
|
| 169 |
+
|
| 170 |
+
/// Computes multiply-add
|
| 171 |
+
CUTLASS_HOST_DEVICE
|
| 172 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 173 |
+
FragmentC const &c) const {
|
| 174 |
+
|
| 175 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 176 |
+
|
| 177 |
+
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
|
| 178 |
+
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
|
| 179 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 180 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 181 |
+
|
| 182 |
+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
| 183 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 184 |
+
:
|
| 185 |
+
"r"(A[0]), "r"(A[1]),
|
| 186 |
+
"r"(B[0]),
|
| 187 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 188 |
+
);
|
| 189 |
+
|
| 190 |
+
#else
|
| 191 |
+
CUTLASS_UNUSED(a);
|
| 192 |
+
CUTLASS_UNUSED(b);
|
| 193 |
+
CUTLASS_UNUSED(c);
|
| 194 |
+
CUTLASS_UNUSED(d);
|
| 195 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 196 |
+
#endif
|
| 197 |
+
}
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 201 |
+
//
|
| 202 |
+
// Integer matrix multiply (8b) with SATURATE
|
| 203 |
+
//
|
| 204 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 205 |
+
|
| 206 |
+
/// Matrix multiply-add operation: S32 = S8 * S8 + S32
|
| 207 |
+
template <>
|
| 208 |
+
struct Mma<
|
| 209 |
+
gemm::GemmShape<8, 8, 16>,
|
| 210 |
+
32,
|
| 211 |
+
int8_t,
|
| 212 |
+
layout::RowMajor,
|
| 213 |
+
int8_t,
|
| 214 |
+
layout::ColumnMajor,
|
| 215 |
+
int,
|
| 216 |
+
layout::RowMajor,
|
| 217 |
+
OpMultiplyAddSaturate> {
|
| 218 |
+
|
| 219 |
+
using Shape = gemm::GemmShape<8, 8, 16>;
|
| 220 |
+
|
| 221 |
+
using ElementA = int8_t;
|
| 222 |
+
using LayoutA = layout::RowMajor;
|
| 223 |
+
using FragmentA = Array<int8_t, 4>;
|
| 224 |
+
|
| 225 |
+
using ElementB = int8_t;
|
| 226 |
+
using LayoutB = layout::ColumnMajor;
|
| 227 |
+
using FragmentB = Array<int8_t, 4>;
|
| 228 |
+
|
| 229 |
+
using ElementC = int;
|
| 230 |
+
using LayoutC = layout::RowMajor;
|
| 231 |
+
using FragmentC = Array<int, 2>;
|
| 232 |
+
|
| 233 |
+
using Operator = OpMultiplyAddSaturate;
|
| 234 |
+
using ArchTag = arch::Sm75;
|
| 235 |
+
|
| 236 |
+
/// Computes multiply-add
|
| 237 |
+
CUTLASS_HOST_DEVICE
|
| 238 |
+
void operator()(
|
| 239 |
+
FragmentC &d,
|
| 240 |
+
FragmentA const &a,
|
| 241 |
+
FragmentB const &b,
|
| 242 |
+
FragmentC const &c
|
| 243 |
+
) const {
|
| 244 |
+
|
| 245 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 246 |
+
|
| 247 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 248 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 249 |
+
|
| 250 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 251 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 252 |
+
|
| 253 |
+
asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 254 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 255 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 256 |
+
#else
|
| 257 |
+
CUTLASS_UNUSED(a);
|
| 258 |
+
CUTLASS_UNUSED(b);
|
| 259 |
+
CUTLASS_UNUSED(c);
|
| 260 |
+
CUTLASS_UNUSED(d);
|
| 261 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 262 |
+
#endif
|
| 263 |
+
}
|
| 264 |
+
};
|
| 265 |
+
|
| 266 |
+
/// Matrix multiply-add operation: S32 = U8 * S8 + S32
|
| 267 |
+
template <>
|
| 268 |
+
struct Mma<
|
| 269 |
+
gemm::GemmShape<8, 8, 16>,
|
| 270 |
+
32,
|
| 271 |
+
uint8_t,
|
| 272 |
+
layout::RowMajor,
|
| 273 |
+
int8_t,
|
| 274 |
+
layout::ColumnMajor,
|
| 275 |
+
int,
|
| 276 |
+
layout::RowMajor,
|
| 277 |
+
OpMultiplyAddSaturate> {
|
| 278 |
+
|
| 279 |
+
using Shape = gemm::GemmShape<8, 8, 16>;
|
| 280 |
+
|
| 281 |
+
using ElementA = uint8_t;
|
| 282 |
+
using LayoutA = layout::RowMajor;
|
| 283 |
+
using FragmentA = Array<uint8_t, 4>;
|
| 284 |
+
|
| 285 |
+
using ElementB = int8_t;
|
| 286 |
+
using LayoutB = layout::ColumnMajor;
|
| 287 |
+
using FragmentB = Array<int8_t, 4>;
|
| 288 |
+
|
| 289 |
+
using ElementC = int;
|
| 290 |
+
using LayoutC = layout::RowMajor;
|
| 291 |
+
using FragmentC = Array<int, 2>;
|
| 292 |
+
|
| 293 |
+
using Operator = OpMultiplyAddSaturate;
|
| 294 |
+
using ArchTag = arch::Sm75;
|
| 295 |
+
|
| 296 |
+
/// Computes multiply-add
|
| 297 |
+
CUTLASS_HOST_DEVICE
|
| 298 |
+
void operator()(
|
| 299 |
+
FragmentC &d,
|
| 300 |
+
FragmentA const &a,
|
| 301 |
+
FragmentB const &b,
|
| 302 |
+
FragmentC const &c
|
| 303 |
+
) const {
|
| 304 |
+
|
| 305 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 306 |
+
|
| 307 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 308 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 309 |
+
|
| 310 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 311 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 312 |
+
|
| 313 |
+
asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 314 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 315 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 316 |
+
#else
|
| 317 |
+
CUTLASS_UNUSED(a);
|
| 318 |
+
CUTLASS_UNUSED(b);
|
| 319 |
+
CUTLASS_UNUSED(c);
|
| 320 |
+
CUTLASS_UNUSED(d);
|
| 321 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 322 |
+
#endif
|
| 323 |
+
}
|
| 324 |
+
};
|
| 325 |
+
|
| 326 |
+
/// Matrix multiply-add operation: S32 = S8 * U8 + S32
|
| 327 |
+
template <>
|
| 328 |
+
struct Mma<
|
| 329 |
+
gemm::GemmShape<8, 8, 16>,
|
| 330 |
+
32,
|
| 331 |
+
int8_t,
|
| 332 |
+
layout::RowMajor,
|
| 333 |
+
uint8_t,
|
| 334 |
+
layout::ColumnMajor,
|
| 335 |
+
int,
|
| 336 |
+
layout::RowMajor,
|
| 337 |
+
OpMultiplyAddSaturate> {
|
| 338 |
+
|
| 339 |
+
using Shape = gemm::GemmShape<8, 8, 16>;
|
| 340 |
+
|
| 341 |
+
using ElementA = int8_t;
|
| 342 |
+
using LayoutA = layout::RowMajor;
|
| 343 |
+
using FragmentA = Array<int8_t, 4>;
|
| 344 |
+
|
| 345 |
+
using ElementB = uint8_t;
|
| 346 |
+
using LayoutB = layout::ColumnMajor;
|
| 347 |
+
using FragmentB = Array<uint8_t, 4>;
|
| 348 |
+
|
| 349 |
+
using ElementC = int;
|
| 350 |
+
using LayoutC = layout::RowMajor;
|
| 351 |
+
using FragmentC = Array<int, 2>;
|
| 352 |
+
|
| 353 |
+
using Operator = OpMultiplyAddSaturate;
|
| 354 |
+
using ArchTag = arch::Sm75;
|
| 355 |
+
|
| 356 |
+
/// Computes multiply-add
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
void operator()(
|
| 359 |
+
FragmentC &d,
|
| 360 |
+
FragmentA const &a,
|
| 361 |
+
FragmentB const &b,
|
| 362 |
+
FragmentC const &c
|
| 363 |
+
) const {
|
| 364 |
+
|
| 365 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 366 |
+
|
| 367 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 368 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 369 |
+
|
| 370 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 371 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 372 |
+
|
| 373 |
+
asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 374 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 375 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 376 |
+
#else
|
| 377 |
+
CUTLASS_UNUSED(a);
|
| 378 |
+
CUTLASS_UNUSED(b);
|
| 379 |
+
CUTLASS_UNUSED(c);
|
| 380 |
+
CUTLASS_UNUSED(d);
|
| 381 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 382 |
+
#endif
|
| 383 |
+
}
|
| 384 |
+
};
|
| 385 |
+
|
| 386 |
+
/// Matrix multiply-add operation: S32 = U8 * U8 + S32
|
| 387 |
+
template <>
|
| 388 |
+
struct Mma<
|
| 389 |
+
gemm::GemmShape<8, 8, 16>,
|
| 390 |
+
32,
|
| 391 |
+
uint8_t,
|
| 392 |
+
layout::RowMajor,
|
| 393 |
+
uint8_t,
|
| 394 |
+
layout::ColumnMajor,
|
| 395 |
+
int,
|
| 396 |
+
layout::RowMajor,
|
| 397 |
+
OpMultiplyAddSaturate> {
|
| 398 |
+
|
| 399 |
+
using Shape = gemm::GemmShape<8, 8, 16>;
|
| 400 |
+
|
| 401 |
+
using ElementA = uint8_t;
|
| 402 |
+
using LayoutA = layout::RowMajor;
|
| 403 |
+
using FragmentA = Array<uint8_t, 4>;
|
| 404 |
+
|
| 405 |
+
using ElementB = uint8_t;
|
| 406 |
+
using LayoutB = layout::ColumnMajor;
|
| 407 |
+
using FragmentB = Array<uint8_t, 4>;
|
| 408 |
+
|
| 409 |
+
using ElementC = int;
|
| 410 |
+
using LayoutC = layout::RowMajor;
|
| 411 |
+
using FragmentC = Array<int, 2>;
|
| 412 |
+
|
| 413 |
+
using Operator = OpMultiplyAddSaturate;
|
| 414 |
+
using ArchTag = arch::Sm75;
|
| 415 |
+
|
| 416 |
+
/// Computes multiply-add
|
| 417 |
+
CUTLASS_HOST_DEVICE
|
| 418 |
+
void operator()(
|
| 419 |
+
FragmentC &d,
|
| 420 |
+
FragmentA const &a,
|
| 421 |
+
FragmentB const &b,
|
| 422 |
+
FragmentC const &c
|
| 423 |
+
) const {
|
| 424 |
+
|
| 425 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 426 |
+
|
| 427 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 428 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 429 |
+
|
| 430 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 431 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 432 |
+
|
| 433 |
+
asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 434 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 435 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 436 |
+
#else
|
| 437 |
+
CUTLASS_UNUSED(a);
|
| 438 |
+
CUTLASS_UNUSED(b);
|
| 439 |
+
CUTLASS_UNUSED(c);
|
| 440 |
+
CUTLASS_UNUSED(d);
|
| 441 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 442 |
+
#endif
|
| 443 |
+
}
|
| 444 |
+
};
|
| 445 |
+
|
| 446 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 447 |
+
//
|
| 448 |
+
// Integer matrix multiply (4b) - SATURATE
|
| 449 |
+
//
|
| 450 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 451 |
+
|
| 452 |
+
/// Matrix multiply-add operation: S32 = S4 * S4 + S32
|
| 453 |
+
template <>
|
| 454 |
+
struct Mma<
|
| 455 |
+
gemm::GemmShape<8, 8, 32>,
|
| 456 |
+
32,
|
| 457 |
+
int4b_t,
|
| 458 |
+
layout::RowMajor,
|
| 459 |
+
int4b_t,
|
| 460 |
+
layout::ColumnMajor,
|
| 461 |
+
int,
|
| 462 |
+
layout::RowMajor,
|
| 463 |
+
OpMultiplyAddSaturate> {
|
| 464 |
+
|
| 465 |
+
using Shape = gemm::GemmShape<8, 8, 32>;
|
| 466 |
+
|
| 467 |
+
using ElementA = int4b_t;
|
| 468 |
+
using LayoutA = layout::RowMajor;
|
| 469 |
+
using FragmentA = Array<int4b_t, 8>;
|
| 470 |
+
|
| 471 |
+
using ElementB = int4b_t;
|
| 472 |
+
using LayoutB = layout::ColumnMajor;
|
| 473 |
+
using FragmentB = Array<int4b_t, 8>;
|
| 474 |
+
|
| 475 |
+
using ElementC = int;
|
| 476 |
+
using LayoutC = layout::RowMajor;
|
| 477 |
+
using FragmentC = Array<int, 2>;
|
| 478 |
+
|
| 479 |
+
using Operator = OpMultiplyAddSaturate;
|
| 480 |
+
using ArchTag = arch::Sm75;
|
| 481 |
+
|
| 482 |
+
/// Computes multiply-add
|
| 483 |
+
CUTLASS_HOST_DEVICE
|
| 484 |
+
void operator()(
|
| 485 |
+
FragmentC &d,
|
| 486 |
+
FragmentA const &a,
|
| 487 |
+
FragmentB const &b,
|
| 488 |
+
FragmentC const &c
|
| 489 |
+
) const {
|
| 490 |
+
|
| 491 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 492 |
+
|
| 493 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 494 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 495 |
+
|
| 496 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 497 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 498 |
+
|
| 499 |
+
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 500 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 501 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 502 |
+
#else
|
| 503 |
+
CUTLASS_UNUSED(a);
|
| 504 |
+
CUTLASS_UNUSED(b);
|
| 505 |
+
CUTLASS_UNUSED(c);
|
| 506 |
+
CUTLASS_UNUSED(d);
|
| 507 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 508 |
+
#endif
|
| 509 |
+
}
|
| 510 |
+
};
|
| 511 |
+
|
| 512 |
+
/// Matrix multiply-add operation: S32 = U4 * S4 + S32
|
| 513 |
+
template <>
|
| 514 |
+
struct Mma<
|
| 515 |
+
gemm::GemmShape<8, 8, 32>,
|
| 516 |
+
32,
|
| 517 |
+
uint4b_t,
|
| 518 |
+
layout::RowMajor,
|
| 519 |
+
int4b_t,
|
| 520 |
+
layout::ColumnMajor,
|
| 521 |
+
int,
|
| 522 |
+
layout::RowMajor,
|
| 523 |
+
OpMultiplyAddSaturate> {
|
| 524 |
+
|
| 525 |
+
using Shape = gemm::GemmShape<8, 8, 32>;
|
| 526 |
+
|
| 527 |
+
using ElementA = uint4b_t;
|
| 528 |
+
using LayoutA = layout::RowMajor;
|
| 529 |
+
using FragmentA = Array<uint4b_t, 8>;
|
| 530 |
+
|
| 531 |
+
using ElementB = int4b_t;
|
| 532 |
+
using LayoutB = layout::ColumnMajor;
|
| 533 |
+
using FragmentB = Array<int4b_t, 8>;
|
| 534 |
+
|
| 535 |
+
using ElementC = int;
|
| 536 |
+
using LayoutC = layout::RowMajor;
|
| 537 |
+
using FragmentC = Array<int, 2>;
|
| 538 |
+
|
| 539 |
+
using Operator = OpMultiplyAddSaturate;
|
| 540 |
+
using ArchTag = arch::Sm75;
|
| 541 |
+
|
| 542 |
+
/// Computes multiply-add
|
| 543 |
+
CUTLASS_HOST_DEVICE
|
| 544 |
+
void operator()(
|
| 545 |
+
FragmentC &d,
|
| 546 |
+
FragmentA const &a,
|
| 547 |
+
FragmentB const &b,
|
| 548 |
+
FragmentC const &c
|
| 549 |
+
) const {
|
| 550 |
+
|
| 551 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 552 |
+
|
| 553 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 554 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 555 |
+
|
| 556 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 557 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 558 |
+
|
| 559 |
+
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 560 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 561 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 562 |
+
#else
|
| 563 |
+
CUTLASS_UNUSED(a);
|
| 564 |
+
CUTLASS_UNUSED(b);
|
| 565 |
+
CUTLASS_UNUSED(c);
|
| 566 |
+
CUTLASS_UNUSED(d);
|
| 567 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 568 |
+
#endif
|
| 569 |
+
}
|
| 570 |
+
};
|
| 571 |
+
|
| 572 |
+
/// Matrix multiply-add operation: S32 = S4 * U4 + S32
|
| 573 |
+
template <>
|
| 574 |
+
struct Mma<
|
| 575 |
+
gemm::GemmShape<8, 8, 32>,
|
| 576 |
+
32,
|
| 577 |
+
int4b_t,
|
| 578 |
+
layout::RowMajor,
|
| 579 |
+
uint4b_t,
|
| 580 |
+
layout::ColumnMajor,
|
| 581 |
+
int,
|
| 582 |
+
layout::RowMajor,
|
| 583 |
+
OpMultiplyAddSaturate> {
|
| 584 |
+
|
| 585 |
+
using Shape = gemm::GemmShape<8, 8, 32>;
|
| 586 |
+
|
| 587 |
+
using ElementA = int4b_t;
|
| 588 |
+
using LayoutA = layout::RowMajor;
|
| 589 |
+
using FragmentA = Array<int4b_t, 8>;
|
| 590 |
+
|
| 591 |
+
using ElementB = uint4b_t;
|
| 592 |
+
using LayoutB = layout::ColumnMajor;
|
| 593 |
+
using FragmentB = Array<uint4b_t, 8>;
|
| 594 |
+
|
| 595 |
+
using ElementC = int;
|
| 596 |
+
using LayoutC = layout::RowMajor;
|
| 597 |
+
using FragmentC = Array<int, 2>;
|
| 598 |
+
|
| 599 |
+
using Operator = OpMultiplyAddSaturate;
|
| 600 |
+
using ArchTag = arch::Sm75;
|
| 601 |
+
|
| 602 |
+
/// Computes multiply-add
|
| 603 |
+
CUTLASS_HOST_DEVICE
|
| 604 |
+
void operator()(
|
| 605 |
+
FragmentC &d,
|
| 606 |
+
FragmentA const &a,
|
| 607 |
+
FragmentB const &b,
|
| 608 |
+
FragmentC const &c
|
| 609 |
+
) const {
|
| 610 |
+
|
| 611 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 612 |
+
|
| 613 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 614 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 615 |
+
|
| 616 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 617 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 618 |
+
|
| 619 |
+
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 620 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 621 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 622 |
+
#else
|
| 623 |
+
CUTLASS_UNUSED(a);
|
| 624 |
+
CUTLASS_UNUSED(b);
|
| 625 |
+
CUTLASS_UNUSED(c);
|
| 626 |
+
CUTLASS_UNUSED(d);
|
| 627 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 628 |
+
#endif
|
| 629 |
+
}
|
| 630 |
+
};
|
| 631 |
+
|
| 632 |
+
/// Matrix multiply-add operation: S32 = U4 * U4 + S32
|
| 633 |
+
template <>
|
| 634 |
+
struct Mma<
|
| 635 |
+
gemm::GemmShape<8, 8, 32>,
|
| 636 |
+
32,
|
| 637 |
+
uint4b_t,
|
| 638 |
+
layout::RowMajor,
|
| 639 |
+
uint4b_t,
|
| 640 |
+
layout::ColumnMajor,
|
| 641 |
+
int,
|
| 642 |
+
layout::RowMajor,
|
| 643 |
+
OpMultiplyAddSaturate> {
|
| 644 |
+
|
| 645 |
+
using Shape = gemm::GemmShape<8, 8, 32>;
|
| 646 |
+
|
| 647 |
+
using ElementA = uint4b_t;
|
| 648 |
+
using LayoutA = layout::RowMajor;
|
| 649 |
+
using FragmentA = Array<uint4b_t, 8>;
|
| 650 |
+
|
| 651 |
+
using ElementB = uint4b_t;
|
| 652 |
+
using LayoutB = layout::ColumnMajor;
|
| 653 |
+
using FragmentB = Array<uint4b_t, 8>;
|
| 654 |
+
|
| 655 |
+
using ElementC = int;
|
| 656 |
+
using LayoutC = layout::RowMajor;
|
| 657 |
+
using FragmentC = Array<int, 2>;
|
| 658 |
+
|
| 659 |
+
using Operator = OpMultiplyAddSaturate;
|
| 660 |
+
using ArchTag = arch::Sm75;
|
| 661 |
+
|
| 662 |
+
/// Computes multiply-add
|
| 663 |
+
CUTLASS_HOST_DEVICE
|
| 664 |
+
void operator()(
|
| 665 |
+
FragmentC &d,
|
| 666 |
+
FragmentA const &a,
|
| 667 |
+
FragmentB const &b,
|
| 668 |
+
FragmentC const &c
|
| 669 |
+
) const {
|
| 670 |
+
|
| 671 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 672 |
+
|
| 673 |
+
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
| 674 |
+
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
| 675 |
+
|
| 676 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 677 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 678 |
+
|
| 679 |
+
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 680 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 681 |
+
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
| 682 |
+
#else
|
| 683 |
+
CUTLASS_UNUSED(a);
|
| 684 |
+
CUTLASS_UNUSED(b);
|
| 685 |
+
CUTLASS_UNUSED(c);
|
| 686 |
+
CUTLASS_UNUSED(d);
|
| 687 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 688 |
+
#endif
|
| 689 |
+
}
|
| 690 |
+
};
|
| 691 |
+
|
| 692 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 693 |
+
//
|
| 694 |
+
// b1 ^ b1 + s32 => s32
|
| 695 |
+
//
|
| 696 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 697 |
+
|
| 698 |
+
/// Matrix multiply-add operation
|
| 699 |
+
template <>
|
| 700 |
+
struct Mma<
|
| 701 |
+
gemm::GemmShape<8,8,128>,
|
| 702 |
+
32,
|
| 703 |
+
uint1b_t,
|
| 704 |
+
layout::RowMajor,
|
| 705 |
+
uint1b_t,
|
| 706 |
+
layout::ColumnMajor,
|
| 707 |
+
int,
|
| 708 |
+
layout::RowMajor,
|
| 709 |
+
OpXorPopc> {
|
| 710 |
+
|
| 711 |
+
using Shape = gemm::GemmShape<8,8,128>;
|
| 712 |
+
|
| 713 |
+
using ElementA = uint1b_t;
|
| 714 |
+
using LayoutA = layout::RowMajor;
|
| 715 |
+
using FragmentA = Array<uint1b_t, 32>;
|
| 716 |
+
|
| 717 |
+
using ElementB = uint1b_t;
|
| 718 |
+
using LayoutB = layout::ColumnMajor;
|
| 719 |
+
using FragmentB = Array<uint1b_t, 32>;
|
| 720 |
+
|
| 721 |
+
using ElementC = int;
|
| 722 |
+
using LayoutC = layout::RowMajor;
|
| 723 |
+
using FragmentC = Array<int, 2>;
|
| 724 |
+
|
| 725 |
+
using Operator = OpXorPopc;
|
| 726 |
+
using ArchTag = arch::Sm75;
|
| 727 |
+
|
| 728 |
+
/// Computes multiply-add
|
| 729 |
+
CUTLASS_HOST_DEVICE
|
| 730 |
+
void operator()(
|
| 731 |
+
FragmentC &d,
|
| 732 |
+
FragmentA const &a,
|
| 733 |
+
FragmentB const &b,
|
| 734 |
+
FragmentC const &c
|
| 735 |
+
) const {
|
| 736 |
+
|
| 737 |
+
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
| 738 |
+
|
| 739 |
+
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
| 740 |
+
using WmmaFragmentA = nvcuda::wmma::fragment<
|
| 741 |
+
nvcuda::wmma::matrix_a,
|
| 742 |
+
Shape::kM,
|
| 743 |
+
Shape::kN,
|
| 744 |
+
Shape::kK,
|
| 745 |
+
nvcuda::wmma::experimental::precision::b1,
|
| 746 |
+
nvcuda::wmma::row_major>;
|
| 747 |
+
|
| 748 |
+
using WmmaFragmentB = nvcuda::wmma::fragment<
|
| 749 |
+
nvcuda::wmma::matrix_b,
|
| 750 |
+
Shape::kM,
|
| 751 |
+
Shape::kN,
|
| 752 |
+
Shape::kK,
|
| 753 |
+
nvcuda::wmma::experimental::precision::b1,
|
| 754 |
+
nvcuda::wmma::col_major>;
|
| 755 |
+
|
| 756 |
+
using WmmaFragmentC = nvcuda::wmma::fragment<
|
| 757 |
+
nvcuda::wmma::accumulator,
|
| 758 |
+
Shape::kM,
|
| 759 |
+
Shape::kN,
|
| 760 |
+
Shape::kK,
|
| 761 |
+
int>;
|
| 762 |
+
|
| 763 |
+
WmmaFragmentA const & A = reinterpret_cast<WmmaFragmentA const &>(a);
|
| 764 |
+
WmmaFragmentB const & B = reinterpret_cast<WmmaFragmentB const &>(b);
|
| 765 |
+
|
| 766 |
+
WmmaFragmentC const & C = reinterpret_cast<WmmaFragmentC const &>(c);
|
| 767 |
+
WmmaFragmentC & D = reinterpret_cast<WmmaFragmentC &>(d);
|
| 768 |
+
|
| 769 |
+
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
|
| 770 |
+
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
| 771 |
+
|
| 772 |
+
#else
|
| 773 |
+
|
| 774 |
+
CUTLASS_UNUSED(a);
|
| 775 |
+
CUTLASS_UNUSED(b);
|
| 776 |
+
CUTLASS_UNUSED(c);
|
| 777 |
+
CUTLASS_UNUSED(d);
|
| 778 |
+
CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.
|
| 779 |
+
|
| 780 |
+
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
|
| 781 |
+
|
| 782 |
+
#endif
|
| 783 |
+
}
|
| 784 |
+
};
|
| 785 |
+
|
| 786 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 787 |
+
|
| 788 |
+
} // namespace arch
|
| 789 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h
ADDED
|
@@ -0,0 +1,1500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include CUDA_STD_HEADER(cassert)
|
| 38 |
+
|
| 39 |
+
#include "mma.h"
|
| 40 |
+
#include "cutlass/layout/matrix.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
| 46 |
+
|
| 47 |
+
#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1
|
| 48 |
+
|
| 49 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
| 50 |
+
#define CUTLASS_ARCH_MMA_SM80_ENABLED
|
| 51 |
+
|
| 52 |
+
#if (__CUDA_ARCH__ <= 900)
|
| 53 |
+
#define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED
|
| 54 |
+
#endif
|
| 55 |
+
#if (__CUDA_ARCH__ <= 890)
|
| 56 |
+
#define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED
|
| 57 |
+
#endif
|
| 58 |
+
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
#endif
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
namespace cutlass {
|
| 66 |
+
namespace arch {
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
//
|
| 70 |
+
// Matrix Multiply 1688 - Float BF16, FP32 accumulation
|
| 71 |
+
//
|
| 72 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 73 |
+
|
| 74 |
+
/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32
|
| 75 |
+
template <>
|
| 76 |
+
struct Mma<
|
| 77 |
+
gemm::GemmShape<16, 8, 8>,
|
| 78 |
+
32,
|
| 79 |
+
bfloat16_t,
|
| 80 |
+
layout::RowMajor,
|
| 81 |
+
bfloat16_t,
|
| 82 |
+
layout::ColumnMajor,
|
| 83 |
+
float,
|
| 84 |
+
layout::RowMajor,
|
| 85 |
+
OpMultiplyAdd> {
|
| 86 |
+
|
| 87 |
+
using Shape = gemm::GemmShape<16, 8, 8>;
|
| 88 |
+
|
| 89 |
+
using ElementA = bfloat16_t;
|
| 90 |
+
using LayoutA = layout::RowMajor;
|
| 91 |
+
using FragmentA = Array<bfloat16_t, 4>;
|
| 92 |
+
|
| 93 |
+
using ElementB = bfloat16_t;
|
| 94 |
+
using LayoutB = layout::ColumnMajor;
|
| 95 |
+
using FragmentB = Array<bfloat16_t, 2>;
|
| 96 |
+
|
| 97 |
+
using ElementC = float;
|
| 98 |
+
using LayoutC = layout::RowMajor;
|
| 99 |
+
using FragmentC = Array<float, 4>;
|
| 100 |
+
|
| 101 |
+
using Operator = OpMultiplyAdd;
|
| 102 |
+
using ArchTag = arch::Sm80;
|
| 103 |
+
|
| 104 |
+
CUTLASS_HOST_DEVICE
|
| 105 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 106 |
+
FragmentC const &c) const {
|
| 107 |
+
|
| 108 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 109 |
+
|
| 110 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 111 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 112 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 113 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 114 |
+
|
| 115 |
+
asm(
|
| 116 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
|
| 117 |
+
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
| 118 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 119 |
+
:
|
| 120 |
+
"r"(A[0]), "r"(A[1]),
|
| 121 |
+
"r"(B[0]),
|
| 122 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 123 |
+
);
|
| 124 |
+
|
| 125 |
+
#else
|
| 126 |
+
|
| 127 |
+
CUTLASS_UNUSED(d);
|
| 128 |
+
CUTLASS_UNUSED(a);
|
| 129 |
+
CUTLASS_UNUSED(b);
|
| 130 |
+
CUTLASS_UNUSED(c);
|
| 131 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 132 |
+
|
| 133 |
+
#endif
|
| 134 |
+
}
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 138 |
+
//
|
| 139 |
+
// Matrix Multiply 1684 - Float TF32
|
| 140 |
+
//
|
| 141 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 142 |
+
|
| 143 |
+
/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
|
| 144 |
+
template <>
|
| 145 |
+
struct Mma<
|
| 146 |
+
gemm::GemmShape<16, 8, 4>,
|
| 147 |
+
32,
|
| 148 |
+
tfloat32_t,
|
| 149 |
+
layout::RowMajor,
|
| 150 |
+
tfloat32_t,
|
| 151 |
+
layout::ColumnMajor,
|
| 152 |
+
float,
|
| 153 |
+
layout::RowMajor,
|
| 154 |
+
OpMultiplyAdd> {
|
| 155 |
+
|
| 156 |
+
using Shape = gemm::GemmShape<16, 8, 4>;
|
| 157 |
+
|
| 158 |
+
using ElementA = tfloat32_t;
|
| 159 |
+
using LayoutA = layout::RowMajor;
|
| 160 |
+
using FragmentA = Array<tfloat32_t, 2>;
|
| 161 |
+
|
| 162 |
+
using ElementB = tfloat32_t;
|
| 163 |
+
using LayoutB = layout::ColumnMajor;
|
| 164 |
+
using FragmentB = Array<tfloat32_t, 1>;
|
| 165 |
+
|
| 166 |
+
using ElementC = float;
|
| 167 |
+
using LayoutC = layout::RowMajor;
|
| 168 |
+
using FragmentC = Array<float, 4>;
|
| 169 |
+
|
| 170 |
+
using Operator = OpMultiplyAdd;
|
| 171 |
+
using ArchTag = arch::Sm80;
|
| 172 |
+
|
| 173 |
+
CUTLASS_HOST_DEVICE
|
| 174 |
+
void operator()(
|
| 175 |
+
FragmentC &d,
|
| 176 |
+
FragmentA const &a,
|
| 177 |
+
FragmentB const &b,
|
| 178 |
+
FragmentC const &c
|
| 179 |
+
) const {
|
| 180 |
+
|
| 181 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 182 |
+
|
| 183 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 184 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 185 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 186 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 187 |
+
|
| 188 |
+
asm volatile(
|
| 189 |
+
"mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
| 190 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 191 |
+
:
|
| 192 |
+
"r"(A[0]), "r"(A[1]),
|
| 193 |
+
"r"(B[0]),
|
| 194 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 195 |
+
);
|
| 196 |
+
|
| 197 |
+
#else
|
| 198 |
+
|
| 199 |
+
CUTLASS_UNUSED(d);
|
| 200 |
+
CUTLASS_UNUSED(a);
|
| 201 |
+
CUTLASS_UNUSED(b);
|
| 202 |
+
CUTLASS_UNUSED(c);
|
| 203 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 204 |
+
|
| 205 |
+
#endif
|
| 206 |
+
}
|
| 207 |
+
};
|
| 208 |
+
|
| 209 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 210 |
+
//
|
| 211 |
+
// Matrix Multiply 1688 - Float TF32
|
| 212 |
+
//
|
| 213 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 214 |
+
|
| 215 |
+
/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
|
| 216 |
+
template <>
|
| 217 |
+
struct Mma<gemm::GemmShape<16, 8, 8>, 32, tfloat32_t, layout::RowMajor,
|
| 218 |
+
tfloat32_t, layout::ColumnMajor, float, layout::RowMajor,
|
| 219 |
+
OpMultiplyAdd> {
|
| 220 |
+
using Shape = gemm::GemmShape<16, 8, 8>;
|
| 221 |
+
|
| 222 |
+
using ElementA = tfloat32_t;
|
| 223 |
+
using LayoutA = layout::RowMajor;
|
| 224 |
+
using FragmentA = Array<tfloat32_t, 4>;
|
| 225 |
+
|
| 226 |
+
using ElementB = tfloat32_t;
|
| 227 |
+
using LayoutB = layout::ColumnMajor;
|
| 228 |
+
using FragmentB = Array<tfloat32_t, 2>;
|
| 229 |
+
|
| 230 |
+
using ElementC = float;
|
| 231 |
+
using LayoutC = layout::RowMajor;
|
| 232 |
+
using FragmentC = Array<float, 4>;
|
| 233 |
+
|
| 234 |
+
using Operator = OpMultiplyAdd;
|
| 235 |
+
using ArchTag = arch::Sm80;
|
| 236 |
+
|
| 237 |
+
CUTLASS_HOST_DEVICE
|
| 238 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 239 |
+
FragmentC const &c) const {
|
| 240 |
+
|
| 241 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 242 |
+
|
| 243 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 244 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 245 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 246 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 247 |
+
|
| 248 |
+
asm volatile(
|
| 249 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 "
|
| 250 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 251 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 252 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 253 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
| 254 |
+
|
| 255 |
+
#else
|
| 256 |
+
|
| 257 |
+
CUTLASS_UNUSED(d);
|
| 258 |
+
CUTLASS_UNUSED(a);
|
| 259 |
+
CUTLASS_UNUSED(b);
|
| 260 |
+
CUTLASS_UNUSED(c);
|
| 261 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 262 |
+
|
| 263 |
+
#endif
|
| 264 |
+
}
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 268 |
+
//
|
| 269 |
+
// Matrix Multiply 16816
|
| 270 |
+
//
|
| 271 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 272 |
+
|
| 273 |
+
/// Matrix multiply-add operation: F16 = F16 * F16 + F16
|
| 274 |
+
template <>
|
| 275 |
+
struct Mma<
|
| 276 |
+
gemm::GemmShape<16, 8, 16>,
|
| 277 |
+
32,
|
| 278 |
+
half_t,
|
| 279 |
+
layout::RowMajor,
|
| 280 |
+
half_t,
|
| 281 |
+
layout::ColumnMajor,
|
| 282 |
+
half_t,
|
| 283 |
+
layout::RowMajor,
|
| 284 |
+
OpMultiplyAdd> {
|
| 285 |
+
|
| 286 |
+
using Shape = gemm::GemmShape<16, 8, 16>;
|
| 287 |
+
|
| 288 |
+
using ElementA = half_t;
|
| 289 |
+
using LayoutA = layout::RowMajor;
|
| 290 |
+
using FragmentA = Array<half_t, 8>;
|
| 291 |
+
|
| 292 |
+
using ElementB = half_t;
|
| 293 |
+
using LayoutB = layout::ColumnMajor;
|
| 294 |
+
using FragmentB = Array<half_t, 4>;
|
| 295 |
+
|
| 296 |
+
using ElementC = half_t;
|
| 297 |
+
using LayoutC = layout::RowMajor;
|
| 298 |
+
using FragmentC = Array<half_t, 4>;
|
| 299 |
+
|
| 300 |
+
using Operator = OpMultiplyAdd;
|
| 301 |
+
using ArchTag = arch::Sm80;
|
| 302 |
+
|
| 303 |
+
/// Computes multiply-add
|
| 304 |
+
CUTLASS_HOST_DEVICE
|
| 305 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 306 |
+
FragmentC const &c) const {
|
| 307 |
+
|
| 308 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 309 |
+
|
| 310 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 311 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 312 |
+
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
|
| 313 |
+
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
|
| 314 |
+
|
| 315 |
+
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
| 316 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 317 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 318 |
+
"r"(B[0]), "r"(B[1]),
|
| 319 |
+
"r"(C[0]), "r"(C[1])
|
| 320 |
+
);
|
| 321 |
+
|
| 322 |
+
#else
|
| 323 |
+
|
| 324 |
+
CUTLASS_UNUSED(d);
|
| 325 |
+
CUTLASS_UNUSED(a);
|
| 326 |
+
CUTLASS_UNUSED(b);
|
| 327 |
+
CUTLASS_UNUSED(c);
|
| 328 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 329 |
+
|
| 330 |
+
#endif
|
| 331 |
+
}
|
| 332 |
+
};
|
| 333 |
+
|
| 334 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 335 |
+
|
| 336 |
+
/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32
|
| 337 |
+
template <>
|
| 338 |
+
struct Mma<
|
| 339 |
+
gemm::GemmShape<16, 8, 16>,
|
| 340 |
+
32,
|
| 341 |
+
bfloat16_t,
|
| 342 |
+
layout::RowMajor,
|
| 343 |
+
bfloat16_t,
|
| 344 |
+
layout::ColumnMajor,
|
| 345 |
+
float,
|
| 346 |
+
layout::RowMajor,
|
| 347 |
+
OpMultiplyAdd> {
|
| 348 |
+
|
| 349 |
+
using Shape = gemm::GemmShape<16, 8, 16>;
|
| 350 |
+
|
| 351 |
+
using ElementA = bfloat16_t;
|
| 352 |
+
using LayoutA = layout::RowMajor;
|
| 353 |
+
using FragmentA = Array<bfloat16_t, 8>;
|
| 354 |
+
|
| 355 |
+
using ElementB = bfloat16_t;
|
| 356 |
+
using LayoutB = layout::ColumnMajor;
|
| 357 |
+
using FragmentB = Array<bfloat16_t, 4>;
|
| 358 |
+
|
| 359 |
+
using ElementC = float;
|
| 360 |
+
using LayoutC = layout::RowMajor;
|
| 361 |
+
using FragmentC = Array<float, 4>;
|
| 362 |
+
|
| 363 |
+
using Operator = OpMultiplyAdd;
|
| 364 |
+
using ArchTag = arch::Sm80;
|
| 365 |
+
|
| 366 |
+
/// Computes multiply-add
|
| 367 |
+
CUTLASS_HOST_DEVICE
|
| 368 |
+
void operator()(
|
| 369 |
+
FragmentC &d,
|
| 370 |
+
FragmentA const &a,
|
| 371 |
+
FragmentB const &b,
|
| 372 |
+
FragmentC const &c
|
| 373 |
+
) const {
|
| 374 |
+
|
| 375 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 376 |
+
|
| 377 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 378 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 379 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 380 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 381 |
+
|
| 382 |
+
asm volatile(
|
| 383 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
| 384 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 385 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 386 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 387 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
| 388 |
+
|
| 389 |
+
#else
|
| 390 |
+
|
| 391 |
+
CUTLASS_UNUSED(d);
|
| 392 |
+
CUTLASS_UNUSED(a);
|
| 393 |
+
CUTLASS_UNUSED(b);
|
| 394 |
+
CUTLASS_UNUSED(c);
|
| 395 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 396 |
+
|
| 397 |
+
#endif
|
| 398 |
+
}
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 402 |
+
|
| 403 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 404 |
+
template <>
|
| 405 |
+
struct Mma<
|
| 406 |
+
gemm::GemmShape<16, 8, 16>,
|
| 407 |
+
32,
|
| 408 |
+
half_t,
|
| 409 |
+
layout::RowMajor,
|
| 410 |
+
half_t,
|
| 411 |
+
layout::ColumnMajor,
|
| 412 |
+
float,
|
| 413 |
+
layout::RowMajor,
|
| 414 |
+
OpMultiplyAdd> {
|
| 415 |
+
|
| 416 |
+
using Shape = gemm::GemmShape<16, 8, 16>;
|
| 417 |
+
|
| 418 |
+
using ElementA = half_t;
|
| 419 |
+
using LayoutA = layout::RowMajor;
|
| 420 |
+
using FragmentA = Array<half_t, 8>;
|
| 421 |
+
|
| 422 |
+
using ElementB = half_t;
|
| 423 |
+
using LayoutB = layout::ColumnMajor;
|
| 424 |
+
using FragmentB = Array<half_t, 4>;
|
| 425 |
+
|
| 426 |
+
using ElementC = float;
|
| 427 |
+
using LayoutC = layout::RowMajor;
|
| 428 |
+
using FragmentC = Array<float, 4>;
|
| 429 |
+
|
| 430 |
+
using Operator = OpMultiplyAdd;
|
| 431 |
+
using ArchTag = arch::Sm80;
|
| 432 |
+
|
| 433 |
+
/// Computes multiply-add
|
| 434 |
+
CUTLASS_HOST_DEVICE
|
| 435 |
+
void operator()(
|
| 436 |
+
FragmentC &d,
|
| 437 |
+
FragmentA const &a,
|
| 438 |
+
FragmentB const &b,
|
| 439 |
+
FragmentC const &c
|
| 440 |
+
) const {
|
| 441 |
+
|
| 442 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 443 |
+
|
| 444 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 445 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 446 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 447 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 448 |
+
|
| 449 |
+
asm volatile(
|
| 450 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
|
| 451 |
+
"{%10,%11,%12,%13};\n"
|
| 452 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 453 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 454 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
| 455 |
+
|
| 456 |
+
#else
|
| 457 |
+
|
| 458 |
+
CUTLASS_UNUSED(d);
|
| 459 |
+
CUTLASS_UNUSED(a);
|
| 460 |
+
CUTLASS_UNUSED(b);
|
| 461 |
+
CUTLASS_UNUSED(c);
|
| 462 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 463 |
+
|
| 464 |
+
#endif
|
| 465 |
+
}
|
| 466 |
+
};
|
| 467 |
+
|
| 468 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 469 |
+
//
|
| 470 |
+
// Matrix Multiply 884 - F64
|
| 471 |
+
//
|
| 472 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 473 |
+
|
| 474 |
+
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
|
| 475 |
+
template <>
|
| 476 |
+
struct Mma<
|
| 477 |
+
gemm::GemmShape<8,8,4>,
|
| 478 |
+
32,
|
| 479 |
+
double,
|
| 480 |
+
layout::RowMajor,
|
| 481 |
+
double,
|
| 482 |
+
layout::ColumnMajor,
|
| 483 |
+
double,
|
| 484 |
+
layout::RowMajor,
|
| 485 |
+
OpMultiplyAdd> {
|
| 486 |
+
|
| 487 |
+
using Shape = gemm::GemmShape<8,8,4>;
|
| 488 |
+
|
| 489 |
+
using ElementA = double;
|
| 490 |
+
using LayoutA = layout::RowMajor;
|
| 491 |
+
using FragmentA = Array<double, 1>;
|
| 492 |
+
|
| 493 |
+
using ElementB = double;
|
| 494 |
+
using LayoutB = layout::ColumnMajor;
|
| 495 |
+
using FragmentB = Array<double, 1>;
|
| 496 |
+
|
| 497 |
+
using ElementC = double;
|
| 498 |
+
using LayoutC = layout::RowMajor;
|
| 499 |
+
using FragmentC = Array<double, 2>;
|
| 500 |
+
|
| 501 |
+
using Operator = OpMultiplyAdd;
|
| 502 |
+
|
| 503 |
+
using ArchTag = arch::Sm80;
|
| 504 |
+
|
| 505 |
+
CUTLASS_HOST_DEVICE
|
| 506 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 507 |
+
FragmentC const &c) const {
|
| 508 |
+
|
| 509 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 510 |
+
|
| 511 |
+
double const & A = reinterpret_cast<double const &>(a);
|
| 512 |
+
double const & B = reinterpret_cast<double const &>(b);
|
| 513 |
+
|
| 514 |
+
double const *C = reinterpret_cast<double const *>(&c);
|
| 515 |
+
double *D = reinterpret_cast<double *>(&d);
|
| 516 |
+
|
| 517 |
+
asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
| 518 |
+
: "=d"(D[0]), "=d"(D[1])
|
| 519 |
+
: "d"(A), "d"(B), "d"(C[0]), "d"(C[1]));
|
| 520 |
+
|
| 521 |
+
#else
|
| 522 |
+
|
| 523 |
+
CUTLASS_UNUSED(d);
|
| 524 |
+
CUTLASS_UNUSED(a);
|
| 525 |
+
CUTLASS_UNUSED(b);
|
| 526 |
+
CUTLASS_UNUSED(c);
|
| 527 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 528 |
+
|
| 529 |
+
#endif
|
| 530 |
+
}
|
| 531 |
+
};
|
| 532 |
+
|
| 533 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 534 |
+
//
|
| 535 |
+
// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE
|
| 536 |
+
//
|
| 537 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 538 |
+
|
| 539 |
+
/// Matrix multiply-add operation: S32 = S8 * S8 + S32
|
| 540 |
+
template <>
|
| 541 |
+
struct Mma<
|
| 542 |
+
gemm::GemmShape<16,8,16>,
|
| 543 |
+
32,
|
| 544 |
+
int8_t,
|
| 545 |
+
layout::RowMajor,
|
| 546 |
+
int8_t,
|
| 547 |
+
layout::ColumnMajor,
|
| 548 |
+
int,
|
| 549 |
+
layout::RowMajor,
|
| 550 |
+
OpMultiplyAddSaturate> {
|
| 551 |
+
|
| 552 |
+
using Shape = gemm::GemmShape<16,8,16>;
|
| 553 |
+
|
| 554 |
+
using ElementA = int8_t;
|
| 555 |
+
using LayoutA = layout::RowMajor;
|
| 556 |
+
using FragmentA = Array<int8_t, 8>;
|
| 557 |
+
|
| 558 |
+
using ElementB = int8_t;
|
| 559 |
+
using LayoutB = layout::ColumnMajor;
|
| 560 |
+
using FragmentB = Array<int8_t, 4>;
|
| 561 |
+
|
| 562 |
+
using ElementC = int;
|
| 563 |
+
using LayoutC = layout::RowMajor;
|
| 564 |
+
using FragmentC = Array<int, 4>;
|
| 565 |
+
|
| 566 |
+
using Operator = OpMultiplyAddSaturate;
|
| 567 |
+
using ArchTag = arch::Sm80;
|
| 568 |
+
|
| 569 |
+
/// Computes multiply-add
|
| 570 |
+
CUTLASS_HOST_DEVICE
|
| 571 |
+
void operator()(
|
| 572 |
+
FragmentC &d,
|
| 573 |
+
FragmentA const &a,
|
| 574 |
+
FragmentB const &b,
|
| 575 |
+
FragmentC const &c
|
| 576 |
+
) const {
|
| 577 |
+
|
| 578 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 579 |
+
|
| 580 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 581 |
+
uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
|
| 582 |
+
|
| 583 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 584 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 585 |
+
|
| 586 |
+
asm volatile(
|
| 587 |
+
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
|
| 588 |
+
"{%6}, {%7,%8,%9,%10};\n"
|
| 589 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 590 |
+
: "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
|
| 591 |
+
"r"(C[3]));
|
| 592 |
+
|
| 593 |
+
#else
|
| 594 |
+
assert(0);
|
| 595 |
+
#endif
|
| 596 |
+
}
|
| 597 |
+
};
|
| 598 |
+
|
| 599 |
+
/// Matrix multiply-add operation: S32 = U8 * S8 + S32
|
| 600 |
+
template <>
|
| 601 |
+
struct Mma<
|
| 602 |
+
gemm::GemmShape<16,8,16>,
|
| 603 |
+
32,
|
| 604 |
+
uint8_t,
|
| 605 |
+
layout::RowMajor,
|
| 606 |
+
int8_t,
|
| 607 |
+
layout::ColumnMajor,
|
| 608 |
+
int,
|
| 609 |
+
layout::RowMajor,
|
| 610 |
+
OpMultiplyAddSaturate> {
|
| 611 |
+
|
| 612 |
+
using Shape = gemm::GemmShape<16,8,16>;
|
| 613 |
+
|
| 614 |
+
using ElementA = uint8_t;
|
| 615 |
+
using LayoutA = layout::RowMajor;
|
| 616 |
+
using FragmentA = Array<uint8_t, 8>;
|
| 617 |
+
|
| 618 |
+
using ElementB = int8_t;
|
| 619 |
+
using LayoutB = layout::ColumnMajor;
|
| 620 |
+
using FragmentB = Array<int8_t, 4>;
|
| 621 |
+
|
| 622 |
+
using ElementC = int;
|
| 623 |
+
using LayoutC = layout::RowMajor;
|
| 624 |
+
using FragmentC = Array<int, 4>;
|
| 625 |
+
|
| 626 |
+
using Operator = OpMultiplyAddSaturate;
|
| 627 |
+
using ArchTag = arch::Sm80;
|
| 628 |
+
|
| 629 |
+
/// Computes multiply-add
|
| 630 |
+
CUTLASS_HOST_DEVICE
|
| 631 |
+
void operator()(
|
| 632 |
+
FragmentC &d,
|
| 633 |
+
FragmentA const &a,
|
| 634 |
+
FragmentB const &b,
|
| 635 |
+
FragmentC const &c
|
| 636 |
+
) const {
|
| 637 |
+
|
| 638 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 639 |
+
|
| 640 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 641 |
+
uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
|
| 642 |
+
|
| 643 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 644 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 645 |
+
|
| 646 |
+
asm volatile(
|
| 647 |
+
"mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
|
| 648 |
+
"{%6}, {%7,%8,%9,%10};\n"
|
| 649 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 650 |
+
: "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
|
| 651 |
+
"r"(C[3]));
|
| 652 |
+
|
| 653 |
+
#else
|
| 654 |
+
assert(0);
|
| 655 |
+
#endif
|
| 656 |
+
}
|
| 657 |
+
};
|
| 658 |
+
|
| 659 |
+
/// Matrix multiply-add operation: S32 = S8 * U8 + S32
|
| 660 |
+
template <>
|
| 661 |
+
struct Mma<
|
| 662 |
+
gemm::GemmShape<16,8,16>,
|
| 663 |
+
32,
|
| 664 |
+
int8_t,
|
| 665 |
+
layout::RowMajor,
|
| 666 |
+
uint8_t,
|
| 667 |
+
layout::ColumnMajor,
|
| 668 |
+
int,
|
| 669 |
+
layout::RowMajor,
|
| 670 |
+
OpMultiplyAddSaturate> {
|
| 671 |
+
|
| 672 |
+
using Shape = gemm::GemmShape<16,8,16>;
|
| 673 |
+
|
| 674 |
+
using ElementA = int8_t;
|
| 675 |
+
using LayoutA = layout::RowMajor;
|
| 676 |
+
using FragmentA = Array<int8_t, 8>;
|
| 677 |
+
|
| 678 |
+
using ElementB = uint8_t;
|
| 679 |
+
using LayoutB = layout::ColumnMajor;
|
| 680 |
+
using FragmentB = Array<uint8_t, 4>;
|
| 681 |
+
|
| 682 |
+
using ElementC = int;
|
| 683 |
+
using LayoutC = layout::RowMajor;
|
| 684 |
+
using FragmentC = Array<int, 4>;
|
| 685 |
+
|
| 686 |
+
using Operator = OpMultiplyAddSaturate;
|
| 687 |
+
using ArchTag = arch::Sm80;
|
| 688 |
+
|
| 689 |
+
/// Computes multiply-add
|
| 690 |
+
CUTLASS_HOST_DEVICE
|
| 691 |
+
void operator()(
|
| 692 |
+
FragmentC &d,
|
| 693 |
+
FragmentA const &a,
|
| 694 |
+
FragmentB const &b,
|
| 695 |
+
FragmentC const &c
|
| 696 |
+
) const {
|
| 697 |
+
|
| 698 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 699 |
+
|
| 700 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 701 |
+
uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
|
| 702 |
+
|
| 703 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 704 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 705 |
+
|
| 706 |
+
asm volatile(
|
| 707 |
+
"mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
|
| 708 |
+
"{%6}, {%7,%8,%9,%10};\n"
|
| 709 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 710 |
+
: "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
|
| 711 |
+
"r"(C[3]));
|
| 712 |
+
|
| 713 |
+
#else
|
| 714 |
+
assert(0);
|
| 715 |
+
#endif
|
| 716 |
+
}
|
| 717 |
+
};
|
| 718 |
+
|
| 719 |
+
/// Matrix multiply-add operation: S32 = U8 * U8 + S32
|
| 720 |
+
template <>
|
| 721 |
+
struct Mma<
|
| 722 |
+
gemm::GemmShape<16,8,16>,
|
| 723 |
+
32,
|
| 724 |
+
uint8_t,
|
| 725 |
+
layout::RowMajor,
|
| 726 |
+
uint8_t,
|
| 727 |
+
layout::ColumnMajor,
|
| 728 |
+
int,
|
| 729 |
+
layout::RowMajor,
|
| 730 |
+
OpMultiplyAddSaturate> {
|
| 731 |
+
|
| 732 |
+
using Shape = gemm::GemmShape<16,8,16>;
|
| 733 |
+
|
| 734 |
+
using ElementA = uint8_t;
|
| 735 |
+
using LayoutA = layout::RowMajor;
|
| 736 |
+
using FragmentA = Array<uint8_t, 8>;
|
| 737 |
+
|
| 738 |
+
using ElementB = uint8_t;
|
| 739 |
+
using LayoutB = layout::ColumnMajor;
|
| 740 |
+
using FragmentB = Array<uint8_t, 4>;
|
| 741 |
+
|
| 742 |
+
using ElementC = int;
|
| 743 |
+
using LayoutC = layout::RowMajor;
|
| 744 |
+
using FragmentC = Array<int, 4>;
|
| 745 |
+
|
| 746 |
+
using Operator = OpMultiplyAddSaturate;
|
| 747 |
+
using ArchTag = arch::Sm80;
|
| 748 |
+
|
| 749 |
+
/// Computes multiply-add
|
| 750 |
+
CUTLASS_HOST_DEVICE
|
| 751 |
+
void operator()(
|
| 752 |
+
FragmentC &d,
|
| 753 |
+
FragmentA const &a,
|
| 754 |
+
FragmentB const &b,
|
| 755 |
+
FragmentC const &c
|
| 756 |
+
) const {
|
| 757 |
+
|
| 758 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 759 |
+
|
| 760 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 761 |
+
uint32_t const &B = reinterpret_cast<uint32_t const &>(b);
|
| 762 |
+
|
| 763 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 764 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 765 |
+
|
| 766 |
+
asm volatile(
|
| 767 |
+
"mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, "
|
| 768 |
+
"{%6}, {%7,%8,%9,%10};\n"
|
| 769 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 770 |
+
: "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]),
|
| 771 |
+
"r"(C[3]));
|
| 772 |
+
|
| 773 |
+
#else
|
| 774 |
+
assert(0);
|
| 775 |
+
#endif
|
| 776 |
+
}
|
| 777 |
+
};
|
| 778 |
+
|
| 779 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 780 |
+
//
|
| 781 |
+
// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE
|
| 782 |
+
//
|
| 783 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 784 |
+
|
| 785 |
+
/// Matrix multiply-add operation: S32 = S8 * S8 + S32
|
| 786 |
+
template <>
|
| 787 |
+
struct Mma<
|
| 788 |
+
gemm::GemmShape<16,8,32>,
|
| 789 |
+
32,
|
| 790 |
+
int8_t,
|
| 791 |
+
layout::RowMajor,
|
| 792 |
+
int8_t,
|
| 793 |
+
layout::ColumnMajor,
|
| 794 |
+
int,
|
| 795 |
+
layout::RowMajor,
|
| 796 |
+
OpMultiplyAddSaturate> {
|
| 797 |
+
|
| 798 |
+
using Shape = gemm::GemmShape<16,8,32>;
|
| 799 |
+
|
| 800 |
+
using ElementA = int8_t;
|
| 801 |
+
using LayoutA = layout::RowMajor;
|
| 802 |
+
using FragmentA = Array<int8_t, 16>;
|
| 803 |
+
|
| 804 |
+
using ElementB = int8_t;
|
| 805 |
+
using LayoutB = layout::ColumnMajor;
|
| 806 |
+
using FragmentB = Array<int8_t, 8>;
|
| 807 |
+
|
| 808 |
+
using ElementC = int;
|
| 809 |
+
using LayoutC = layout::RowMajor;
|
| 810 |
+
using FragmentC = Array<int, 4>;
|
| 811 |
+
|
| 812 |
+
using Operator = OpMultiplyAddSaturate;
|
| 813 |
+
using ArchTag = arch::Sm80;
|
| 814 |
+
|
| 815 |
+
/// Computes multiply-add
|
| 816 |
+
CUTLASS_HOST_DEVICE
|
| 817 |
+
void operator()(
|
| 818 |
+
FragmentC &d,
|
| 819 |
+
FragmentA const &a,
|
| 820 |
+
FragmentB const &b,
|
| 821 |
+
FragmentC const &c
|
| 822 |
+
) const {
|
| 823 |
+
|
| 824 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 825 |
+
|
| 826 |
+
uint32_t const * A = reinterpret_cast<uint32_t const *>(&a);
|
| 827 |
+
uint32_t const * B = reinterpret_cast<uint32_t const *>(&b);
|
| 828 |
+
|
| 829 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 830 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 831 |
+
|
| 832 |
+
asm volatile(
|
| 833 |
+
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, "
|
| 834 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 835 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 836 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 837 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 838 |
+
|
| 839 |
+
#else
|
| 840 |
+
assert(0);
|
| 841 |
+
#endif
|
| 842 |
+
}
|
| 843 |
+
};
|
| 844 |
+
|
| 845 |
+
/// Matrix multiply-add operation: S32 = U8 * S8 + S32
|
| 846 |
+
template <>
|
| 847 |
+
struct Mma<
|
| 848 |
+
gemm::GemmShape<16,8,32>,
|
| 849 |
+
32,
|
| 850 |
+
uint8_t,
|
| 851 |
+
layout::RowMajor,
|
| 852 |
+
int8_t,
|
| 853 |
+
layout::ColumnMajor,
|
| 854 |
+
int,
|
| 855 |
+
layout::RowMajor,
|
| 856 |
+
OpMultiplyAddSaturate> {
|
| 857 |
+
|
| 858 |
+
using Shape = gemm::GemmShape<16,8,32>;
|
| 859 |
+
|
| 860 |
+
using ElementA = uint8_t;
|
| 861 |
+
using LayoutA = layout::RowMajor;
|
| 862 |
+
using FragmentA = Array<uint8_t, 16>;
|
| 863 |
+
|
| 864 |
+
using ElementB = int8_t;
|
| 865 |
+
using LayoutB = layout::ColumnMajor;
|
| 866 |
+
using FragmentB = Array<int8_t, 8>;
|
| 867 |
+
|
| 868 |
+
using ElementC = int;
|
| 869 |
+
using LayoutC = layout::RowMajor;
|
| 870 |
+
using FragmentC = Array<int, 4>;
|
| 871 |
+
|
| 872 |
+
using Operator = OpMultiplyAddSaturate;
|
| 873 |
+
using ArchTag = arch::Sm80;
|
| 874 |
+
|
| 875 |
+
/// Computes multiply-add
|
| 876 |
+
CUTLASS_HOST_DEVICE
|
| 877 |
+
void operator()(
|
| 878 |
+
FragmentC &d,
|
| 879 |
+
FragmentA const &a,
|
| 880 |
+
FragmentB const &b,
|
| 881 |
+
FragmentC const &c
|
| 882 |
+
) const {
|
| 883 |
+
|
| 884 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 885 |
+
|
| 886 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 887 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 888 |
+
|
| 889 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 890 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 891 |
+
|
| 892 |
+
asm volatile(
|
| 893 |
+
"mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, "
|
| 894 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 895 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 896 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 897 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 898 |
+
|
| 899 |
+
#else
|
| 900 |
+
assert(0);
|
| 901 |
+
#endif
|
| 902 |
+
}
|
| 903 |
+
};
|
| 904 |
+
|
| 905 |
+
/// Matrix multiply-add operation: S32 = S8 * U8 + S32
|
| 906 |
+
template <>
|
| 907 |
+
struct Mma<
|
| 908 |
+
gemm::GemmShape<16,8,32>,
|
| 909 |
+
32,
|
| 910 |
+
int8_t,
|
| 911 |
+
layout::RowMajor,
|
| 912 |
+
uint8_t,
|
| 913 |
+
layout::ColumnMajor,
|
| 914 |
+
int,
|
| 915 |
+
layout::RowMajor,
|
| 916 |
+
OpMultiplyAddSaturate> {
|
| 917 |
+
|
| 918 |
+
using Shape = gemm::GemmShape<16,8,32>;
|
| 919 |
+
|
| 920 |
+
using ElementA = int8_t;
|
| 921 |
+
using LayoutA = layout::RowMajor;
|
| 922 |
+
using FragmentA = Array<int8_t, 16>;
|
| 923 |
+
|
| 924 |
+
using ElementB = uint8_t;
|
| 925 |
+
using LayoutB = layout::ColumnMajor;
|
| 926 |
+
using FragmentB = Array<uint8_t, 8>;
|
| 927 |
+
|
| 928 |
+
using ElementC = int;
|
| 929 |
+
using LayoutC = layout::RowMajor;
|
| 930 |
+
using FragmentC = Array<int, 4>;
|
| 931 |
+
|
| 932 |
+
using Operator = OpMultiplyAddSaturate;
|
| 933 |
+
using ArchTag = arch::Sm80;
|
| 934 |
+
|
| 935 |
+
/// Computes multiply-add
|
| 936 |
+
CUTLASS_HOST_DEVICE
|
| 937 |
+
void operator()(
|
| 938 |
+
FragmentC &d,
|
| 939 |
+
FragmentA const &a,
|
| 940 |
+
FragmentB const &b,
|
| 941 |
+
FragmentC const &c
|
| 942 |
+
) const {
|
| 943 |
+
|
| 944 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 945 |
+
|
| 946 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 947 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 948 |
+
|
| 949 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 950 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 951 |
+
|
| 952 |
+
asm volatile(
|
| 953 |
+
"mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, "
|
| 954 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 955 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 956 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 957 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 958 |
+
|
| 959 |
+
#else
|
| 960 |
+
assert(0);
|
| 961 |
+
#endif
|
| 962 |
+
}
|
| 963 |
+
};
|
| 964 |
+
|
| 965 |
+
/// Matrix multiply-add operation: S32 = U8 * U8 + S32
|
| 966 |
+
template <>
|
| 967 |
+
struct Mma<
|
| 968 |
+
gemm::GemmShape<16,8,32>,
|
| 969 |
+
32,
|
| 970 |
+
uint8_t,
|
| 971 |
+
layout::RowMajor,
|
| 972 |
+
uint8_t,
|
| 973 |
+
layout::ColumnMajor,
|
| 974 |
+
int,
|
| 975 |
+
layout::RowMajor,
|
| 976 |
+
OpMultiplyAddSaturate> {
|
| 977 |
+
|
| 978 |
+
using Shape = gemm::GemmShape<16,8,32>;
|
| 979 |
+
|
| 980 |
+
using ElementA = uint8_t;
|
| 981 |
+
using LayoutA = layout::RowMajor;
|
| 982 |
+
using FragmentA = Array<uint8_t, 16>;
|
| 983 |
+
|
| 984 |
+
using ElementB = uint8_t;
|
| 985 |
+
using LayoutB = layout::ColumnMajor;
|
| 986 |
+
using FragmentB = Array<uint8_t, 8>;
|
| 987 |
+
|
| 988 |
+
using ElementC = int;
|
| 989 |
+
using LayoutC = layout::RowMajor;
|
| 990 |
+
using FragmentC = Array<int, 4>;
|
| 991 |
+
|
| 992 |
+
using Operator = OpMultiplyAddSaturate;
|
| 993 |
+
using ArchTag = arch::Sm80;
|
| 994 |
+
|
| 995 |
+
/// Computes multiply-add
|
| 996 |
+
CUTLASS_HOST_DEVICE
|
| 997 |
+
void operator()(
|
| 998 |
+
FragmentC &d,
|
| 999 |
+
FragmentA const &a,
|
| 1000 |
+
FragmentB const &b,
|
| 1001 |
+
FragmentC const &c
|
| 1002 |
+
) const {
|
| 1003 |
+
|
| 1004 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 1005 |
+
|
| 1006 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1007 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1008 |
+
|
| 1009 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1010 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1011 |
+
|
| 1012 |
+
asm volatile(
|
| 1013 |
+
"mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, "
|
| 1014 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 1015 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1016 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1017 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1018 |
+
|
| 1019 |
+
#else
|
| 1020 |
+
assert(0);
|
| 1021 |
+
#endif
|
| 1022 |
+
}
|
| 1023 |
+
};
|
| 1024 |
+
|
| 1025 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1026 |
+
//
|
| 1027 |
+
// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE
|
| 1028 |
+
//
|
| 1029 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1030 |
+
|
| 1031 |
+
/// Matrix multiply-add operation: S32 = S4 * S4 + S32
|
| 1032 |
+
template <>
|
| 1033 |
+
struct Mma<
|
| 1034 |
+
gemm::GemmShape<16, 8, 64>,
|
| 1035 |
+
32,
|
| 1036 |
+
cutlass::int4b_t,
|
| 1037 |
+
layout::RowMajor,
|
| 1038 |
+
cutlass::int4b_t,
|
| 1039 |
+
layout::ColumnMajor,
|
| 1040 |
+
int,
|
| 1041 |
+
layout::RowMajor,
|
| 1042 |
+
OpMultiplyAddSaturate> {
|
| 1043 |
+
|
| 1044 |
+
using Shape = gemm::GemmShape<16, 8, 64>;
|
| 1045 |
+
|
| 1046 |
+
using ElementA = cutlass::int4b_t;
|
| 1047 |
+
using LayoutA = layout::RowMajor;
|
| 1048 |
+
using FragmentA = Array<cutlass::int4b_t, 32>;
|
| 1049 |
+
|
| 1050 |
+
using ElementB = cutlass::int4b_t;
|
| 1051 |
+
using LayoutB = layout::ColumnMajor;
|
| 1052 |
+
using FragmentB = Array<cutlass::int4b_t, 16>;
|
| 1053 |
+
|
| 1054 |
+
using ElementC = int;
|
| 1055 |
+
using LayoutC = layout::RowMajor;
|
| 1056 |
+
using FragmentC = Array<int, 4>;
|
| 1057 |
+
|
| 1058 |
+
using Operator = OpMultiplyAddSaturate;
|
| 1059 |
+
using ArchTag = arch::Sm80;
|
| 1060 |
+
|
| 1061 |
+
/// Computes multiply-add
|
| 1062 |
+
CUTLASS_HOST_DEVICE
|
| 1063 |
+
void operator()(
|
| 1064 |
+
FragmentC &d,
|
| 1065 |
+
FragmentA const &a,
|
| 1066 |
+
FragmentB const &b,
|
| 1067 |
+
FragmentC const &c
|
| 1068 |
+
) const {
|
| 1069 |
+
|
| 1070 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 1071 |
+
|
| 1072 |
+
uint32_t const * A = reinterpret_cast<uint32_t const *>(&a);
|
| 1073 |
+
uint32_t const * B = reinterpret_cast<uint32_t const *>(&b);
|
| 1074 |
+
|
| 1075 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1076 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1077 |
+
|
| 1078 |
+
asm volatile(
|
| 1079 |
+
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, "
|
| 1080 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 1081 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1082 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1083 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1084 |
+
|
| 1085 |
+
#else
|
| 1086 |
+
CUTLASS_UNUSED(a);
|
| 1087 |
+
CUTLASS_UNUSED(b);
|
| 1088 |
+
CUTLASS_UNUSED(c);
|
| 1089 |
+
CUTLASS_UNUSED(d);
|
| 1090 |
+
assert(0);
|
| 1091 |
+
#endif
|
| 1092 |
+
}
|
| 1093 |
+
};
|
| 1094 |
+
|
| 1095 |
+
/// Matrix multiply-add operation: S32 = U4 * S4 + S32
|
| 1096 |
+
template <>
|
| 1097 |
+
struct Mma<
|
| 1098 |
+
gemm::GemmShape<16, 8, 64>,
|
| 1099 |
+
32,
|
| 1100 |
+
cutlass::uint4b_t,
|
| 1101 |
+
layout::RowMajor,
|
| 1102 |
+
cutlass::int4b_t,
|
| 1103 |
+
layout::ColumnMajor,
|
| 1104 |
+
int,
|
| 1105 |
+
layout::RowMajor,
|
| 1106 |
+
OpMultiplyAddSaturate> {
|
| 1107 |
+
|
| 1108 |
+
using Shape = gemm::GemmShape<16, 8, 64>;
|
| 1109 |
+
|
| 1110 |
+
using ElementA = cutlass::uint4b_t;
|
| 1111 |
+
using LayoutA = layout::RowMajor;
|
| 1112 |
+
using FragmentA = Array<cutlass::uint4b_t, 32>;
|
| 1113 |
+
|
| 1114 |
+
using ElementB = cutlass::int4b_t;
|
| 1115 |
+
using LayoutB = layout::ColumnMajor;
|
| 1116 |
+
using FragmentB = Array<cutlass::int4b_t, 16>;
|
| 1117 |
+
|
| 1118 |
+
using ElementC = int;
|
| 1119 |
+
using LayoutC = layout::RowMajor;
|
| 1120 |
+
using FragmentC = Array<int, 4>;
|
| 1121 |
+
|
| 1122 |
+
using Operator = OpMultiplyAddSaturate;
|
| 1123 |
+
using ArchTag = arch::Sm80;
|
| 1124 |
+
|
| 1125 |
+
/// Computes multiply-add
|
| 1126 |
+
CUTLASS_HOST_DEVICE
|
| 1127 |
+
void operator()(
|
| 1128 |
+
FragmentC &d,
|
| 1129 |
+
FragmentA const &a,
|
| 1130 |
+
FragmentB const &b,
|
| 1131 |
+
FragmentC const &c
|
| 1132 |
+
) const {
|
| 1133 |
+
|
| 1134 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 1135 |
+
|
| 1136 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1137 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1138 |
+
|
| 1139 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1140 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1141 |
+
|
| 1142 |
+
asm volatile(
|
| 1143 |
+
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, "
|
| 1144 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 1145 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1146 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1147 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1148 |
+
|
| 1149 |
+
#else
|
| 1150 |
+
CUTLASS_UNUSED(a);
|
| 1151 |
+
CUTLASS_UNUSED(b);
|
| 1152 |
+
CUTLASS_UNUSED(c);
|
| 1153 |
+
CUTLASS_UNUSED(d);
|
| 1154 |
+
assert(0);
|
| 1155 |
+
#endif
|
| 1156 |
+
}
|
| 1157 |
+
};
|
| 1158 |
+
|
| 1159 |
+
/// Matrix multiply-add operation: S32 = S4 * U4 + S32
|
| 1160 |
+
template <>
|
| 1161 |
+
struct Mma<
|
| 1162 |
+
gemm::GemmShape<16, 8, 64>,
|
| 1163 |
+
32,
|
| 1164 |
+
cutlass::int4b_t,
|
| 1165 |
+
layout::RowMajor,
|
| 1166 |
+
cutlass::uint4b_t,
|
| 1167 |
+
layout::ColumnMajor,
|
| 1168 |
+
int,
|
| 1169 |
+
layout::RowMajor,
|
| 1170 |
+
OpMultiplyAddSaturate> {
|
| 1171 |
+
|
| 1172 |
+
using Shape = gemm::GemmShape<16, 8, 64>;
|
| 1173 |
+
|
| 1174 |
+
using ElementA = cutlass::int4b_t;
|
| 1175 |
+
using LayoutA = layout::RowMajor;
|
| 1176 |
+
using FragmentA = Array<cutlass::int4b_t, 32>;
|
| 1177 |
+
|
| 1178 |
+
using ElementB = cutlass::uint4b_t;
|
| 1179 |
+
using LayoutB = layout::ColumnMajor;
|
| 1180 |
+
using FragmentB = Array<cutlass::uint4b_t, 16>;
|
| 1181 |
+
|
| 1182 |
+
using ElementC = int;
|
| 1183 |
+
using LayoutC = layout::RowMajor;
|
| 1184 |
+
using FragmentC = Array<int, 4>;
|
| 1185 |
+
|
| 1186 |
+
using Operator = OpMultiplyAddSaturate;
|
| 1187 |
+
using ArchTag = arch::Sm80;
|
| 1188 |
+
|
| 1189 |
+
/// Computes multiply-add
|
| 1190 |
+
CUTLASS_HOST_DEVICE
|
| 1191 |
+
void operator()(
|
| 1192 |
+
FragmentC &d,
|
| 1193 |
+
FragmentA const &a,
|
| 1194 |
+
FragmentB const &b,
|
| 1195 |
+
FragmentC const &c
|
| 1196 |
+
) const {
|
| 1197 |
+
|
| 1198 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 1199 |
+
|
| 1200 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1201 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1202 |
+
|
| 1203 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1204 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1205 |
+
|
| 1206 |
+
asm volatile(
|
| 1207 |
+
"mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, "
|
| 1208 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 1209 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1210 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1211 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1212 |
+
|
| 1213 |
+
#else
|
| 1214 |
+
CUTLASS_UNUSED(a);
|
| 1215 |
+
CUTLASS_UNUSED(b);
|
| 1216 |
+
CUTLASS_UNUSED(c);
|
| 1217 |
+
CUTLASS_UNUSED(d);
|
| 1218 |
+
assert(0);
|
| 1219 |
+
#endif
|
| 1220 |
+
}
|
| 1221 |
+
};
|
| 1222 |
+
|
| 1223 |
+
/// Matrix multiply-add operation: S32 = U4 * U4 + S32
|
| 1224 |
+
template <>
|
| 1225 |
+
struct Mma<
|
| 1226 |
+
gemm::GemmShape<16, 8, 64>,
|
| 1227 |
+
32,
|
| 1228 |
+
cutlass::uint4b_t,
|
| 1229 |
+
layout::RowMajor,
|
| 1230 |
+
cutlass::uint4b_t,
|
| 1231 |
+
layout::ColumnMajor,
|
| 1232 |
+
int,
|
| 1233 |
+
layout::RowMajor,
|
| 1234 |
+
OpMultiplyAddSaturate> {
|
| 1235 |
+
|
| 1236 |
+
using Shape = gemm::GemmShape<16, 8, 64>;
|
| 1237 |
+
|
| 1238 |
+
using ElementA = cutlass::uint4b_t;
|
| 1239 |
+
using LayoutA = layout::RowMajor;
|
| 1240 |
+
using FragmentA = Array<cutlass::uint4b_t, 32>;
|
| 1241 |
+
|
| 1242 |
+
using ElementB = cutlass::uint4b_t;
|
| 1243 |
+
using LayoutB = layout::ColumnMajor;
|
| 1244 |
+
using FragmentB = Array<cutlass::uint4b_t, 16>;
|
| 1245 |
+
|
| 1246 |
+
using ElementC = int;
|
| 1247 |
+
using LayoutC = layout::RowMajor;
|
| 1248 |
+
using FragmentC = Array<int, 4>;
|
| 1249 |
+
|
| 1250 |
+
using Operator = OpMultiplyAddSaturate;
|
| 1251 |
+
using ArchTag = arch::Sm80;
|
| 1252 |
+
|
| 1253 |
+
/// Computes multiply-add
|
| 1254 |
+
CUTLASS_HOST_DEVICE
|
| 1255 |
+
void operator()(
|
| 1256 |
+
FragmentC &d,
|
| 1257 |
+
FragmentA const &a,
|
| 1258 |
+
FragmentB const &b,
|
| 1259 |
+
FragmentC const &c
|
| 1260 |
+
) const {
|
| 1261 |
+
|
| 1262 |
+
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
| 1263 |
+
|
| 1264 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1265 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1266 |
+
|
| 1267 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1268 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1269 |
+
|
| 1270 |
+
asm volatile(
|
| 1271 |
+
"mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, "
|
| 1272 |
+
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 1273 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1274 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1275 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1276 |
+
|
| 1277 |
+
#else
|
| 1278 |
+
CUTLASS_UNUSED(a);
|
| 1279 |
+
CUTLASS_UNUSED(b);
|
| 1280 |
+
CUTLASS_UNUSED(c);
|
| 1281 |
+
CUTLASS_UNUSED(d);
|
| 1282 |
+
assert(0);
|
| 1283 |
+
#endif
|
| 1284 |
+
}
|
| 1285 |
+
};
|
| 1286 |
+
|
| 1287 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1288 |
+
//
|
| 1289 |
+
// Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC
|
| 1290 |
+
//
|
| 1291 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1292 |
+
|
| 1293 |
+
/// Matrix multiply-add operation: S32 = B1 & B1 + S32
|
| 1294 |
+
template <>
|
| 1295 |
+
struct Mma<
|
| 1296 |
+
gemm::GemmShape<16,8,256>,
|
| 1297 |
+
32,
|
| 1298 |
+
cutlass::uint1b_t,
|
| 1299 |
+
layout::RowMajor,
|
| 1300 |
+
cutlass::uint1b_t,
|
| 1301 |
+
layout::ColumnMajor,
|
| 1302 |
+
int32_t,
|
| 1303 |
+
layout::RowMajor,
|
| 1304 |
+
OpAndPopc> {
|
| 1305 |
+
|
| 1306 |
+
using Shape = gemm::GemmShape<16,8,256>;
|
| 1307 |
+
|
| 1308 |
+
using ElementA = cutlass::uint1b_t;
|
| 1309 |
+
using LayoutA = layout::RowMajor;
|
| 1310 |
+
using FragmentA = Array<cutlass::uint1b_t, 128>;
|
| 1311 |
+
|
| 1312 |
+
using ElementB = cutlass::uint1b_t;
|
| 1313 |
+
using LayoutB = layout::ColumnMajor;
|
| 1314 |
+
using FragmentB = Array<cutlass::uint1b_t, 64>;
|
| 1315 |
+
|
| 1316 |
+
using ElementC = int32_t;
|
| 1317 |
+
using LayoutC = layout::RowMajor;
|
| 1318 |
+
using FragmentC = Array<int32_t, 4>;
|
| 1319 |
+
|
| 1320 |
+
using Operator = OpAndPopc;
|
| 1321 |
+
using ArchTag = arch::Sm80;
|
| 1322 |
+
|
| 1323 |
+
/// Computes multiply-add
|
| 1324 |
+
CUTLASS_HOST_DEVICE
|
| 1325 |
+
void operator()(
|
| 1326 |
+
FragmentC &d,
|
| 1327 |
+
FragmentA const &a,
|
| 1328 |
+
FragmentB const &b,
|
| 1329 |
+
FragmentC const &c
|
| 1330 |
+
) const {
|
| 1331 |
+
|
| 1332 |
+
#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED)
|
| 1333 |
+
|
| 1334 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1335 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1336 |
+
|
| 1337 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1338 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1339 |
+
|
| 1340 |
+
asm volatile(
|
| 1341 |
+
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, "
|
| 1342 |
+
"{%4,%5,%6,%7}, "
|
| 1343 |
+
"{%8,%9}, {%10,%11,%12,%13};\n"
|
| 1344 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1345 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1346 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1347 |
+
|
| 1348 |
+
#else
|
| 1349 |
+
CUTLASS_UNUSED(a);
|
| 1350 |
+
CUTLASS_UNUSED(b);
|
| 1351 |
+
CUTLASS_UNUSED(c);
|
| 1352 |
+
CUTLASS_UNUSED(d);
|
| 1353 |
+
assert(0);
|
| 1354 |
+
#endif
|
| 1355 |
+
}
|
| 1356 |
+
};
|
| 1357 |
+
|
| 1358 |
+
/// Matrix multiply-add operation: S32 = B1 & B1 + S32
|
| 1359 |
+
template <>
|
| 1360 |
+
struct Mma<
|
| 1361 |
+
gemm::GemmShape<16,8,256>,
|
| 1362 |
+
32,
|
| 1363 |
+
cutlass::uint1b_t,
|
| 1364 |
+
layout::RowMajor,
|
| 1365 |
+
cutlass::uint1b_t,
|
| 1366 |
+
layout::ColumnMajor,
|
| 1367 |
+
int,
|
| 1368 |
+
layout::RowMajor,
|
| 1369 |
+
OpMultiplyAdd> {
|
| 1370 |
+
|
| 1371 |
+
using Shape = gemm::GemmShape<16,8,256>;
|
| 1372 |
+
|
| 1373 |
+
using ElementA = cutlass::uint1b_t;
|
| 1374 |
+
using LayoutA = layout::RowMajor;
|
| 1375 |
+
using FragmentA = Array<cutlass::uint1b_t, 128>;
|
| 1376 |
+
|
| 1377 |
+
using ElementB = cutlass::uint1b_t;
|
| 1378 |
+
using LayoutB = layout::ColumnMajor;
|
| 1379 |
+
using FragmentB = Array<cutlass::uint1b_t, 64>;
|
| 1380 |
+
|
| 1381 |
+
using ElementC = int32_t;
|
| 1382 |
+
using LayoutC = layout::RowMajor;
|
| 1383 |
+
using FragmentC = Array<int32_t, 4>;
|
| 1384 |
+
|
| 1385 |
+
using Operator = OpMultiplyAdd;
|
| 1386 |
+
using ArchTag = arch::Sm80;
|
| 1387 |
+
|
| 1388 |
+
/// Computes multiply-add
|
| 1389 |
+
CUTLASS_HOST_DEVICE
|
| 1390 |
+
void operator()(
|
| 1391 |
+
FragmentC &d,
|
| 1392 |
+
FragmentA const &a,
|
| 1393 |
+
FragmentB const &b,
|
| 1394 |
+
FragmentC const &c
|
| 1395 |
+
) const {
|
| 1396 |
+
|
| 1397 |
+
#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED)
|
| 1398 |
+
|
| 1399 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1400 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1401 |
+
|
| 1402 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1403 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1404 |
+
|
| 1405 |
+
asm volatile(
|
| 1406 |
+
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, "
|
| 1407 |
+
"{%4,%5,%6,%7}, "
|
| 1408 |
+
"{%8,%9}, {%10,%11,%12,%13};\n"
|
| 1409 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1410 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1411 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1412 |
+
|
| 1413 |
+
#else
|
| 1414 |
+
CUTLASS_UNUSED(a);
|
| 1415 |
+
CUTLASS_UNUSED(b);
|
| 1416 |
+
CUTLASS_UNUSED(c);
|
| 1417 |
+
CUTLASS_UNUSED(d);
|
| 1418 |
+
assert(0);
|
| 1419 |
+
#endif
|
| 1420 |
+
}
|
| 1421 |
+
};
|
| 1422 |
+
|
| 1423 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1424 |
+
//
|
| 1425 |
+
// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC
|
| 1426 |
+
//
|
| 1427 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1428 |
+
|
| 1429 |
+
/// Matrix multiply-add operation: S32 = B1 & B1 + S32
|
| 1430 |
+
template <>
|
| 1431 |
+
struct Mma<
|
| 1432 |
+
gemm::GemmShape<16,8,256>,
|
| 1433 |
+
32,
|
| 1434 |
+
cutlass::uint1b_t,
|
| 1435 |
+
layout::RowMajor,
|
| 1436 |
+
cutlass::uint1b_t,
|
| 1437 |
+
layout::ColumnMajor,
|
| 1438 |
+
int,
|
| 1439 |
+
layout::RowMajor,
|
| 1440 |
+
OpXorPopc> {
|
| 1441 |
+
|
| 1442 |
+
using Shape = gemm::GemmShape<16,8,256>;
|
| 1443 |
+
|
| 1444 |
+
using ElementA = cutlass::uint1b_t;
|
| 1445 |
+
using LayoutA = layout::RowMajor;
|
| 1446 |
+
using FragmentA = Array<cutlass::uint1b_t, 128>;
|
| 1447 |
+
|
| 1448 |
+
using ElementB = cutlass::uint1b_t;
|
| 1449 |
+
using LayoutB = layout::ColumnMajor;
|
| 1450 |
+
using FragmentB = Array<cutlass::uint1b_t, 64>;
|
| 1451 |
+
|
| 1452 |
+
using ElementC = int;
|
| 1453 |
+
using LayoutC = layout::RowMajor;
|
| 1454 |
+
using FragmentC = Array<int, 4>;
|
| 1455 |
+
|
| 1456 |
+
using Operator = OpXorPopc;
|
| 1457 |
+
using ArchTag = arch::Sm80;
|
| 1458 |
+
|
| 1459 |
+
/// Computes multiply-add
|
| 1460 |
+
CUTLASS_HOST_DEVICE
|
| 1461 |
+
void operator()(
|
| 1462 |
+
FragmentC &d,
|
| 1463 |
+
FragmentA const &a,
|
| 1464 |
+
FragmentB const &b,
|
| 1465 |
+
FragmentC const &c
|
| 1466 |
+
) const {
|
| 1467 |
+
|
| 1468 |
+
#if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED)
|
| 1469 |
+
|
| 1470 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1471 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1472 |
+
|
| 1473 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1474 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1475 |
+
|
| 1476 |
+
asm volatile(
|
| 1477 |
+
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
|
| 1478 |
+
"{%4,%5,%6,%7}, "
|
| 1479 |
+
"{%8,%9}, {%10,%11,%12,%13};\n"
|
| 1480 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1481 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 1482 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
| 1483 |
+
|
| 1484 |
+
#else
|
| 1485 |
+
|
| 1486 |
+
CUTLASS_UNUSED(a);
|
| 1487 |
+
CUTLASS_UNUSED(b);
|
| 1488 |
+
CUTLASS_UNUSED(c);
|
| 1489 |
+
CUTLASS_UNUSED(d);
|
| 1490 |
+
assert(0);
|
| 1491 |
+
|
| 1492 |
+
#endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED)
|
| 1493 |
+
}
|
| 1494 |
+
};
|
| 1495 |
+
|
| 1496 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1497 |
+
|
| 1498 |
+
} // namespace arch
|
| 1499 |
+
} // namespace cutlass
|
| 1500 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Matrix multiply-accumulate specialzied for SM89
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include CUDA_STD_HEADER(cassert)
|
| 39 |
+
|
| 40 |
+
#include "mma.h"
|
| 41 |
+
#include "cutlass/layout/matrix.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)
|
| 47 |
+
# define CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)
|
| 51 |
+
# define CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
| 55 |
+
# if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED)
|
| 56 |
+
# define CUTLASS_ARCH_MMA_F32_SM89_ENABLED
|
| 57 |
+
# endif
|
| 58 |
+
|
| 59 |
+
# if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED)
|
| 60 |
+
# define CUTLASS_ARCH_MMA_F16_SM89_ENABLED
|
| 61 |
+
# endif
|
| 62 |
+
#endif
|
| 63 |
+
|
| 64 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
namespace cutlass {
|
| 67 |
+
namespace arch {
|
| 68 |
+
|
| 69 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 70 |
+
|
| 71 |
+
namespace detail {
|
| 72 |
+
|
| 73 |
+
// Whether the Mma uses as SM89 staged accumulation policy
|
| 74 |
+
template <class Operator>
|
| 75 |
+
static constexpr bool is_sm89_staged_policy_v =
|
| 76 |
+
(
|
| 77 |
+
// ElementA must be FP8
|
| 78 |
+
platform::is_same<typename Operator::ElementA, cutlass::float_e4m3_t>::value ||
|
| 79 |
+
platform::is_same<typename Operator::ElementA, cutlass::float_e5m2_t>::value
|
| 80 |
+
) &&
|
| 81 |
+
(
|
| 82 |
+
// ElementB must be FP8
|
| 83 |
+
platform::is_same<typename Operator::ElementB, cutlass::float_e4m3_t>::value ||
|
| 84 |
+
platform::is_same<typename Operator::ElementB, cutlass::float_e5m2_t>::value
|
| 85 |
+
) &&
|
| 86 |
+
(
|
| 87 |
+
// The instruction shape must be 16x8x32
|
| 88 |
+
Operator::ArchMmaOperator::Shape::kM == 16 &&
|
| 89 |
+
Operator::ArchMmaOperator::Shape::kN == 8 &&
|
| 90 |
+
Operator::ArchMmaOperator::Shape::kK == 32
|
| 91 |
+
) &&
|
| 92 |
+
(
|
| 93 |
+
// The operator must be OpMultiplyAdd (default)
|
| 94 |
+
platform::is_same<typename Operator::MathOperator, OpMultiplyAdd>::value
|
| 95 |
+
);
|
| 96 |
+
} // namespace detail
|
| 97 |
+
|
| 98 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 99 |
+
|
| 100 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 101 |
+
//
|
| 102 |
+
// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP32 accumulation
|
| 103 |
+
//
|
| 104 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 105 |
+
|
| 106 |
+
/// Matrix multiply-add operation - F32 = fe4m3 * fe4m3 + F32
|
| 107 |
+
template <typename Operator_>
|
| 108 |
+
struct Mma<
|
| 109 |
+
gemm::GemmShape<16, 8, 32>,
|
| 110 |
+
32,
|
| 111 |
+
cutlass::float_e4m3_t,
|
| 112 |
+
layout::RowMajor,
|
| 113 |
+
cutlass::float_e4m3_t,
|
| 114 |
+
layout::ColumnMajor,
|
| 115 |
+
float,
|
| 116 |
+
layout::RowMajor,
|
| 117 |
+
Operator_> {
|
| 118 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 119 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 120 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 121 |
+
|
| 122 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 123 |
+
|
| 124 |
+
using ElementA = cutlass::float_e4m3_t;
|
| 125 |
+
using LayoutA = layout::RowMajor;
|
| 126 |
+
using FragmentA = Array<ElementA, 16>;
|
| 127 |
+
|
| 128 |
+
using ElementB = cutlass::float_e4m3_t;
|
| 129 |
+
using LayoutB = layout::ColumnMajor;
|
| 130 |
+
using FragmentB = Array<ElementB, 8>;
|
| 131 |
+
|
| 132 |
+
using ElementC = float;
|
| 133 |
+
using LayoutC = layout::RowMajor;
|
| 134 |
+
using FragmentC = Array<float, 4>;
|
| 135 |
+
|
| 136 |
+
using Operator = Operator_;
|
| 137 |
+
using ArchTag = arch::Sm89;
|
| 138 |
+
|
| 139 |
+
CUTLASS_HOST_DEVICE
|
| 140 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 141 |
+
FragmentC const &c) const {
|
| 142 |
+
|
| 143 |
+
#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
|
| 144 |
+
|
| 145 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 146 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 147 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 148 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 149 |
+
|
| 150 |
+
asm(
|
| 151 |
+
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
| 152 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 153 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 154 |
+
:
|
| 155 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 156 |
+
"r"(B[0]), "r"(B[1]),
|
| 157 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 158 |
+
);
|
| 159 |
+
|
| 160 |
+
#else
|
| 161 |
+
|
| 162 |
+
CUTLASS_UNUSED(d);
|
| 163 |
+
CUTLASS_UNUSED(a);
|
| 164 |
+
CUTLASS_UNUSED(b);
|
| 165 |
+
CUTLASS_UNUSED(c);
|
| 166 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 167 |
+
|
| 168 |
+
#endif
|
| 169 |
+
}
|
| 170 |
+
};
|
| 171 |
+
|
| 172 |
+
/// Matrix multiply-add operation - F32 = fe4m3 * fe5m2 + F32
|
| 173 |
+
template <typename Operator_>
|
| 174 |
+
struct Mma<
|
| 175 |
+
gemm::GemmShape<16, 8, 32>,
|
| 176 |
+
32,
|
| 177 |
+
cutlass::float_e4m3_t,
|
| 178 |
+
layout::RowMajor,
|
| 179 |
+
cutlass::float_e5m2_t,
|
| 180 |
+
layout::ColumnMajor,
|
| 181 |
+
float,
|
| 182 |
+
layout::RowMajor,
|
| 183 |
+
Operator_> {
|
| 184 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 185 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 186 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 187 |
+
|
| 188 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 189 |
+
|
| 190 |
+
using ElementA = cutlass::float_e4m3_t;
|
| 191 |
+
using LayoutA = layout::RowMajor;
|
| 192 |
+
using FragmentA = Array<ElementA, 16>;
|
| 193 |
+
|
| 194 |
+
using ElementB = cutlass::float_e5m2_t;
|
| 195 |
+
using LayoutB = layout::ColumnMajor;
|
| 196 |
+
using FragmentB = Array<ElementB, 8>;
|
| 197 |
+
|
| 198 |
+
using ElementC = float;
|
| 199 |
+
using LayoutC = layout::RowMajor;
|
| 200 |
+
using FragmentC = Array<float, 4>;
|
| 201 |
+
|
| 202 |
+
using Operator = Operator_;
|
| 203 |
+
using ArchTag = arch::Sm89;
|
| 204 |
+
|
| 205 |
+
CUTLASS_HOST_DEVICE
|
| 206 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 207 |
+
FragmentC const &c) const {
|
| 208 |
+
|
| 209 |
+
#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
|
| 210 |
+
|
| 211 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 212 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 213 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 214 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 215 |
+
|
| 216 |
+
asm(
|
| 217 |
+
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 "
|
| 218 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 219 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 220 |
+
:
|
| 221 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 222 |
+
"r"(B[0]), "r"(B[1]),
|
| 223 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 224 |
+
);
|
| 225 |
+
|
| 226 |
+
#else
|
| 227 |
+
|
| 228 |
+
CUTLASS_UNUSED(d);
|
| 229 |
+
CUTLASS_UNUSED(a);
|
| 230 |
+
CUTLASS_UNUSED(b);
|
| 231 |
+
CUTLASS_UNUSED(c);
|
| 232 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 233 |
+
|
| 234 |
+
#endif
|
| 235 |
+
}
|
| 236 |
+
};
|
| 237 |
+
|
| 238 |
+
/// Matrix multiply-add operation - F32 = fe5m2 * fe4m3 + F32
|
| 239 |
+
template <typename Operator_>
|
| 240 |
+
struct Mma<
|
| 241 |
+
gemm::GemmShape<16, 8, 32>,
|
| 242 |
+
32,
|
| 243 |
+
cutlass::float_e5m2_t,
|
| 244 |
+
layout::RowMajor,
|
| 245 |
+
cutlass::float_e4m3_t,
|
| 246 |
+
layout::ColumnMajor,
|
| 247 |
+
float,
|
| 248 |
+
layout::RowMajor,
|
| 249 |
+
Operator_> {
|
| 250 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 251 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 252 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 253 |
+
|
| 254 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 255 |
+
|
| 256 |
+
using ElementA = cutlass::float_e5m2_t;
|
| 257 |
+
using LayoutA = layout::RowMajor;
|
| 258 |
+
using FragmentA = Array<ElementA, 16>;
|
| 259 |
+
|
| 260 |
+
using ElementB = cutlass::float_e4m3_t;
|
| 261 |
+
using LayoutB = layout::ColumnMajor;
|
| 262 |
+
using FragmentB = Array<ElementB, 8>;
|
| 263 |
+
|
| 264 |
+
using ElementC = float;
|
| 265 |
+
using LayoutC = layout::RowMajor;
|
| 266 |
+
using FragmentC = Array<float, 4>;
|
| 267 |
+
|
| 268 |
+
using Operator = Operator_;
|
| 269 |
+
using ArchTag = arch::Sm89;
|
| 270 |
+
|
| 271 |
+
CUTLASS_HOST_DEVICE
|
| 272 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 273 |
+
FragmentC const &c) const {
|
| 274 |
+
|
| 275 |
+
#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
|
| 276 |
+
|
| 277 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 278 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 279 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 280 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 281 |
+
|
| 282 |
+
asm(
|
| 283 |
+
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 "
|
| 284 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 285 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 286 |
+
:
|
| 287 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 288 |
+
"r"(B[0]), "r"(B[1]),
|
| 289 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 290 |
+
);
|
| 291 |
+
|
| 292 |
+
#else
|
| 293 |
+
|
| 294 |
+
CUTLASS_UNUSED(d);
|
| 295 |
+
CUTLASS_UNUSED(a);
|
| 296 |
+
CUTLASS_UNUSED(b);
|
| 297 |
+
CUTLASS_UNUSED(c);
|
| 298 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 299 |
+
|
| 300 |
+
#endif
|
| 301 |
+
}
|
| 302 |
+
};
|
| 303 |
+
|
| 304 |
+
/// Matrix multiply-add operation - F32 = fe5m2 * fe5m2 + F32
|
| 305 |
+
template <typename Operator_>
|
| 306 |
+
struct Mma<
|
| 307 |
+
gemm::GemmShape<16, 8, 32>,
|
| 308 |
+
32,
|
| 309 |
+
cutlass::float_e5m2_t,
|
| 310 |
+
layout::RowMajor,
|
| 311 |
+
cutlass::float_e5m2_t,
|
| 312 |
+
layout::ColumnMajor,
|
| 313 |
+
float,
|
| 314 |
+
layout::RowMajor,
|
| 315 |
+
Operator_> {
|
| 316 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 317 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 318 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 319 |
+
|
| 320 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 321 |
+
|
| 322 |
+
using ElementA = cutlass::float_e5m2_t;
|
| 323 |
+
using LayoutA = layout::RowMajor;
|
| 324 |
+
using FragmentA = Array<ElementA, 16>;
|
| 325 |
+
|
| 326 |
+
using ElementB = cutlass::float_e5m2_t;
|
| 327 |
+
using LayoutB = layout::ColumnMajor;
|
| 328 |
+
using FragmentB = Array<ElementB, 8>;
|
| 329 |
+
|
| 330 |
+
using ElementC = float;
|
| 331 |
+
using LayoutC = layout::RowMajor;
|
| 332 |
+
using FragmentC = Array<float, 4>;
|
| 333 |
+
|
| 334 |
+
using Operator = Operator_;
|
| 335 |
+
using ArchTag = arch::Sm89;
|
| 336 |
+
|
| 337 |
+
CUTLASS_HOST_DEVICE
|
| 338 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 339 |
+
FragmentC const &c) const {
|
| 340 |
+
|
| 341 |
+
#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED)
|
| 342 |
+
|
| 343 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 344 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 345 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 346 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 347 |
+
|
| 348 |
+
asm(
|
| 349 |
+
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
|
| 350 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 351 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 352 |
+
:
|
| 353 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 354 |
+
"r"(B[0]), "r"(B[1]),
|
| 355 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
|
| 356 |
+
);
|
| 357 |
+
|
| 358 |
+
#else
|
| 359 |
+
|
| 360 |
+
CUTLASS_UNUSED(d);
|
| 361 |
+
CUTLASS_UNUSED(a);
|
| 362 |
+
CUTLASS_UNUSED(b);
|
| 363 |
+
CUTLASS_UNUSED(c);
|
| 364 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 365 |
+
|
| 366 |
+
#endif
|
| 367 |
+
}
|
| 368 |
+
};
|
| 369 |
+
|
| 370 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 371 |
+
//
|
| 372 |
+
// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP16 accumulation
|
| 373 |
+
//
|
| 374 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 375 |
+
|
| 376 |
+
/// Matrix multiply-add operation - F16 = fe4m3 * fe4m3 + F16
|
| 377 |
+
template <typename Operator_>
|
| 378 |
+
struct Mma<
|
| 379 |
+
gemm::GemmShape<16, 8, 32>,
|
| 380 |
+
32,
|
| 381 |
+
cutlass::float_e4m3_t,
|
| 382 |
+
layout::RowMajor,
|
| 383 |
+
cutlass::float_e4m3_t,
|
| 384 |
+
layout::ColumnMajor,
|
| 385 |
+
cutlass::half_t,
|
| 386 |
+
layout::RowMajor,
|
| 387 |
+
Operator_> {
|
| 388 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 389 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 390 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 391 |
+
|
| 392 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 393 |
+
|
| 394 |
+
using ElementA = cutlass::float_e4m3_t;
|
| 395 |
+
using LayoutA = layout::RowMajor;
|
| 396 |
+
using FragmentA = Array<ElementA, 16>;
|
| 397 |
+
|
| 398 |
+
using ElementB = cutlass::float_e4m3_t;
|
| 399 |
+
using LayoutB = layout::ColumnMajor;
|
| 400 |
+
using FragmentB = Array<ElementB, 8>;
|
| 401 |
+
|
| 402 |
+
using ElementC = cutlass::half_t;
|
| 403 |
+
using LayoutC = layout::RowMajor;
|
| 404 |
+
using FragmentC = Array<cutlass::half_t, 4>;
|
| 405 |
+
|
| 406 |
+
using Operator = Operator_;
|
| 407 |
+
using ArchTag = arch::Sm89;
|
| 408 |
+
|
| 409 |
+
CUTLASS_HOST_DEVICE
|
| 410 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 411 |
+
FragmentC const &c) const {
|
| 412 |
+
|
| 413 |
+
#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
|
| 414 |
+
|
| 415 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 416 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 417 |
+
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
|
| 418 |
+
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
|
| 419 |
+
|
| 420 |
+
asm(
|
| 421 |
+
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 "
|
| 422 |
+
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
| 423 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 424 |
+
:
|
| 425 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 426 |
+
"r"(B[0]), "r"(B[1]),
|
| 427 |
+
"r"(C[0]), "r"(C[1])
|
| 428 |
+
);
|
| 429 |
+
|
| 430 |
+
#else
|
| 431 |
+
|
| 432 |
+
CUTLASS_UNUSED(d);
|
| 433 |
+
CUTLASS_UNUSED(a);
|
| 434 |
+
CUTLASS_UNUSED(b);
|
| 435 |
+
CUTLASS_UNUSED(c);
|
| 436 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 437 |
+
|
| 438 |
+
#endif
|
| 439 |
+
}
|
| 440 |
+
};
|
| 441 |
+
|
| 442 |
+
/// Matrix multiply-add operation - F16 = fe4m3 * fe5m2 + F16
|
| 443 |
+
template <typename Operator_>
|
| 444 |
+
struct Mma<
|
| 445 |
+
gemm::GemmShape<16, 8, 32>,
|
| 446 |
+
32,
|
| 447 |
+
cutlass::float_e4m3_t,
|
| 448 |
+
layout::RowMajor,
|
| 449 |
+
cutlass::float_e5m2_t,
|
| 450 |
+
layout::ColumnMajor,
|
| 451 |
+
cutlass::half_t,
|
| 452 |
+
layout::RowMajor,
|
| 453 |
+
Operator_> {
|
| 454 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 455 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 456 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 457 |
+
|
| 458 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 459 |
+
|
| 460 |
+
using ElementA = cutlass::float_e4m3_t;
|
| 461 |
+
using LayoutA = layout::RowMajor;
|
| 462 |
+
using FragmentA = Array<ElementA, 16>;
|
| 463 |
+
|
| 464 |
+
using ElementB = cutlass::float_e5m2_t;
|
| 465 |
+
using LayoutB = layout::ColumnMajor;
|
| 466 |
+
using FragmentB = Array<ElementB, 8>;
|
| 467 |
+
|
| 468 |
+
using ElementC = cutlass::half_t;
|
| 469 |
+
using LayoutC = layout::RowMajor;
|
| 470 |
+
using FragmentC = Array<cutlass::half_t, 4>;
|
| 471 |
+
|
| 472 |
+
using Operator = Operator_;
|
| 473 |
+
using ArchTag = arch::Sm89;
|
| 474 |
+
|
| 475 |
+
CUTLASS_HOST_DEVICE
|
| 476 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 477 |
+
FragmentC const &c) const {
|
| 478 |
+
|
| 479 |
+
#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
|
| 480 |
+
|
| 481 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 482 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 483 |
+
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
|
| 484 |
+
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
|
| 485 |
+
|
| 486 |
+
asm(
|
| 487 |
+
"mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 "
|
| 488 |
+
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
| 489 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 490 |
+
:
|
| 491 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 492 |
+
"r"(B[0]), "r"(B[1]),
|
| 493 |
+
"r"(C[0]), "r"(C[1])
|
| 494 |
+
);
|
| 495 |
+
|
| 496 |
+
#else
|
| 497 |
+
|
| 498 |
+
CUTLASS_UNUSED(d);
|
| 499 |
+
CUTLASS_UNUSED(a);
|
| 500 |
+
CUTLASS_UNUSED(b);
|
| 501 |
+
CUTLASS_UNUSED(c);
|
| 502 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 503 |
+
|
| 504 |
+
#endif
|
| 505 |
+
}
|
| 506 |
+
};
|
| 507 |
+
|
| 508 |
+
/// Matrix multiply-add operation - F16 = fe5m2 * fe4m3 + F16
|
| 509 |
+
template <typename Operator_>
|
| 510 |
+
struct Mma<
|
| 511 |
+
gemm::GemmShape<16, 8, 32>,
|
| 512 |
+
32,
|
| 513 |
+
cutlass::float_e5m2_t,
|
| 514 |
+
layout::RowMajor,
|
| 515 |
+
cutlass::float_e4m3_t,
|
| 516 |
+
layout::ColumnMajor,
|
| 517 |
+
cutlass::half_t,
|
| 518 |
+
layout::RowMajor,
|
| 519 |
+
Operator_> {
|
| 520 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 521 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 522 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 523 |
+
|
| 524 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 525 |
+
|
| 526 |
+
using ElementA = cutlass::float_e5m2_t;
|
| 527 |
+
using LayoutA = layout::RowMajor;
|
| 528 |
+
using FragmentA = Array<ElementA, 16>;
|
| 529 |
+
|
| 530 |
+
using ElementB = cutlass::float_e4m3_t;
|
| 531 |
+
using LayoutB = layout::ColumnMajor;
|
| 532 |
+
using FragmentB = Array<ElementB, 8>;
|
| 533 |
+
|
| 534 |
+
using ElementC = cutlass::half_t;
|
| 535 |
+
using LayoutC = layout::RowMajor;
|
| 536 |
+
using FragmentC = Array<cutlass::half_t, 4>;
|
| 537 |
+
|
| 538 |
+
using Operator = Operator_;
|
| 539 |
+
using ArchTag = arch::Sm89;
|
| 540 |
+
|
| 541 |
+
CUTLASS_HOST_DEVICE
|
| 542 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 543 |
+
FragmentC const &c) const {
|
| 544 |
+
|
| 545 |
+
#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
|
| 546 |
+
|
| 547 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 548 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 549 |
+
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
|
| 550 |
+
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
|
| 551 |
+
|
| 552 |
+
asm(
|
| 553 |
+
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 "
|
| 554 |
+
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
| 555 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 556 |
+
:
|
| 557 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 558 |
+
"r"(B[0]), "r"(B[1]),
|
| 559 |
+
"r"(C[0]), "r"(C[1])
|
| 560 |
+
);
|
| 561 |
+
|
| 562 |
+
#else
|
| 563 |
+
|
| 564 |
+
CUTLASS_UNUSED(d);
|
| 565 |
+
CUTLASS_UNUSED(a);
|
| 566 |
+
CUTLASS_UNUSED(b);
|
| 567 |
+
CUTLASS_UNUSED(c);
|
| 568 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 569 |
+
|
| 570 |
+
#endif
|
| 571 |
+
}
|
| 572 |
+
};
|
| 573 |
+
|
| 574 |
+
/// Matrix multiply-add operation - F16 = fe5m2 * fe5m2 + F16
|
| 575 |
+
template <typename Operator_>
|
| 576 |
+
struct Mma<
|
| 577 |
+
gemm::GemmShape<16, 8, 32>,
|
| 578 |
+
32,
|
| 579 |
+
cutlass::float_e5m2_t,
|
| 580 |
+
layout::RowMajor,
|
| 581 |
+
cutlass::float_e5m2_t,
|
| 582 |
+
layout::ColumnMajor,
|
| 583 |
+
cutlass::half_t,
|
| 584 |
+
layout::RowMajor,
|
| 585 |
+
Operator_> {
|
| 586 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 587 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 588 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 589 |
+
|
| 590 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 591 |
+
|
| 592 |
+
using ElementA = cutlass::float_e5m2_t;
|
| 593 |
+
using LayoutA = layout::RowMajor;
|
| 594 |
+
using FragmentA = Array<ElementA, 16>;
|
| 595 |
+
|
| 596 |
+
using ElementB = cutlass::float_e5m2_t;
|
| 597 |
+
using LayoutB = layout::ColumnMajor;
|
| 598 |
+
using FragmentB = Array<ElementB, 8>;
|
| 599 |
+
|
| 600 |
+
using ElementC = cutlass::half_t;
|
| 601 |
+
using LayoutC = layout::RowMajor;
|
| 602 |
+
using FragmentC = Array<cutlass::half_t, 4>;
|
| 603 |
+
|
| 604 |
+
using Operator = Operator_;
|
| 605 |
+
using ArchTag = arch::Sm89;
|
| 606 |
+
|
| 607 |
+
CUTLASS_HOST_DEVICE
|
| 608 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 609 |
+
FragmentC const &c) const {
|
| 610 |
+
|
| 611 |
+
#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED)
|
| 612 |
+
|
| 613 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 614 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 615 |
+
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
|
| 616 |
+
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
|
| 617 |
+
|
| 618 |
+
asm(
|
| 619 |
+
"mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 "
|
| 620 |
+
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
| 621 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 622 |
+
:
|
| 623 |
+
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
|
| 624 |
+
"r"(B[0]), "r"(B[1]),
|
| 625 |
+
"r"(C[0]), "r"(C[1])
|
| 626 |
+
);
|
| 627 |
+
|
| 628 |
+
#else
|
| 629 |
+
|
| 630 |
+
CUTLASS_UNUSED(d);
|
| 631 |
+
CUTLASS_UNUSED(a);
|
| 632 |
+
CUTLASS_UNUSED(b);
|
| 633 |
+
CUTLASS_UNUSED(c);
|
| 634 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 635 |
+
|
| 636 |
+
#endif
|
| 637 |
+
}
|
| 638 |
+
};
|
| 639 |
+
|
| 640 |
+
} // namespace arch
|
| 641 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include CUDA_STD_HEADER(cassert)
|
| 38 |
+
|
| 39 |
+
#include "mma.h"
|
| 40 |
+
#include "cutlass/layout/matrix.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
#include "cutlass/arch/config.h"
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace arch {
|
| 48 |
+
|
| 49 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
/// Matrix Multiply-Add 16x8x4 fp64
|
| 51 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
|
| 54 |
+
template <>
|
| 55 |
+
struct Mma<
|
| 56 |
+
gemm::GemmShape<16,8,4>,
|
| 57 |
+
32,
|
| 58 |
+
double,
|
| 59 |
+
layout::RowMajor,
|
| 60 |
+
double,
|
| 61 |
+
layout::ColumnMajor,
|
| 62 |
+
double,
|
| 63 |
+
layout::RowMajor,
|
| 64 |
+
OpMultiplyAdd> {
|
| 65 |
+
|
| 66 |
+
using Shape = gemm::GemmShape<16,8,4>;
|
| 67 |
+
|
| 68 |
+
using ElementA = double;
|
| 69 |
+
using LayoutA = layout::RowMajor;
|
| 70 |
+
using FragmentA = Array<double, 2>;
|
| 71 |
+
|
| 72 |
+
using ElementB = double;
|
| 73 |
+
using LayoutB = layout::ColumnMajor;
|
| 74 |
+
using FragmentB = Array<double, 1>;
|
| 75 |
+
|
| 76 |
+
using ElementC = double;
|
| 77 |
+
using LayoutC = layout::RowMajor;
|
| 78 |
+
using FragmentC = Array<double, 4>;
|
| 79 |
+
|
| 80 |
+
using Operator = OpMultiplyAdd;
|
| 81 |
+
|
| 82 |
+
using ArchTag = arch::Sm90;
|
| 83 |
+
|
| 84 |
+
CUTLASS_HOST_DEVICE
|
| 85 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 86 |
+
FragmentC const &c) const {
|
| 87 |
+
|
| 88 |
+
#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)
|
| 89 |
+
|
| 90 |
+
double const *A = reinterpret_cast<double const *>(&a);
|
| 91 |
+
double const *B = reinterpret_cast<double const *>(&b);
|
| 92 |
+
|
| 93 |
+
double const *C = reinterpret_cast<double const *>(&c);
|
| 94 |
+
double *D = reinterpret_cast<double *>(&d);
|
| 95 |
+
|
| 96 |
+
asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
| 97 |
+
: "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
|
| 98 |
+
: "d"(A[0]), "d"(A[1]),
|
| 99 |
+
"d"(B[0]),
|
| 100 |
+
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
|
| 101 |
+
|
| 102 |
+
#else
|
| 103 |
+
CUTLASS_UNUSED(d);
|
| 104 |
+
CUTLASS_UNUSED(a);
|
| 105 |
+
CUTLASS_UNUSED(b);
|
| 106 |
+
CUTLASS_UNUSED(c);
|
| 107 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 108 |
+
#endif
|
| 109 |
+
}
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 113 |
+
/// Matrix Multiply-Add 16x8x8 fp64
|
| 114 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 115 |
+
|
| 116 |
+
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
|
| 117 |
+
template <>
|
| 118 |
+
struct Mma<
|
| 119 |
+
gemm::GemmShape<16,8,8>,
|
| 120 |
+
32,
|
| 121 |
+
double,
|
| 122 |
+
layout::RowMajor,
|
| 123 |
+
double,
|
| 124 |
+
layout::ColumnMajor,
|
| 125 |
+
double,
|
| 126 |
+
layout::RowMajor,
|
| 127 |
+
OpMultiplyAdd> {
|
| 128 |
+
|
| 129 |
+
using Shape = gemm::GemmShape<16,8,8>;
|
| 130 |
+
|
| 131 |
+
using ElementA = double;
|
| 132 |
+
using LayoutA = layout::RowMajor;
|
| 133 |
+
using FragmentA = Array<double, 4>;
|
| 134 |
+
|
| 135 |
+
using ElementB = double;
|
| 136 |
+
using LayoutB = layout::ColumnMajor;
|
| 137 |
+
using FragmentB = Array<double, 2>;
|
| 138 |
+
|
| 139 |
+
using ElementC = double;
|
| 140 |
+
using LayoutC = layout::RowMajor;
|
| 141 |
+
using FragmentC = Array<double, 4>;
|
| 142 |
+
|
| 143 |
+
using Operator = OpMultiplyAdd;
|
| 144 |
+
|
| 145 |
+
using ArchTag = arch::Sm90;
|
| 146 |
+
|
| 147 |
+
CUTLASS_HOST_DEVICE
|
| 148 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 149 |
+
FragmentC const &c) const {
|
| 150 |
+
|
| 151 |
+
#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)
|
| 152 |
+
|
| 153 |
+
double const *A = reinterpret_cast<double const *>(&a);
|
| 154 |
+
double const *B = reinterpret_cast<double const *>(&b);
|
| 155 |
+
|
| 156 |
+
double const *C = reinterpret_cast<double const *>(&c);
|
| 157 |
+
double *D = reinterpret_cast<double *>(&d);
|
| 158 |
+
|
| 159 |
+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
| 160 |
+
: "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3])
|
| 161 |
+
: "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]),
|
| 162 |
+
"d"(B[0]), "d"(B[1]),
|
| 163 |
+
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
|
| 164 |
+
|
| 165 |
+
#else
|
| 166 |
+
|
| 167 |
+
CUTLASS_UNUSED(d);
|
| 168 |
+
CUTLASS_UNUSED(a);
|
| 169 |
+
CUTLASS_UNUSED(b);
|
| 170 |
+
CUTLASS_UNUSED(c);
|
| 171 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 172 |
+
#endif
|
| 173 |
+
}
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 177 |
+
/// Matrix Multiply-Add 16x8x16 fp64
|
| 178 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 179 |
+
|
| 180 |
+
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
|
| 181 |
+
template <>
|
| 182 |
+
struct Mma<
|
| 183 |
+
gemm::GemmShape<16,8,16>,
|
| 184 |
+
32,
|
| 185 |
+
double,
|
| 186 |
+
layout::RowMajor,
|
| 187 |
+
double,
|
| 188 |
+
layout::ColumnMajor,
|
| 189 |
+
double,
|
| 190 |
+
layout::RowMajor,
|
| 191 |
+
OpMultiplyAdd> {
|
| 192 |
+
|
| 193 |
+
using Shape = gemm::GemmShape<16,8,16>;
|
| 194 |
+
|
| 195 |
+
using ElementA = double;
|
| 196 |
+
using LayoutA = layout::RowMajor;
|
| 197 |
+
using FragmentA = Array<double, 8>;
|
| 198 |
+
|
| 199 |
+
using ElementB = double;
|
| 200 |
+
using LayoutB = layout::ColumnMajor;
|
| 201 |
+
using FragmentB = Array<double, 4>;
|
| 202 |
+
|
| 203 |
+
using ElementC = double;
|
| 204 |
+
using LayoutC = layout::RowMajor;
|
| 205 |
+
using FragmentC = Array<double, 4>;
|
| 206 |
+
|
| 207 |
+
using Operator = OpMultiplyAdd;
|
| 208 |
+
|
| 209 |
+
using ArchTag = arch::Sm90;
|
| 210 |
+
|
| 211 |
+
CUTLASS_HOST_DEVICE
|
| 212 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 213 |
+
FragmentC const &c) const {
|
| 214 |
+
|
| 215 |
+
#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)
|
| 216 |
+
|
| 217 |
+
double const *A = reinterpret_cast<double const *>(&a);
|
| 218 |
+
double const *B = reinterpret_cast<double const *>(&b);
|
| 219 |
+
|
| 220 |
+
double const *C = reinterpret_cast<double const *>(&c);
|
| 221 |
+
double *D = reinterpret_cast<double *>(&d);
|
| 222 |
+
|
| 223 |
+
asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n"
|
| 224 |
+
: "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
|
| 225 |
+
: "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]),
|
| 226 |
+
"d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]),
|
| 227 |
+
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
|
| 228 |
+
|
| 229 |
+
#else
|
| 230 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 231 |
+
#endif
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 236 |
+
|
| 237 |
+
} // namespace arch
|
| 238 |
+
} // namespace cutlass
|
| 239 |
+
|
| 240 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 241 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h
ADDED
|
@@ -0,0 +1,1234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Sparse matrix multiply accumulate for SM80
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include CUDA_STD_HEADER(cassert)
|
| 39 |
+
|
| 40 |
+
#include "mma.h"
|
| 41 |
+
#include "cutlass/layout/matrix.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))
|
| 47 |
+
|
| 48 |
+
#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1
|
| 49 |
+
|
| 50 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
| 51 |
+
#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
#endif
|
| 55 |
+
|
| 56 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
namespace cutlass {
|
| 59 |
+
namespace arch {
|
| 60 |
+
|
| 61 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
//
|
| 65 |
+
// Sparse Matrix Multiply 16832
|
| 66 |
+
//
|
| 67 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
/// Matrix multiply-add operation: F16 = F16 * F16 + F16
|
| 70 |
+
template <>
|
| 71 |
+
struct SparseMma<
|
| 72 |
+
gemm::GemmShape<16, 8, 32>,
|
| 73 |
+
32,
|
| 74 |
+
half_t,
|
| 75 |
+
layout::RowMajor,
|
| 76 |
+
half_t,
|
| 77 |
+
layout::ColumnMajor,
|
| 78 |
+
half_t,
|
| 79 |
+
layout::RowMajor,
|
| 80 |
+
OpMultiplyAdd,
|
| 81 |
+
SPFormatType::Thread
|
| 82 |
+
> {
|
| 83 |
+
|
| 84 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 85 |
+
|
| 86 |
+
using ElementA = half_t;
|
| 87 |
+
using LayoutA = layout::RowMajor;
|
| 88 |
+
using FragmentA = Array<half_t, 8>;
|
| 89 |
+
|
| 90 |
+
using ElementB = half_t;
|
| 91 |
+
using LayoutB = layout::ColumnMajor;
|
| 92 |
+
using FragmentB = Array<half_t, 8>;
|
| 93 |
+
|
| 94 |
+
using ElementC = half_t;
|
| 95 |
+
using LayoutC = layout::RowMajor;
|
| 96 |
+
using FragmentC = Array<half_t, 4>;
|
| 97 |
+
|
| 98 |
+
using FragmentE = uint32_t;
|
| 99 |
+
|
| 100 |
+
using Operator = OpMultiplyAdd;
|
| 101 |
+
using ArchTag = arch::Sm80;
|
| 102 |
+
|
| 103 |
+
static int const kSparse = 2;
|
| 104 |
+
|
| 105 |
+
static int const kMetaSizeInBits = 2;
|
| 106 |
+
|
| 107 |
+
static int const kMaxID2 = 2;
|
| 108 |
+
|
| 109 |
+
/// Computes multiply-add
|
| 110 |
+
CUTLASS_HOST_DEVICE
|
| 111 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 112 |
+
FragmentC const &c, uint32_t const &E, int const id2) const {
|
| 113 |
+
|
| 114 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 115 |
+
|
| 116 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 117 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 118 |
+
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
|
| 119 |
+
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
|
| 120 |
+
|
| 121 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 122 |
+
if (id2 == 0) {
|
| 123 |
+
asm volatile(
|
| 124 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
|
| 125 |
+
"{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n"
|
| 126 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 127 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 128 |
+
"r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
|
| 129 |
+
}
|
| 130 |
+
else if (id2 == 1) {
|
| 131 |
+
asm volatile(
|
| 132 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
|
| 133 |
+
"{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n"
|
| 134 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 135 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 136 |
+
"r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
|
| 137 |
+
}
|
| 138 |
+
else {
|
| 139 |
+
assert(0);
|
| 140 |
+
}
|
| 141 |
+
#else
|
| 142 |
+
if (id2 == 0) {
|
| 143 |
+
asm volatile(
|
| 144 |
+
"mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
|
| 145 |
+
"{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n"
|
| 146 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 147 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 148 |
+
"r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
|
| 149 |
+
}
|
| 150 |
+
else if (id2 == 1) {
|
| 151 |
+
asm volatile(
|
| 152 |
+
"mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, "
|
| 153 |
+
"{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n"
|
| 154 |
+
: "=r"(D[0]), "=r"(D[1])
|
| 155 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 156 |
+
"r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E));
|
| 157 |
+
}
|
| 158 |
+
else {
|
| 159 |
+
assert(0);
|
| 160 |
+
}
|
| 161 |
+
#endif
|
| 162 |
+
|
| 163 |
+
#else
|
| 164 |
+
CUTLASS_UNUSED(a);
|
| 165 |
+
CUTLASS_UNUSED(b);
|
| 166 |
+
CUTLASS_UNUSED(c);
|
| 167 |
+
CUTLASS_UNUSED(d);
|
| 168 |
+
assert(0);
|
| 169 |
+
#endif
|
| 170 |
+
}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 174 |
+
|
| 175 |
+
/// Matrix multiply-add operation: F32 = F16 * F16 + F32
|
| 176 |
+
template <>
|
| 177 |
+
struct SparseMma<
|
| 178 |
+
gemm::GemmShape<16, 8, 32>,
|
| 179 |
+
32,
|
| 180 |
+
half_t,
|
| 181 |
+
layout::RowMajor,
|
| 182 |
+
half_t,
|
| 183 |
+
layout::ColumnMajor,
|
| 184 |
+
float,
|
| 185 |
+
layout::RowMajor,
|
| 186 |
+
OpMultiplyAdd,
|
| 187 |
+
SPFormatType::Thread
|
| 188 |
+
> {
|
| 189 |
+
|
| 190 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 191 |
+
|
| 192 |
+
using ElementA = half_t;
|
| 193 |
+
using LayoutA = layout::RowMajor;
|
| 194 |
+
using FragmentA = Array<half_t, 8>;
|
| 195 |
+
|
| 196 |
+
using ElementB = half_t;
|
| 197 |
+
using LayoutB = layout::ColumnMajor;
|
| 198 |
+
using FragmentB = Array<half_t, 8>;
|
| 199 |
+
|
| 200 |
+
using ElementC = float;
|
| 201 |
+
using LayoutC = layout::RowMajor;
|
| 202 |
+
using FragmentC = Array<float, 4>;
|
| 203 |
+
|
| 204 |
+
using FragmentE = uint32_t;
|
| 205 |
+
|
| 206 |
+
using Operator = OpMultiplyAdd;
|
| 207 |
+
using ArchTag = arch::Sm80;
|
| 208 |
+
|
| 209 |
+
static int const kSparse = 2;
|
| 210 |
+
|
| 211 |
+
static int const kMetaSizeInBits = 2;
|
| 212 |
+
|
| 213 |
+
static int const kMaxID2 = 2;
|
| 214 |
+
|
| 215 |
+
/// Computes multiply-add
|
| 216 |
+
CUTLASS_HOST_DEVICE
|
| 217 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 218 |
+
FragmentC const &c, uint32_t const &E, int const id2) const {
|
| 219 |
+
|
| 220 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 221 |
+
|
| 222 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 223 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 224 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 225 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 226 |
+
|
| 227 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 228 |
+
if (id2 == 0) {
|
| 229 |
+
asm volatile(
|
| 230 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
|
| 231 |
+
"{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 232 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 233 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 234 |
+
"r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
|
| 235 |
+
"r"(E));
|
| 236 |
+
}
|
| 237 |
+
else if (id2 == 1) {
|
| 238 |
+
asm volatile(
|
| 239 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
|
| 240 |
+
"{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
|
| 241 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 242 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 243 |
+
"r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
|
| 244 |
+
"r"(E));
|
| 245 |
+
}
|
| 246 |
+
else {
|
| 247 |
+
assert(0);
|
| 248 |
+
}
|
| 249 |
+
#else
|
| 250 |
+
if (id2 == 0) {
|
| 251 |
+
asm volatile(
|
| 252 |
+
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
|
| 253 |
+
"{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 254 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 255 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 256 |
+
"r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
|
| 257 |
+
"r"(E));
|
| 258 |
+
}
|
| 259 |
+
else if (id2 == 1) {
|
| 260 |
+
asm volatile(
|
| 261 |
+
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
|
| 262 |
+
"{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
|
| 263 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 264 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
| 265 |
+
"r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
|
| 266 |
+
"r"(E));
|
| 267 |
+
}
|
| 268 |
+
else {
|
| 269 |
+
assert(0);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
#endif
|
| 273 |
+
|
| 274 |
+
#else
|
| 275 |
+
CUTLASS_UNUSED(a);
|
| 276 |
+
CUTLASS_UNUSED(b);
|
| 277 |
+
CUTLASS_UNUSED(c);
|
| 278 |
+
CUTLASS_UNUSED(d);
|
| 279 |
+
assert(0);
|
| 280 |
+
#endif
|
| 281 |
+
}
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 285 |
+
//
|
| 286 |
+
// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation
|
| 287 |
+
//
|
| 288 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 289 |
+
|
| 290 |
+
/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32
|
| 291 |
+
template <>
|
| 292 |
+
struct SparseMma<gemm::GemmShape<16, 8, 32>, 32, bfloat16_t, layout::RowMajor,
|
| 293 |
+
bfloat16_t, layout::ColumnMajor, float, layout::RowMajor,
|
| 294 |
+
OpMultiplyAdd, SPFormatType::Thread> {
|
| 295 |
+
using Shape = gemm::GemmShape<16, 8, 32>;
|
| 296 |
+
|
| 297 |
+
using ElementA = bfloat16_t;
|
| 298 |
+
using LayoutA = layout::RowMajor;
|
| 299 |
+
using FragmentA = Array<bfloat16_t, 8>;
|
| 300 |
+
|
| 301 |
+
using ElementB = bfloat16_t;
|
| 302 |
+
using LayoutB = layout::ColumnMajor;
|
| 303 |
+
using FragmentB = Array<bfloat16_t, 8>;
|
| 304 |
+
|
| 305 |
+
using ElementC = float;
|
| 306 |
+
using LayoutC = layout::RowMajor;
|
| 307 |
+
using FragmentC = Array<float, 4>;
|
| 308 |
+
|
| 309 |
+
using FragmentE = uint32_t;
|
| 310 |
+
|
| 311 |
+
using Operator = OpMultiplyAdd;
|
| 312 |
+
using ArchTag = arch::Sm80;
|
| 313 |
+
|
| 314 |
+
static int const kSparse = 2;
|
| 315 |
+
|
| 316 |
+
static int const kMetaSizeInBits = 2;
|
| 317 |
+
|
| 318 |
+
static int const kMaxID2 = 2;
|
| 319 |
+
|
| 320 |
+
CUTLASS_HOST_DEVICE
|
| 321 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 322 |
+
FragmentC const &c, uint32_t const &E, int const id2) const {
|
| 323 |
+
|
| 324 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 325 |
+
|
| 326 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 327 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 328 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 329 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 330 |
+
|
| 331 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 332 |
+
if (id2 == 0) {
|
| 333 |
+
asm volatile(
|
| 334 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
|
| 335 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 336 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 337 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 338 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 339 |
+
} else if (id2 == 1) {
|
| 340 |
+
asm volatile(
|
| 341 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
|
| 342 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
|
| 343 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 344 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 345 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 346 |
+
} else {
|
| 347 |
+
assert(0);
|
| 348 |
+
}
|
| 349 |
+
#else
|
| 350 |
+
if (id2 == 0) {
|
| 351 |
+
asm volatile(
|
| 352 |
+
"mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
|
| 353 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 354 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 355 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 356 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 357 |
+
} else if (id2 == 1) {
|
| 358 |
+
asm volatile(
|
| 359 |
+
"mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 "
|
| 360 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
|
| 361 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 362 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 363 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 364 |
+
} else {
|
| 365 |
+
assert(0);
|
| 366 |
+
}
|
| 367 |
+
#endif
|
| 368 |
+
|
| 369 |
+
#else
|
| 370 |
+
|
| 371 |
+
CUTLASS_UNUSED(a);
|
| 372 |
+
CUTLASS_UNUSED(b);
|
| 373 |
+
CUTLASS_UNUSED(c);
|
| 374 |
+
CUTLASS_UNUSED(d);
|
| 375 |
+
assert(0);
|
| 376 |
+
#endif
|
| 377 |
+
}
|
| 378 |
+
};
|
| 379 |
+
|
| 380 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 381 |
+
//
|
| 382 |
+
// Sparse Matrix Multiply 16816 - Float TF32
|
| 383 |
+
//
|
| 384 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 385 |
+
|
| 386 |
+
/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
|
| 387 |
+
template <>
|
| 388 |
+
struct SparseMma<gemm::GemmShape<16, 8, 16>, 32, tfloat32_t, layout::RowMajor,
|
| 389 |
+
tfloat32_t, layout::ColumnMajor, float, layout::RowMajor,
|
| 390 |
+
OpMultiplyAdd, SPFormatType::Thread> {
|
| 391 |
+
using Shape = gemm::GemmShape<16, 8, 16>;
|
| 392 |
+
|
| 393 |
+
using ElementA = tfloat32_t;
|
| 394 |
+
using LayoutA = layout::RowMajor;
|
| 395 |
+
using FragmentA = Array<tfloat32_t, 4>;
|
| 396 |
+
|
| 397 |
+
using ElementB = tfloat32_t;
|
| 398 |
+
using LayoutB = layout::ColumnMajor;
|
| 399 |
+
using FragmentB = Array<tfloat32_t, 4>;
|
| 400 |
+
|
| 401 |
+
using ElementC = float;
|
| 402 |
+
using LayoutC = layout::RowMajor;
|
| 403 |
+
using FragmentC = Array<float, 4>;
|
| 404 |
+
|
| 405 |
+
using FragmentE = uint32_t;
|
| 406 |
+
|
| 407 |
+
using Operator = OpMultiplyAdd;
|
| 408 |
+
using ArchTag = arch::Sm80;
|
| 409 |
+
|
| 410 |
+
static int const kSparse = 2;
|
| 411 |
+
|
| 412 |
+
static int const kMetaSizeInBits = 4;
|
| 413 |
+
|
| 414 |
+
static int const kMaxID2 = 2;
|
| 415 |
+
|
| 416 |
+
CUTLASS_HOST_DEVICE
|
| 417 |
+
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
| 418 |
+
FragmentC const &c, uint32_t const &E, int const id2) const {
|
| 419 |
+
|
| 420 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 421 |
+
|
| 422 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 423 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 424 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 425 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 426 |
+
|
| 427 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 428 |
+
if (id2 == 0) {
|
| 429 |
+
asm volatile(
|
| 430 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
|
| 431 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 432 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 433 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 434 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 435 |
+
} else if (id2 == 1) {
|
| 436 |
+
asm volatile(
|
| 437 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
|
| 438 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
|
| 439 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 440 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 441 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 442 |
+
} else {
|
| 443 |
+
assert(0);
|
| 444 |
+
}
|
| 445 |
+
#else
|
| 446 |
+
if (id2 == 0) {
|
| 447 |
+
asm volatile(
|
| 448 |
+
"mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
|
| 449 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 450 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 451 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 452 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 453 |
+
} else if (id2 == 1) {
|
| 454 |
+
asm volatile(
|
| 455 |
+
"mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 "
|
| 456 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n"
|
| 457 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 458 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 459 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 460 |
+
} else {
|
| 461 |
+
assert(0);
|
| 462 |
+
}
|
| 463 |
+
#endif
|
| 464 |
+
|
| 465 |
+
#else
|
| 466 |
+
|
| 467 |
+
CUTLASS_UNUSED(a);
|
| 468 |
+
CUTLASS_UNUSED(b);
|
| 469 |
+
CUTLASS_UNUSED(c);
|
| 470 |
+
CUTLASS_UNUSED(d);
|
| 471 |
+
assert(0);
|
| 472 |
+
#endif
|
| 473 |
+
}
|
| 474 |
+
};
|
| 475 |
+
|
| 476 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 477 |
+
//
|
| 478 |
+
// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE
|
| 479 |
+
//
|
| 480 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 481 |
+
|
| 482 |
+
/// Matrix multiply-add operation: S32 = S8 * S8 + S32
|
| 483 |
+
template <>
|
| 484 |
+
struct SparseMma<
|
| 485 |
+
gemm::GemmShape<16,8,64>,
|
| 486 |
+
32,
|
| 487 |
+
int8_t,
|
| 488 |
+
layout::RowMajor,
|
| 489 |
+
int8_t,
|
| 490 |
+
layout::ColumnMajor,
|
| 491 |
+
int,
|
| 492 |
+
layout::RowMajor,
|
| 493 |
+
OpMultiplyAddSaturate,
|
| 494 |
+
SPFormatType::Thread> {
|
| 495 |
+
|
| 496 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 497 |
+
|
| 498 |
+
using ElementA = int8_t;
|
| 499 |
+
using LayoutA = layout::RowMajor;
|
| 500 |
+
using FragmentA = Array<int8_t, 16>;
|
| 501 |
+
|
| 502 |
+
using ElementB = int8_t;
|
| 503 |
+
using LayoutB = layout::ColumnMajor;
|
| 504 |
+
using FragmentB = Array<int8_t, 16>;
|
| 505 |
+
|
| 506 |
+
using ElementC = int;
|
| 507 |
+
using LayoutC = layout::RowMajor;
|
| 508 |
+
using FragmentC = Array<int, 4>;
|
| 509 |
+
|
| 510 |
+
using FragmentE = uint32_t;
|
| 511 |
+
|
| 512 |
+
using Operator = OpMultiplyAddSaturate;
|
| 513 |
+
using ArchTag = arch::Sm80;
|
| 514 |
+
|
| 515 |
+
static int const kSparse = 2;
|
| 516 |
+
|
| 517 |
+
static int const kMetaSizeInBits = 2;
|
| 518 |
+
|
| 519 |
+
static int const kMaxID2 = 1;
|
| 520 |
+
|
| 521 |
+
/// Computes multiply-add
|
| 522 |
+
CUTLASS_HOST_DEVICE
|
| 523 |
+
void operator()(
|
| 524 |
+
FragmentC &d,
|
| 525 |
+
FragmentA const &a,
|
| 526 |
+
FragmentB const &b,
|
| 527 |
+
FragmentC const &c,
|
| 528 |
+
uint32_t const &E,
|
| 529 |
+
int const id2
|
| 530 |
+
) const {
|
| 531 |
+
|
| 532 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 533 |
+
|
| 534 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 535 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 536 |
+
|
| 537 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 538 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 539 |
+
|
| 540 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 541 |
+
if (id2 == 0) {
|
| 542 |
+
asm volatile(
|
| 543 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 544 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 545 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 546 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 547 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 548 |
+
} else {
|
| 549 |
+
assert(0);
|
| 550 |
+
}
|
| 551 |
+
#else
|
| 552 |
+
if (id2 == 0) {
|
| 553 |
+
asm volatile(
|
| 554 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 555 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 556 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 557 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 558 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 559 |
+
} else {
|
| 560 |
+
assert(0);
|
| 561 |
+
}
|
| 562 |
+
#endif
|
| 563 |
+
|
| 564 |
+
#else
|
| 565 |
+
CUTLASS_UNUSED(a);
|
| 566 |
+
CUTLASS_UNUSED(b);
|
| 567 |
+
CUTLASS_UNUSED(c);
|
| 568 |
+
CUTLASS_UNUSED(d);
|
| 569 |
+
assert(0);
|
| 570 |
+
#endif
|
| 571 |
+
}
|
| 572 |
+
};
|
| 573 |
+
|
| 574 |
+
/// Matrix multiply-add operation: S32 = S8 * U8 + S32
|
| 575 |
+
template <>
|
| 576 |
+
struct SparseMma<
|
| 577 |
+
gemm::GemmShape<16,8,64>,
|
| 578 |
+
32,
|
| 579 |
+
int8_t,
|
| 580 |
+
layout::RowMajor,
|
| 581 |
+
uint8_t,
|
| 582 |
+
layout::ColumnMajor,
|
| 583 |
+
int,
|
| 584 |
+
layout::RowMajor,
|
| 585 |
+
OpMultiplyAddSaturate,
|
| 586 |
+
SPFormatType::Thread> {
|
| 587 |
+
|
| 588 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 589 |
+
|
| 590 |
+
using ElementA = int8_t;
|
| 591 |
+
using LayoutA = layout::RowMajor;
|
| 592 |
+
using FragmentA = Array<int8_t, 16>;
|
| 593 |
+
|
| 594 |
+
using ElementB = uint8_t;
|
| 595 |
+
using LayoutB = layout::ColumnMajor;
|
| 596 |
+
using FragmentB = Array<uint8_t, 16>;
|
| 597 |
+
|
| 598 |
+
using ElementC = int;
|
| 599 |
+
using LayoutC = layout::RowMajor;
|
| 600 |
+
using FragmentC = Array<int, 4>;
|
| 601 |
+
|
| 602 |
+
using FragmentE = uint32_t;
|
| 603 |
+
|
| 604 |
+
using Operator = OpMultiplyAddSaturate;
|
| 605 |
+
using ArchTag = arch::Sm80;
|
| 606 |
+
|
| 607 |
+
static int const kSparse = 2;
|
| 608 |
+
|
| 609 |
+
static int const kMetaSizeInBits = 2;
|
| 610 |
+
|
| 611 |
+
static int const kMaxID2 = 1;
|
| 612 |
+
|
| 613 |
+
/// Computes multiply-add
|
| 614 |
+
CUTLASS_HOST_DEVICE
|
| 615 |
+
void operator()(
|
| 616 |
+
FragmentC &d,
|
| 617 |
+
FragmentA const &a,
|
| 618 |
+
FragmentB const &b,
|
| 619 |
+
FragmentC const &c,
|
| 620 |
+
uint32_t const &E,
|
| 621 |
+
int const id2
|
| 622 |
+
) const {
|
| 623 |
+
|
| 624 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 625 |
+
|
| 626 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 627 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 628 |
+
|
| 629 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 630 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 631 |
+
|
| 632 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 633 |
+
if (id2 == 0) {
|
| 634 |
+
asm volatile(
|
| 635 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 636 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 637 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 638 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 639 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 640 |
+
} else {
|
| 641 |
+
assert(0);
|
| 642 |
+
}
|
| 643 |
+
#else
|
| 644 |
+
if (id2 == 0) {
|
| 645 |
+
asm volatile(
|
| 646 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 647 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 648 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 649 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 650 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 651 |
+
} else {
|
| 652 |
+
assert(0);
|
| 653 |
+
}
|
| 654 |
+
#endif
|
| 655 |
+
|
| 656 |
+
#else
|
| 657 |
+
|
| 658 |
+
CUTLASS_UNUSED(a);
|
| 659 |
+
CUTLASS_UNUSED(b);
|
| 660 |
+
CUTLASS_UNUSED(c);
|
| 661 |
+
CUTLASS_UNUSED(d);
|
| 662 |
+
assert(0);
|
| 663 |
+
#endif
|
| 664 |
+
}
|
| 665 |
+
};
|
| 666 |
+
|
| 667 |
+
/// Matrix multiply-add operation: S32 = U8 * S8 + S32
|
| 668 |
+
template <>
|
| 669 |
+
struct SparseMma<
|
| 670 |
+
gemm::GemmShape<16,8,64>,
|
| 671 |
+
32,
|
| 672 |
+
uint8_t,
|
| 673 |
+
layout::RowMajor,
|
| 674 |
+
int8_t,
|
| 675 |
+
layout::ColumnMajor,
|
| 676 |
+
int,
|
| 677 |
+
layout::RowMajor,
|
| 678 |
+
OpMultiplyAddSaturate,
|
| 679 |
+
SPFormatType::Thread> {
|
| 680 |
+
|
| 681 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 682 |
+
|
| 683 |
+
using ElementA = uint8_t;
|
| 684 |
+
using LayoutA = layout::RowMajor;
|
| 685 |
+
using FragmentA = Array<uint8_t, 16>;
|
| 686 |
+
|
| 687 |
+
using ElementB = int8_t;
|
| 688 |
+
using LayoutB = layout::ColumnMajor;
|
| 689 |
+
using FragmentB = Array<int8_t, 16>;
|
| 690 |
+
|
| 691 |
+
using ElementC = int;
|
| 692 |
+
using LayoutC = layout::RowMajor;
|
| 693 |
+
using FragmentC = Array<int, 4>;
|
| 694 |
+
|
| 695 |
+
using FragmentE = uint32_t;
|
| 696 |
+
|
| 697 |
+
using Operator = OpMultiplyAddSaturate;
|
| 698 |
+
using ArchTag = arch::Sm80;
|
| 699 |
+
|
| 700 |
+
static int const kSparse = 2;
|
| 701 |
+
|
| 702 |
+
static int const kMetaSizeInBits = 2;
|
| 703 |
+
|
| 704 |
+
static int const kMaxID2 = 1;
|
| 705 |
+
|
| 706 |
+
/// Computes multiply-add
|
| 707 |
+
CUTLASS_HOST_DEVICE
|
| 708 |
+
void operator()(
|
| 709 |
+
FragmentC &d,
|
| 710 |
+
FragmentA const &a,
|
| 711 |
+
FragmentB const &b,
|
| 712 |
+
FragmentC const &c,
|
| 713 |
+
uint32_t const &E,
|
| 714 |
+
int const id2
|
| 715 |
+
) const {
|
| 716 |
+
|
| 717 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 718 |
+
|
| 719 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 720 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 721 |
+
|
| 722 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 723 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 724 |
+
|
| 725 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 726 |
+
if (id2 == 0) {
|
| 727 |
+
asm volatile(
|
| 728 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 729 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 730 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 731 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 732 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 733 |
+
} else {
|
| 734 |
+
assert(0);
|
| 735 |
+
}
|
| 736 |
+
#else
|
| 737 |
+
if (id2 == 0) {
|
| 738 |
+
asm volatile(
|
| 739 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 740 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 741 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 742 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 743 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 744 |
+
} else {
|
| 745 |
+
assert(0);
|
| 746 |
+
}
|
| 747 |
+
#endif
|
| 748 |
+
|
| 749 |
+
#else
|
| 750 |
+
CUTLASS_UNUSED(a);
|
| 751 |
+
CUTLASS_UNUSED(b);
|
| 752 |
+
CUTLASS_UNUSED(c);
|
| 753 |
+
CUTLASS_UNUSED(d);
|
| 754 |
+
assert(0);
|
| 755 |
+
#endif
|
| 756 |
+
}
|
| 757 |
+
};
|
| 758 |
+
|
| 759 |
+
/// Matrix multiply-add operation: S32 = U8 * U8 + S32
|
| 760 |
+
template <>
|
| 761 |
+
struct SparseMma<
|
| 762 |
+
gemm::GemmShape<16,8,64>,
|
| 763 |
+
32,
|
| 764 |
+
uint8_t,
|
| 765 |
+
layout::RowMajor,
|
| 766 |
+
uint8_t,
|
| 767 |
+
layout::ColumnMajor,
|
| 768 |
+
int,
|
| 769 |
+
layout::RowMajor,
|
| 770 |
+
OpMultiplyAddSaturate,
|
| 771 |
+
SPFormatType::Thread> {
|
| 772 |
+
|
| 773 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 774 |
+
|
| 775 |
+
using ElementA = uint8_t;
|
| 776 |
+
using LayoutA = layout::RowMajor;
|
| 777 |
+
using FragmentA = Array<uint8_t, 16>;
|
| 778 |
+
|
| 779 |
+
using ElementB = uint8_t;
|
| 780 |
+
using LayoutB = layout::ColumnMajor;
|
| 781 |
+
using FragmentB = Array<uint8_t, 16>;
|
| 782 |
+
|
| 783 |
+
using ElementC = int;
|
| 784 |
+
using LayoutC = layout::RowMajor;
|
| 785 |
+
using FragmentC = Array<int, 4>;
|
| 786 |
+
|
| 787 |
+
using FragmentE = uint32_t;
|
| 788 |
+
|
| 789 |
+
using Operator = OpMultiplyAddSaturate;
|
| 790 |
+
using ArchTag = arch::Sm80;
|
| 791 |
+
|
| 792 |
+
static int const kSparse = 2;
|
| 793 |
+
|
| 794 |
+
static int const kMetaSizeInBits = 2;
|
| 795 |
+
|
| 796 |
+
static int const kMaxID2 = 1;
|
| 797 |
+
|
| 798 |
+
/// Computes multiply-add
|
| 799 |
+
CUTLASS_HOST_DEVICE
|
| 800 |
+
void operator()(
|
| 801 |
+
FragmentC &d,
|
| 802 |
+
FragmentA const &a,
|
| 803 |
+
FragmentB const &b,
|
| 804 |
+
FragmentC const &c,
|
| 805 |
+
uint32_t const &E,
|
| 806 |
+
int const id2
|
| 807 |
+
) const {
|
| 808 |
+
|
| 809 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 810 |
+
|
| 811 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 812 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 813 |
+
|
| 814 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 815 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 816 |
+
|
| 817 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 818 |
+
if (id2 == 0) {
|
| 819 |
+
asm volatile(
|
| 820 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 821 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 822 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 823 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 824 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 825 |
+
} else {
|
| 826 |
+
assert(0);
|
| 827 |
+
}
|
| 828 |
+
#else
|
| 829 |
+
if (id2 == 0) {
|
| 830 |
+
asm volatile(
|
| 831 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 832 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 833 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 834 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 835 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 836 |
+
} else {
|
| 837 |
+
assert(0);
|
| 838 |
+
}
|
| 839 |
+
#endif
|
| 840 |
+
|
| 841 |
+
#else
|
| 842 |
+
CUTLASS_UNUSED(a);
|
| 843 |
+
CUTLASS_UNUSED(b);
|
| 844 |
+
CUTLASS_UNUSED(c);
|
| 845 |
+
CUTLASS_UNUSED(d);
|
| 846 |
+
assert(0);
|
| 847 |
+
#endif
|
| 848 |
+
}
|
| 849 |
+
};
|
| 850 |
+
|
| 851 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 852 |
+
//
|
| 853 |
+
// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE
|
| 854 |
+
//
|
| 855 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 856 |
+
|
| 857 |
+
/// Matrix multiply-add operation: S32 = S4 * S4 + S32
|
| 858 |
+
template <>
|
| 859 |
+
struct SparseMma<
|
| 860 |
+
gemm::GemmShape<16,8,128>,
|
| 861 |
+
32,
|
| 862 |
+
cutlass::int4b_t,
|
| 863 |
+
layout::RowMajor,
|
| 864 |
+
cutlass::int4b_t,
|
| 865 |
+
layout::ColumnMajor,
|
| 866 |
+
int,
|
| 867 |
+
layout::RowMajor,
|
| 868 |
+
OpMultiplyAddSaturate,
|
| 869 |
+
SPFormatType::Thread> {
|
| 870 |
+
|
| 871 |
+
using Shape = gemm::GemmShape<16,8,128>;
|
| 872 |
+
|
| 873 |
+
using ElementA = cutlass::int4b_t;
|
| 874 |
+
using LayoutA = layout::RowMajor;
|
| 875 |
+
using FragmentA = Array<cutlass::int4b_t, 32>;
|
| 876 |
+
|
| 877 |
+
using ElementB = cutlass::int4b_t;
|
| 878 |
+
using LayoutB = layout::ColumnMajor;
|
| 879 |
+
using FragmentB = Array<cutlass::int4b_t, 32>;
|
| 880 |
+
|
| 881 |
+
using ElementC = int;
|
| 882 |
+
using LayoutC = layout::RowMajor;
|
| 883 |
+
using FragmentC = Array<int, 4>;
|
| 884 |
+
|
| 885 |
+
using FragmentE = uint32_t;
|
| 886 |
+
|
| 887 |
+
using Operator = OpMultiplyAddSaturate;
|
| 888 |
+
using ArchTag = arch::Sm80;
|
| 889 |
+
|
| 890 |
+
static int const kSparse = 2;
|
| 891 |
+
|
| 892 |
+
static int const kMetaSizeInBits = 2;
|
| 893 |
+
|
| 894 |
+
static int const kMaxID2 = 1;
|
| 895 |
+
|
| 896 |
+
/// Computes multiply-add
|
| 897 |
+
CUTLASS_HOST_DEVICE
|
| 898 |
+
void operator()(
|
| 899 |
+
FragmentC &d,
|
| 900 |
+
FragmentA const &a,
|
| 901 |
+
FragmentB const &b,
|
| 902 |
+
FragmentC const &c,
|
| 903 |
+
uint32_t const &E,
|
| 904 |
+
int const id2
|
| 905 |
+
) const {
|
| 906 |
+
|
| 907 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 908 |
+
|
| 909 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 910 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 911 |
+
|
| 912 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 913 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 914 |
+
|
| 915 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 916 |
+
if (id2 == 0) {
|
| 917 |
+
asm volatile(
|
| 918 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 919 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 920 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 921 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 922 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 923 |
+
} else {
|
| 924 |
+
assert(0);
|
| 925 |
+
}
|
| 926 |
+
#else
|
| 927 |
+
if (id2 == 0) {
|
| 928 |
+
asm volatile(
|
| 929 |
+
"mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 930 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 931 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 932 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 933 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 934 |
+
} else {
|
| 935 |
+
assert(0);
|
| 936 |
+
}
|
| 937 |
+
#endif
|
| 938 |
+
|
| 939 |
+
#else
|
| 940 |
+
|
| 941 |
+
CUTLASS_UNUSED(a);
|
| 942 |
+
CUTLASS_UNUSED(b);
|
| 943 |
+
CUTLASS_UNUSED(c);
|
| 944 |
+
CUTLASS_UNUSED(d);
|
| 945 |
+
assert(0);
|
| 946 |
+
#endif
|
| 947 |
+
}
|
| 948 |
+
};
|
| 949 |
+
|
| 950 |
+
/// Matrix multiply-add operation: S32 = S4 * U4 + S32
|
| 951 |
+
template <>
|
| 952 |
+
struct SparseMma<
|
| 953 |
+
gemm::GemmShape<16,8,128>,
|
| 954 |
+
32,
|
| 955 |
+
cutlass::int4b_t,
|
| 956 |
+
layout::RowMajor,
|
| 957 |
+
cutlass::uint4b_t,
|
| 958 |
+
layout::ColumnMajor,
|
| 959 |
+
int,
|
| 960 |
+
layout::RowMajor,
|
| 961 |
+
OpMultiplyAddSaturate,
|
| 962 |
+
SPFormatType::Thread> {
|
| 963 |
+
|
| 964 |
+
using Shape = gemm::GemmShape<16,8,128>;
|
| 965 |
+
|
| 966 |
+
using ElementA = cutlass::int4b_t;
|
| 967 |
+
using LayoutA = layout::RowMajor;
|
| 968 |
+
using FragmentA = Array<cutlass::int4b_t, 32>;
|
| 969 |
+
|
| 970 |
+
using ElementB = cutlass::uint4b_t;
|
| 971 |
+
using LayoutB = layout::ColumnMajor;
|
| 972 |
+
using FragmentB = Array<cutlass::uint4b_t, 32>;
|
| 973 |
+
|
| 974 |
+
using ElementC = int;
|
| 975 |
+
using LayoutC = layout::RowMajor;
|
| 976 |
+
using FragmentC = Array<int, 4>;
|
| 977 |
+
|
| 978 |
+
using FragmentE = uint32_t;
|
| 979 |
+
|
| 980 |
+
using Operator = OpMultiplyAddSaturate;
|
| 981 |
+
using ArchTag = arch::Sm80;
|
| 982 |
+
|
| 983 |
+
static int const kSparse = 2;
|
| 984 |
+
|
| 985 |
+
static int const kMetaSizeInBits = 2;
|
| 986 |
+
|
| 987 |
+
static int const kMaxID2 = 1;
|
| 988 |
+
|
| 989 |
+
/// Computes multiply-add
|
| 990 |
+
CUTLASS_HOST_DEVICE
|
| 991 |
+
void operator()(
|
| 992 |
+
FragmentC &d,
|
| 993 |
+
FragmentA const &a,
|
| 994 |
+
FragmentB const &b,
|
| 995 |
+
FragmentC const &c,
|
| 996 |
+
uint32_t const &E,
|
| 997 |
+
int const id2
|
| 998 |
+
) const {
|
| 999 |
+
|
| 1000 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 1001 |
+
|
| 1002 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1003 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1004 |
+
|
| 1005 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1006 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1007 |
+
|
| 1008 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 1009 |
+
if (id2 == 0) {
|
| 1010 |
+
asm volatile(
|
| 1011 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 1012 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 1013 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1014 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 1015 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 1016 |
+
} else {
|
| 1017 |
+
assert(0);
|
| 1018 |
+
}
|
| 1019 |
+
#else
|
| 1020 |
+
if (id2 == 0) {
|
| 1021 |
+
asm volatile(
|
| 1022 |
+
"mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 1023 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 1024 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1025 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 1026 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 1027 |
+
} else {
|
| 1028 |
+
assert(0);
|
| 1029 |
+
}
|
| 1030 |
+
#endif
|
| 1031 |
+
|
| 1032 |
+
#else
|
| 1033 |
+
|
| 1034 |
+
CUTLASS_UNUSED(a);
|
| 1035 |
+
CUTLASS_UNUSED(b);
|
| 1036 |
+
CUTLASS_UNUSED(c);
|
| 1037 |
+
CUTLASS_UNUSED(d);
|
| 1038 |
+
assert(0);
|
| 1039 |
+
#endif
|
| 1040 |
+
}
|
| 1041 |
+
};
|
| 1042 |
+
|
| 1043 |
+
/// Matrix multiply-add operation: S32 = U4 * S4 + S32
|
| 1044 |
+
template <>
|
| 1045 |
+
struct SparseMma<
|
| 1046 |
+
gemm::GemmShape<16,8,128>,
|
| 1047 |
+
32,
|
| 1048 |
+
cutlass::uint4b_t,
|
| 1049 |
+
layout::RowMajor,
|
| 1050 |
+
cutlass::int4b_t,
|
| 1051 |
+
layout::ColumnMajor,
|
| 1052 |
+
int,
|
| 1053 |
+
layout::RowMajor,
|
| 1054 |
+
OpMultiplyAddSaturate,
|
| 1055 |
+
SPFormatType::Thread> {
|
| 1056 |
+
|
| 1057 |
+
using Shape = gemm::GemmShape<16,8,128>;
|
| 1058 |
+
|
| 1059 |
+
using ElementA = cutlass::uint4b_t;
|
| 1060 |
+
using LayoutA = layout::RowMajor;
|
| 1061 |
+
using FragmentA = Array<cutlass::uint4b_t, 32>;
|
| 1062 |
+
|
| 1063 |
+
using ElementB = cutlass::int4b_t;
|
| 1064 |
+
using LayoutB = layout::ColumnMajor;
|
| 1065 |
+
using FragmentB = Array<cutlass::int4b_t, 32>;
|
| 1066 |
+
|
| 1067 |
+
using ElementC = int;
|
| 1068 |
+
using LayoutC = layout::RowMajor;
|
| 1069 |
+
using FragmentC = Array<int, 4>;
|
| 1070 |
+
|
| 1071 |
+
using FragmentE = uint32_t;
|
| 1072 |
+
|
| 1073 |
+
using Operator = OpMultiplyAddSaturate;
|
| 1074 |
+
using ArchTag = arch::Sm80;
|
| 1075 |
+
|
| 1076 |
+
static int const kSparse = 2;
|
| 1077 |
+
|
| 1078 |
+
static int const kMetaSizeInBits = 2;
|
| 1079 |
+
|
| 1080 |
+
static int const kMaxID2 = 1;
|
| 1081 |
+
|
| 1082 |
+
/// Computes multiply-add
|
| 1083 |
+
CUTLASS_HOST_DEVICE
|
| 1084 |
+
void operator()(
|
| 1085 |
+
FragmentC &d,
|
| 1086 |
+
FragmentA const &a,
|
| 1087 |
+
FragmentB const &b,
|
| 1088 |
+
FragmentC const &c,
|
| 1089 |
+
uint32_t const &E,
|
| 1090 |
+
int const id2
|
| 1091 |
+
) const {
|
| 1092 |
+
|
| 1093 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 1094 |
+
|
| 1095 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1096 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1097 |
+
|
| 1098 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1099 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1100 |
+
|
| 1101 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 1102 |
+
if (id2 == 0) {
|
| 1103 |
+
asm volatile(
|
| 1104 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 1105 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 1106 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1107 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 1108 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 1109 |
+
} else {
|
| 1110 |
+
assert(0);
|
| 1111 |
+
}
|
| 1112 |
+
#else
|
| 1113 |
+
if (id2 == 0) {
|
| 1114 |
+
asm volatile(
|
| 1115 |
+
"mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 1116 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 1117 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1118 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 1119 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 1120 |
+
} else {
|
| 1121 |
+
assert(0);
|
| 1122 |
+
}
|
| 1123 |
+
#endif
|
| 1124 |
+
|
| 1125 |
+
#else
|
| 1126 |
+
|
| 1127 |
+
CUTLASS_UNUSED(a);
|
| 1128 |
+
CUTLASS_UNUSED(b);
|
| 1129 |
+
CUTLASS_UNUSED(c);
|
| 1130 |
+
CUTLASS_UNUSED(d);
|
| 1131 |
+
assert(0);
|
| 1132 |
+
#endif
|
| 1133 |
+
}
|
| 1134 |
+
};
|
| 1135 |
+
|
| 1136 |
+
/// Matrix multiply-add operation: S32 = U4 * U4 + S32
|
| 1137 |
+
template <>
|
| 1138 |
+
struct SparseMma<
|
| 1139 |
+
gemm::GemmShape<16,8,128>,
|
| 1140 |
+
32,
|
| 1141 |
+
cutlass::uint4b_t,
|
| 1142 |
+
layout::RowMajor,
|
| 1143 |
+
cutlass::uint4b_t,
|
| 1144 |
+
layout::ColumnMajor,
|
| 1145 |
+
int,
|
| 1146 |
+
layout::RowMajor,
|
| 1147 |
+
OpMultiplyAddSaturate,
|
| 1148 |
+
SPFormatType::Thread> {
|
| 1149 |
+
|
| 1150 |
+
using Shape = gemm::GemmShape<16,8,128>;
|
| 1151 |
+
|
| 1152 |
+
using ElementA = cutlass::uint4b_t;
|
| 1153 |
+
using LayoutA = layout::RowMajor;
|
| 1154 |
+
using FragmentA = Array<cutlass::uint4b_t, 32>;
|
| 1155 |
+
|
| 1156 |
+
using ElementB = cutlass::uint4b_t;
|
| 1157 |
+
using LayoutB = layout::ColumnMajor;
|
| 1158 |
+
using FragmentB = Array<cutlass::uint4b_t, 32>;
|
| 1159 |
+
|
| 1160 |
+
using ElementC = int;
|
| 1161 |
+
using LayoutC = layout::RowMajor;
|
| 1162 |
+
using FragmentC = Array<int, 4>;
|
| 1163 |
+
|
| 1164 |
+
using FragmentE = uint32_t;
|
| 1165 |
+
|
| 1166 |
+
using Operator = OpMultiplyAddSaturate;
|
| 1167 |
+
using ArchTag = arch::Sm80;
|
| 1168 |
+
|
| 1169 |
+
static int const kSparse = 2;
|
| 1170 |
+
|
| 1171 |
+
static int const kMetaSizeInBits = 2;
|
| 1172 |
+
|
| 1173 |
+
static int const kMaxID2 = 1;
|
| 1174 |
+
|
| 1175 |
+
/// Computes multiply-add
|
| 1176 |
+
CUTLASS_HOST_DEVICE
|
| 1177 |
+
void operator()(
|
| 1178 |
+
FragmentC &d,
|
| 1179 |
+
FragmentA const &a,
|
| 1180 |
+
FragmentB const &b,
|
| 1181 |
+
FragmentC const &c,
|
| 1182 |
+
uint32_t const &E,
|
| 1183 |
+
int const id2
|
| 1184 |
+
) const {
|
| 1185 |
+
|
| 1186 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED)
|
| 1187 |
+
|
| 1188 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 1189 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 1190 |
+
|
| 1191 |
+
int const *C = reinterpret_cast<int const *>(&c);
|
| 1192 |
+
int *D = reinterpret_cast<int *>(&d);
|
| 1193 |
+
|
| 1194 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5))
|
| 1195 |
+
if (id2 == 0) {
|
| 1196 |
+
asm volatile(
|
| 1197 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 1198 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 1199 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1200 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 1201 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 1202 |
+
} else {
|
| 1203 |
+
assert(0);
|
| 1204 |
+
}
|
| 1205 |
+
#else
|
| 1206 |
+
if (id2 == 0) {
|
| 1207 |
+
asm volatile(
|
| 1208 |
+
"mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 1209 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 1210 |
+
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
| 1211 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 1212 |
+
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E));
|
| 1213 |
+
} else {
|
| 1214 |
+
assert(0);
|
| 1215 |
+
}
|
| 1216 |
+
#endif
|
| 1217 |
+
|
| 1218 |
+
#else
|
| 1219 |
+
|
| 1220 |
+
CUTLASS_UNUSED(a);
|
| 1221 |
+
CUTLASS_UNUSED(b);
|
| 1222 |
+
CUTLASS_UNUSED(c);
|
| 1223 |
+
CUTLASS_UNUSED(d);
|
| 1224 |
+
assert(0);
|
| 1225 |
+
#endif
|
| 1226 |
+
}
|
| 1227 |
+
};
|
| 1228 |
+
|
| 1229 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1230 |
+
|
| 1231 |
+
} // namespace arch
|
| 1232 |
+
} // namespace cutlass
|
| 1233 |
+
|
| 1234 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Sparse matrix multiply accumulate for SM89
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include CUDA_STD_HEADER(cassert)
|
| 39 |
+
|
| 40 |
+
#include "mma.h"
|
| 41 |
+
#include "cutlass/layout/matrix.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)
|
| 47 |
+
# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
| 51 |
+
# if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED)
|
| 52 |
+
# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED
|
| 53 |
+
# endif
|
| 54 |
+
#endif
|
| 55 |
+
|
| 56 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
namespace cutlass {
|
| 59 |
+
namespace arch {
|
| 60 |
+
|
| 61 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
/// Matrix multiply-add operation: F32 = fe4m3 * fe4m3 + F32
|
| 64 |
+
template <typename Operator_>
|
| 65 |
+
struct SparseMma<
|
| 66 |
+
gemm::GemmShape<16,8,64>,
|
| 67 |
+
32,
|
| 68 |
+
cutlass::float_e4m3_t,
|
| 69 |
+
layout::RowMajor,
|
| 70 |
+
cutlass::float_e4m3_t,
|
| 71 |
+
layout::ColumnMajor,
|
| 72 |
+
float,
|
| 73 |
+
layout::RowMajor,
|
| 74 |
+
Operator_,
|
| 75 |
+
SPFormatType::Thread> {
|
| 76 |
+
|
| 77 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 78 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 79 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 80 |
+
|
| 81 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 82 |
+
|
| 83 |
+
using ElementA = cutlass::float_e4m3_t;
|
| 84 |
+
using LayoutA = layout::RowMajor;
|
| 85 |
+
using FragmentA = Array<ElementA, 16>;
|
| 86 |
+
|
| 87 |
+
using ElementB = cutlass::float_e4m3_t;
|
| 88 |
+
using LayoutB = layout::ColumnMajor;
|
| 89 |
+
using FragmentB = Array<ElementB, 16>;
|
| 90 |
+
|
| 91 |
+
using ElementC = float;
|
| 92 |
+
using LayoutC = layout::RowMajor;
|
| 93 |
+
using FragmentC = Array<ElementC, 4>;
|
| 94 |
+
|
| 95 |
+
using FragmentE = uint32_t;
|
| 96 |
+
|
| 97 |
+
using Operator = Operator_;
|
| 98 |
+
using ArchTag = arch::Sm89;
|
| 99 |
+
|
| 100 |
+
static int const kSparse = 2;
|
| 101 |
+
|
| 102 |
+
static int const kMetaSizeInBits = 2;
|
| 103 |
+
|
| 104 |
+
static int const kMaxID2 = 1;
|
| 105 |
+
|
| 106 |
+
/// Computes multiply-add
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
void operator()(
|
| 109 |
+
FragmentC &d,
|
| 110 |
+
FragmentA const &a,
|
| 111 |
+
FragmentB const &b,
|
| 112 |
+
FragmentC const &c,
|
| 113 |
+
uint32_t const &E,
|
| 114 |
+
int const id2
|
| 115 |
+
) const {
|
| 116 |
+
|
| 117 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
|
| 118 |
+
|
| 119 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 120 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 121 |
+
|
| 122 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 123 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 124 |
+
|
| 125 |
+
if (id2 == 0) {
|
| 126 |
+
asm volatile(
|
| 127 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 128 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 129 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 130 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 131 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 132 |
+
}
|
| 133 |
+
else {
|
| 134 |
+
assert(0);
|
| 135 |
+
}
|
| 136 |
+
#else
|
| 137 |
+
CUTLASS_UNUSED(a);
|
| 138 |
+
CUTLASS_UNUSED(b);
|
| 139 |
+
CUTLASS_UNUSED(c);
|
| 140 |
+
CUTLASS_UNUSED(d);
|
| 141 |
+
assert(0);
|
| 142 |
+
#endif
|
| 143 |
+
}
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 147 |
+
|
| 148 |
+
/// Matrix multiply-add operation: F32 = fe4m3 * fe5m2 + F32
|
| 149 |
+
template <typename Operator_>
|
| 150 |
+
struct SparseMma<
|
| 151 |
+
gemm::GemmShape<16,8,64>,
|
| 152 |
+
32,
|
| 153 |
+
cutlass::float_e4m3_t,
|
| 154 |
+
layout::RowMajor,
|
| 155 |
+
cutlass::float_e5m2_t,
|
| 156 |
+
layout::ColumnMajor,
|
| 157 |
+
float,
|
| 158 |
+
layout::RowMajor,
|
| 159 |
+
Operator_,
|
| 160 |
+
SPFormatType::Thread> {
|
| 161 |
+
|
| 162 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 163 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 164 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 165 |
+
|
| 166 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 167 |
+
|
| 168 |
+
using ElementA = cutlass::float_e4m3_t;
|
| 169 |
+
using LayoutA = layout::RowMajor;
|
| 170 |
+
using FragmentA = Array<ElementA, 16>;
|
| 171 |
+
|
| 172 |
+
using ElementB = cutlass::float_e5m2_t;
|
| 173 |
+
using LayoutB = layout::ColumnMajor;
|
| 174 |
+
using FragmentB = Array<ElementB, 16>;
|
| 175 |
+
|
| 176 |
+
using ElementC = float;
|
| 177 |
+
using LayoutC = layout::RowMajor;
|
| 178 |
+
using FragmentC = Array<ElementC, 4>;
|
| 179 |
+
|
| 180 |
+
using FragmentE = uint32_t;
|
| 181 |
+
|
| 182 |
+
using Operator = Operator_;
|
| 183 |
+
using ArchTag = arch::Sm89;
|
| 184 |
+
|
| 185 |
+
static int const kSparse = 2;
|
| 186 |
+
|
| 187 |
+
static int const kMetaSizeInBits = 2;
|
| 188 |
+
|
| 189 |
+
static int const kMaxID2 = 1;
|
| 190 |
+
|
| 191 |
+
/// Computes multiply-add
|
| 192 |
+
CUTLASS_HOST_DEVICE
|
| 193 |
+
void operator()(
|
| 194 |
+
FragmentC &d,
|
| 195 |
+
FragmentA const &a,
|
| 196 |
+
FragmentB const &b,
|
| 197 |
+
FragmentC const &c,
|
| 198 |
+
uint32_t const &E,
|
| 199 |
+
int const id2
|
| 200 |
+
) const {
|
| 201 |
+
|
| 202 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
|
| 203 |
+
|
| 204 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 205 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 206 |
+
|
| 207 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 208 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 209 |
+
|
| 210 |
+
if (id2 == 0) {
|
| 211 |
+
asm volatile(
|
| 212 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 213 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 214 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 215 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 216 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 217 |
+
}
|
| 218 |
+
else {
|
| 219 |
+
assert(0);
|
| 220 |
+
}
|
| 221 |
+
#else
|
| 222 |
+
CUTLASS_UNUSED(a);
|
| 223 |
+
CUTLASS_UNUSED(b);
|
| 224 |
+
CUTLASS_UNUSED(c);
|
| 225 |
+
CUTLASS_UNUSED(d);
|
| 226 |
+
assert(0);
|
| 227 |
+
#endif
|
| 228 |
+
}
|
| 229 |
+
};
|
| 230 |
+
|
| 231 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 232 |
+
|
| 233 |
+
/// Matrix multiply-add operation: F32 = fe5m2 * fe4m3 + F32
|
| 234 |
+
template <typename Operator_>
|
| 235 |
+
struct SparseMma<
|
| 236 |
+
gemm::GemmShape<16,8,64>,
|
| 237 |
+
32,
|
| 238 |
+
cutlass::float_e5m2_t,
|
| 239 |
+
layout::RowMajor,
|
| 240 |
+
cutlass::float_e4m3_t,
|
| 241 |
+
layout::ColumnMajor,
|
| 242 |
+
float,
|
| 243 |
+
layout::RowMajor,
|
| 244 |
+
Operator_,
|
| 245 |
+
SPFormatType::Thread> {
|
| 246 |
+
|
| 247 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 248 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 249 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 250 |
+
|
| 251 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 252 |
+
|
| 253 |
+
using ElementA = cutlass::float_e5m2_t;
|
| 254 |
+
using LayoutA = layout::RowMajor;
|
| 255 |
+
using FragmentA = Array<ElementA, 16>;
|
| 256 |
+
|
| 257 |
+
using ElementB = cutlass::float_e4m3_t;
|
| 258 |
+
using LayoutB = layout::ColumnMajor;
|
| 259 |
+
using FragmentB = Array<ElementB, 16>;
|
| 260 |
+
|
| 261 |
+
using ElementC = float;
|
| 262 |
+
using LayoutC = layout::RowMajor;
|
| 263 |
+
using FragmentC = Array<ElementC, 4>;
|
| 264 |
+
|
| 265 |
+
using FragmentE = uint32_t;
|
| 266 |
+
|
| 267 |
+
using Operator = Operator_;
|
| 268 |
+
using ArchTag = arch::Sm89;
|
| 269 |
+
|
| 270 |
+
static int const kSparse = 2;
|
| 271 |
+
|
| 272 |
+
static int const kMetaSizeInBits = 2;
|
| 273 |
+
|
| 274 |
+
static int const kMaxID2 = 1;
|
| 275 |
+
|
| 276 |
+
/// Computes multiply-add
|
| 277 |
+
CUTLASS_HOST_DEVICE
|
| 278 |
+
void operator()(
|
| 279 |
+
FragmentC &d,
|
| 280 |
+
FragmentA const &a,
|
| 281 |
+
FragmentB const &b,
|
| 282 |
+
FragmentC const &c,
|
| 283 |
+
uint32_t const &E,
|
| 284 |
+
int const id2
|
| 285 |
+
) const {
|
| 286 |
+
|
| 287 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
|
| 288 |
+
|
| 289 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 290 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 291 |
+
|
| 292 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 293 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 294 |
+
|
| 295 |
+
if (id2 == 0) {
|
| 296 |
+
asm volatile(
|
| 297 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 298 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 299 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 300 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 301 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 302 |
+
}
|
| 303 |
+
else {
|
| 304 |
+
assert(0);
|
| 305 |
+
}
|
| 306 |
+
#else
|
| 307 |
+
CUTLASS_UNUSED(a);
|
| 308 |
+
CUTLASS_UNUSED(b);
|
| 309 |
+
CUTLASS_UNUSED(c);
|
| 310 |
+
CUTLASS_UNUSED(d);
|
| 311 |
+
assert(0);
|
| 312 |
+
#endif
|
| 313 |
+
}
|
| 314 |
+
};
|
| 315 |
+
|
| 316 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 317 |
+
|
| 318 |
+
/// Matrix multiply-add operation: F32 = fe5m2 * fe5m2 + F32
|
| 319 |
+
template <typename Operator_>
|
| 320 |
+
struct SparseMma<
|
| 321 |
+
gemm::GemmShape<16,8,64>,
|
| 322 |
+
32,
|
| 323 |
+
cutlass::float_e5m2_t,
|
| 324 |
+
layout::RowMajor,
|
| 325 |
+
cutlass::float_e5m2_t,
|
| 326 |
+
layout::ColumnMajor,
|
| 327 |
+
float,
|
| 328 |
+
layout::RowMajor,
|
| 329 |
+
Operator_,
|
| 330 |
+
SPFormatType::Thread> {
|
| 331 |
+
|
| 332 |
+
static_assert(platform::is_same<Operator_, OpMultiplyAdd>::value ||
|
| 333 |
+
platform::is_same<Operator_, OpMultiplyAddFastAccum>::value,
|
| 334 |
+
"Invalid operator for SM89 FP8 instruction");
|
| 335 |
+
|
| 336 |
+
using Shape = gemm::GemmShape<16,8,64>;
|
| 337 |
+
|
| 338 |
+
using ElementA = cutlass::float_e5m2_t;
|
| 339 |
+
using LayoutA = layout::RowMajor;
|
| 340 |
+
using FragmentA = Array<ElementA, 16>;
|
| 341 |
+
|
| 342 |
+
using ElementB = cutlass::float_e5m2_t;
|
| 343 |
+
using LayoutB = layout::ColumnMajor;
|
| 344 |
+
using FragmentB = Array<ElementB, 16>;
|
| 345 |
+
|
| 346 |
+
using ElementC = float;
|
| 347 |
+
using LayoutC = layout::RowMajor;
|
| 348 |
+
using FragmentC = Array<ElementC, 4>;
|
| 349 |
+
|
| 350 |
+
using FragmentE = uint32_t;
|
| 351 |
+
|
| 352 |
+
using Operator = Operator_;
|
| 353 |
+
using ArchTag = arch::Sm89;
|
| 354 |
+
|
| 355 |
+
static int const kSparse = 2;
|
| 356 |
+
|
| 357 |
+
static int const kMetaSizeInBits = 2;
|
| 358 |
+
|
| 359 |
+
static int const kMaxID2 = 1;
|
| 360 |
+
|
| 361 |
+
/// Computes multiply-add
|
| 362 |
+
CUTLASS_HOST_DEVICE
|
| 363 |
+
void operator()(
|
| 364 |
+
FragmentC &d,
|
| 365 |
+
FragmentA const &a,
|
| 366 |
+
FragmentB const &b,
|
| 367 |
+
FragmentC const &c,
|
| 368 |
+
uint32_t const &E,
|
| 369 |
+
int const id2
|
| 370 |
+
) const {
|
| 371 |
+
|
| 372 |
+
#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED)
|
| 373 |
+
|
| 374 |
+
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
|
| 375 |
+
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
|
| 376 |
+
|
| 377 |
+
float const *C = reinterpret_cast<float const *>(&c);
|
| 378 |
+
float *D = reinterpret_cast<float *>(&d);
|
| 379 |
+
|
| 380 |
+
if (id2 == 0) {
|
| 381 |
+
asm volatile(
|
| 382 |
+
"mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, "
|
| 383 |
+
"{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n"
|
| 384 |
+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
| 385 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
|
| 386 |
+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E));
|
| 387 |
+
}
|
| 388 |
+
else {
|
| 389 |
+
assert(0);
|
| 390 |
+
}
|
| 391 |
+
#else
|
| 392 |
+
CUTLASS_UNUSED(a);
|
| 393 |
+
CUTLASS_UNUSED(b);
|
| 394 |
+
CUTLASS_UNUSED(c);
|
| 395 |
+
CUTLASS_UNUSED(d);
|
| 396 |
+
assert(0);
|
| 397 |
+
#endif
|
| 398 |
+
}
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 402 |
+
|
| 403 |
+
} // namespace arch
|
| 404 |
+
} // namespace cutlass
|
| 405 |
+
|
| 406 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief PTX for CTA Reconfiguration
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#if defined(__CUDACC_RTC__)
|
| 40 |
+
#include <cuda/std/cstdint>
|
| 41 |
+
#else
|
| 42 |
+
#include <cstdint>
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
#ifndef CUDA_CTA_RECONFIG_ACTIVATED
|
| 46 |
+
#if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \
|
| 47 |
+
(__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \
|
| 48 |
+
|| (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \
|
| 49 |
+
|| (__CUDA_ARCH__ == 1010 && defined(__CUDA_ARCH_FEAT_SM101_ALL)) \
|
| 50 |
+
|| (__CUDA_ARCH__ == 1030 && defined(__CUDA_ARCH_FEAT_SM103_ALL)) \
|
| 51 |
+
|| (__CUDA_ARCH__ == 1200 && defined(__CUDA_ARCH_FEAT_SM120_ALL)) \
|
| 52 |
+
|| (__CUDA_ARCH__ == 1210 && defined(__CUDA_ARCH_FEAT_SM121_ALL)) \
|
| 53 |
+
)
|
| 54 |
+
#define CUDA_CTA_RECONFIG_ACTIVATED 1
|
| 55 |
+
#endif
|
| 56 |
+
|
| 57 |
+
#if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \
|
| 58 |
+
(__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \
|
| 59 |
+
|| (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \
|
| 60 |
+
|| (__CUDA_ARCH__ == 1030 && CUDA_ARCH_FAMILY(1030)) \
|
| 61 |
+
|| (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \
|
| 62 |
+
|| (__CUDA_ARCH__ == 1210 && CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) \
|
| 63 |
+
)
|
| 64 |
+
#define CUDA_CTA_RECONFIG_ACTIVATED 1
|
| 65 |
+
#endif
|
| 66 |
+
|
| 67 |
+
#endif
|
| 68 |
+
|
| 69 |
+
namespace cutlass {
|
| 70 |
+
namespace arch {
|
| 71 |
+
|
| 72 |
+
template<uint32_t RegCount>
|
| 73 |
+
CUTLASS_DEVICE
|
| 74 |
+
void warpgroup_reg_alloc(){
|
| 75 |
+
#if CUDA_CTA_RECONFIG_ACTIVATED
|
| 76 |
+
asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
|
| 77 |
+
#endif
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template<uint32_t RegCount>
|
| 81 |
+
CUTLASS_DEVICE
|
| 82 |
+
void warpgroup_reg_dealloc(){
|
| 83 |
+
#if CUDA_CTA_RECONFIG_ACTIVATED
|
| 84 |
+
asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
|
| 85 |
+
#endif
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
} // namespace arch
|
| 89 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates exposing SIMD operators
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/arch/array.h"
|
| 38 |
+
#include "cutlass/arch/numeric_types.h"
|
| 39 |
+
|
| 40 |
+
namespace cutlass {
|
| 41 |
+
namespace arch {
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
//
|
| 46 |
+
// Element-wise operators
|
| 47 |
+
//
|
| 48 |
+
|
| 49 |
+
CUTLASS_HOST_DEVICE
|
| 50 |
+
template <typename T, int N>
|
| 51 |
+
Array<T, N> operator*(Array<T, N> const &a, Array<T, N> const &b) {
|
| 52 |
+
Array<T, N> d;
|
| 53 |
+
CUTLASS_PRAGMA_UNROLL
|
| 54 |
+
for (int i = 0; i < N; ++i) {
|
| 55 |
+
d[i] = a[i] * b[i];
|
| 56 |
+
}
|
| 57 |
+
return d;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
CUTLASS_HOST_DEVICE
|
| 61 |
+
template <typename T, int N>
|
| 62 |
+
Array<T, N> operator+(Array<T, N> const &a, Array<T, N> const &b) {
|
| 63 |
+
Array<T, N> d;
|
| 64 |
+
CUTLASS_PRAGMA_UNROLL
|
| 65 |
+
for (int i = 0; i < N; ++i) {
|
| 66 |
+
d[i] = a[i] + b[i];
|
| 67 |
+
}
|
| 68 |
+
return d;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
CUTLASS_HOST_DEVICE
|
| 72 |
+
template <typename T, int N>
|
| 73 |
+
Array<T, N> operator-(Array<T, N> const &a, Array<T, N> const &b) {
|
| 74 |
+
Array<T, N> d;
|
| 75 |
+
CUTLASS_PRAGMA_UNROLL
|
| 76 |
+
for (int i = 0; i < N; ++i) {
|
| 77 |
+
d[i] = a[i] - b[i];
|
| 78 |
+
}
|
| 79 |
+
return d;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 83 |
+
|
| 84 |
+
//
|
| 85 |
+
// Multiply-accumulate operators
|
| 86 |
+
//
|
| 87 |
+
|
| 88 |
+
CUTLASS_HOST_DEVICE
|
| 89 |
+
template <typename T, int N>
|
| 90 |
+
Array<T, N> mac(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) {
|
| 91 |
+
Array<T, N> d;
|
| 92 |
+
CUTLASS_PRAGMA_UNROLL
|
| 93 |
+
for (int i = 0; i < N; ++i) {
|
| 94 |
+
d[i] = a[i] * b[i] + c[i];
|
| 95 |
+
}
|
| 96 |
+
return d;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 100 |
+
|
| 101 |
+
//
|
| 102 |
+
// Dot product operator
|
| 103 |
+
//
|
| 104 |
+
|
| 105 |
+
CUTLASS_HOST_DEVICE
|
| 106 |
+
template <typename Element, typename Accumulator, int N>
|
| 107 |
+
Accumulator dot(Array<T, N> const &a, Array<T, N> const &b, Accumulator accum) {
|
| 108 |
+
CUTLASS_PRAGMA_UNROLL
|
| 109 |
+
for (int i = 0; i < N; ++i) {
|
| 110 |
+
accum += a[i] * b[i];
|
| 111 |
+
}
|
| 112 |
+
return accum;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 116 |
+
|
| 117 |
+
} // namespace arch
|
| 118 |
+
} // namespace cutlass
|
| 119 |
+
|
| 120 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 121 |
+
|
| 122 |
+
#include "simd_sm60.h"
|
| 123 |
+
#include "simd_sm61.h"
|
| 124 |
+
|
| 125 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates exposing SIMD operators for SM60
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "simd.h"
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
namespace arch {
|
| 41 |
+
|
| 42 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
//
|
| 45 |
+
// Element-wise operators - specialized for half_t x 2
|
| 46 |
+
//
|
| 47 |
+
|
| 48 |
+
CUTLASS_HOST_DEVICE
|
| 49 |
+
template <>
|
| 50 |
+
Array<half_t, 2> operator*(Array<half_t, 2> const &a, Array<half_t, 2> const &b) {
|
| 51 |
+
Array<half_t, 2> d;
|
| 52 |
+
|
| 53 |
+
return d;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
CUTLASS_HOST_DEVICE
|
| 57 |
+
template <>
|
| 58 |
+
Array<half_t, 2> operator+(AArray<half_t, 2> const &a, Array<half_t, 2> const &b) {
|
| 59 |
+
Array<half_t, 2> d;
|
| 60 |
+
|
| 61 |
+
return d;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
CUTLASS_HOST_DEVICE
|
| 65 |
+
template <>
|
| 66 |
+
Array<half_t, 2> operator-(Array<half_t, 2> const &a, Array<half_t, 2> const &b) {
|
| 67 |
+
Array<T, N> d;
|
| 68 |
+
|
| 69 |
+
return d;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 73 |
+
|
| 74 |
+
/// Multiply-accumulate operators - specialized for half_t x 2
|
| 75 |
+
CUTLASS_HOST_DEVICE
|
| 76 |
+
template <>
|
| 77 |
+
Array<half_t, 2> mac(Array<half_t, 2> const &a, Array<half_t, 2> const &b, Array<half_t, 2> const &c) {
|
| 78 |
+
Array<half_t, 2> d;
|
| 79 |
+
|
| 80 |
+
return d;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 84 |
+
|
| 85 |
+
/// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t
|
| 86 |
+
CUTLASS_HOST_DEVICE
|
| 87 |
+
template <>
|
| 88 |
+
half_t dot(Array<half_t, 2> const &a, Array<half_t, 2> const &b, half_t accum) {
|
| 89 |
+
|
| 90 |
+
return accum;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
/// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float
|
| 94 |
+
CUTLASS_HOST_DEVICE
|
| 95 |
+
template <>
|
| 96 |
+
float dot(Array<half_t, 2> const &a, Array<half_t, 2> const &b, float accum) {
|
| 97 |
+
|
| 98 |
+
return accum;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 102 |
+
|
| 103 |
+
} // namespace arch
|
| 104 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates exposing SIMD operators for SM61
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "simd.h"
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
namespace arch {
|
| 41 |
+
|
| 42 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
/// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t
|
| 45 |
+
CUTLASS_HOST_DEVICE
|
| 46 |
+
template <>
|
| 47 |
+
int32_t dot(Array<int8_t, 4> const &a, Array<int8_t, 4> const &b, int32_t accum) {
|
| 48 |
+
|
| 49 |
+
return accum;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
/// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t
|
| 53 |
+
CUTLASS_HOST_DEVICE
|
| 54 |
+
template <>
|
| 55 |
+
int32_t dot(Array<uint8_t, 4> const &a, Array<int8_t, 4> const &b, int32_t accum) {
|
| 56 |
+
|
| 57 |
+
return accum;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t
|
| 61 |
+
CUTLASS_HOST_DEVICE
|
| 62 |
+
template <>
|
| 63 |
+
int32_t dot(Array<int8_t, 4> const &a, Array<uint8_t, 4> const &b, int32_t accum) {
|
| 64 |
+
|
| 65 |
+
return accum;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t
|
| 69 |
+
CUTLASS_HOST_DEVICE
|
| 70 |
+
template <>
|
| 71 |
+
int32_t dot(Array<uint8_t, 4> const &a, Array<uint8_t, 4> const &b, int32_t accum) {
|
| 72 |
+
|
| 73 |
+
return accum;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
|
| 78 |
+
/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t
|
| 79 |
+
CUTLASS_HOST_DEVICE
|
| 80 |
+
template <>
|
| 81 |
+
int32_t dot(Array<int16_t, 2> const &a, Array<int8_t, 2> const &b, int32_t accum) {
|
| 82 |
+
|
| 83 |
+
return accum;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t
|
| 87 |
+
CUTLASS_HOST_DEVICE
|
| 88 |
+
template <>
|
| 89 |
+
int32_t dot(Array<uint16_t, 2> const &a, Array<int8_t, 2> const &b, int32_t accum) {
|
| 90 |
+
|
| 91 |
+
return accum;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t
|
| 95 |
+
CUTLASS_HOST_DEVICE
|
| 96 |
+
template <>
|
| 97 |
+
int32_t dot(Array<int16_t, 2> const &a, Array<uint8_t, 2> const &b, int32_t accum) {
|
| 98 |
+
|
| 99 |
+
return accum;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t
|
| 103 |
+
CUTLASS_HOST_DEVICE
|
| 104 |
+
template <>
|
| 105 |
+
int32_t dot(Array<uint16_t, 2> const &a, Array<uint8_t, 2> const &b, int32_t accum) {
|
| 106 |
+
|
| 107 |
+
return accum;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 111 |
+
|
| 112 |
+
/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t
|
| 113 |
+
CUTLASS_HOST_DEVICE
|
| 114 |
+
template <>
|
| 115 |
+
int32_t dot(Array<int16_t, 2> const &a, Array<int16_t, 2> const &b, int32_t accum) {
|
| 116 |
+
|
| 117 |
+
return accum;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t
|
| 121 |
+
CUTLASS_HOST_DEVICE
|
| 122 |
+
template <>
|
| 123 |
+
int32_t dot(Array<uint16_t, 2> const &a, Array<int16_t, 2> const &b, int32_t accum) {
|
| 124 |
+
|
| 125 |
+
return accum;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t
|
| 129 |
+
CUTLASS_HOST_DEVICE
|
| 130 |
+
template <>
|
| 131 |
+
int32_t dot(Array<int16_t, 2> const &a, Array<uint16_t, 2> const &b, int32_t accum) {
|
| 132 |
+
|
| 133 |
+
return accum;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t
|
| 137 |
+
CUTLASS_HOST_DEVICE
|
| 138 |
+
template <>
|
| 139 |
+
int32_t dot(Array<uint16_t, 2> const &a, Array<uint16_t, 2> const &b, int32_t accum) {
|
| 140 |
+
|
| 141 |
+
return accum;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 145 |
+
|
| 146 |
+
} // namespace arch
|
| 147 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp
ADDED
|
@@ -0,0 +1,1271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Synchronization event logging for race condition debugging.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/detail/helper_macros.hpp"
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#if defined(__CUDACC_RTC__)
|
| 40 |
+
#include CUDA_STD_HEADER(cstdint)
|
| 41 |
+
#else
|
| 42 |
+
#include <cstdint>
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
#if !defined(__CUDACC_RTC__)
|
| 46 |
+
#include <mutex>
|
| 47 |
+
#include <vector>
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace arch {
|
| 52 |
+
|
| 53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 56 |
+
|
| 57 |
+
constexpr uint32_t synclog_cap = 1 << 26;
|
| 58 |
+
|
| 59 |
+
inline std::mutex synclog_mutex;
|
| 60 |
+
inline std::vector<uint32_t*> synclog_buf_list;
|
| 61 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 62 |
+
CUTLASS_DEVICE uint32_t* synclog_buf;
|
| 63 |
+
#endif
|
| 64 |
+
|
| 65 |
+
CUTLASS_DEVICE
|
| 66 |
+
uint32_t* synclog_alloc(uint32_t n) {
|
| 67 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 68 |
+
uint32_t* buf = synclog_buf;
|
| 69 |
+
if (buf == nullptr) return nullptr;
|
| 70 |
+
uint32_t last = atomicAdd(&buf[0], n);
|
| 71 |
+
if (last + n < synclog_cap) return buf + last + 1;
|
| 72 |
+
if (last >= synclog_cap) atomicAdd(&buf[0], -n);
|
| 73 |
+
#endif
|
| 74 |
+
return nullptr;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
CUTLASS_DEVICE
|
| 78 |
+
void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) {
|
| 79 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 80 |
+
uint64_t time64;
|
| 81 |
+
asm volatile (
|
| 82 |
+
"mov.u64 %0, %%globaltimer;\n"
|
| 83 |
+
: "=l"(time64) :
|
| 84 |
+
);
|
| 85 |
+
to[0] = header;
|
| 86 |
+
to[1] = line;
|
| 87 |
+
to[2] = time64;
|
| 88 |
+
to[3] = time64 >> 32;
|
| 89 |
+
to[4] = threadIdx.x;
|
| 90 |
+
to[5] = threadIdx.y;
|
| 91 |
+
to[6] = threadIdx.z;
|
| 92 |
+
to[7] = blockIdx.x;
|
| 93 |
+
to[8] = blockIdx.y;
|
| 94 |
+
to[9] = blockIdx.z;
|
| 95 |
+
#endif
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
constexpr uint32_t synclog_header_none = 0;
|
| 99 |
+
constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3;
|
| 100 |
+
|
| 101 |
+
constexpr bool synclog_enable_syncthreads = true;
|
| 102 |
+
constexpr uint32_t synclog_header_syncthreads = 1;
|
| 103 |
+
constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0;
|
| 104 |
+
|
| 105 |
+
constexpr bool synclog_enable_syncwarp = true;
|
| 106 |
+
constexpr uint32_t synclog_header_syncwarp = 2;
|
| 107 |
+
constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0;
|
| 108 |
+
|
| 109 |
+
constexpr bool synclog_enable_named_barrier_arrive_and_wait = true;
|
| 110 |
+
constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3;
|
| 111 |
+
constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2;
|
| 112 |
+
|
| 113 |
+
constexpr bool synclog_enable_named_barrier_arrive = true;
|
| 114 |
+
constexpr uint32_t synclog_header_named_barrier_arrive = 4;
|
| 115 |
+
constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2;
|
| 116 |
+
|
| 117 |
+
constexpr bool synclog_enable_cluster_barrier_init = true;
|
| 118 |
+
constexpr uint32_t synclog_header_cluster_barrier_init = 5;
|
| 119 |
+
constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2;
|
| 120 |
+
|
| 121 |
+
constexpr bool synclog_enable_cluster_barrier_wait = true;
|
| 122 |
+
constexpr uint32_t synclog_header_cluster_barrier_wait = 6;
|
| 123 |
+
constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 2;
|
| 124 |
+
constexpr bool synclog_enable_cluster_barrier_test_wait = true;
|
| 125 |
+
constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7;
|
| 126 |
+
constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 3;
|
| 127 |
+
constexpr bool synclog_enable_cluster_barrier_try_wait = true;
|
| 128 |
+
constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8;
|
| 129 |
+
constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 2;
|
| 130 |
+
constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true;
|
| 131 |
+
constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9;
|
| 132 |
+
constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 3;
|
| 133 |
+
constexpr bool synclog_enable_cluster_barrier_arrive = true;
|
| 134 |
+
constexpr uint32_t synclog_header_cluster_barrier_arrive = 10;
|
| 135 |
+
constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 1;
|
| 136 |
+
constexpr bool synclog_enable_cluster_barrier_invalidate = true;
|
| 137 |
+
constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11;
|
| 138 |
+
constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 1;
|
| 139 |
+
constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true;
|
| 140 |
+
constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12;
|
| 141 |
+
constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 2;
|
| 142 |
+
constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true;
|
| 143 |
+
constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13;
|
| 144 |
+
constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 4;
|
| 145 |
+
constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true;
|
| 146 |
+
constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14;
|
| 147 |
+
constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 2;
|
| 148 |
+
constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true;
|
| 149 |
+
constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15;
|
| 150 |
+
constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 4;
|
| 151 |
+
constexpr bool synclog_enable_fence_barrier_init = true;
|
| 152 |
+
constexpr uint32_t synclog_header_fence_barrier_init = 16;
|
| 153 |
+
constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0;
|
| 154 |
+
|
| 155 |
+
constexpr bool synclog_enable_fence_view_async_shared = true;
|
| 156 |
+
constexpr uint32_t synclog_header_fence_view_async_shared = 17;
|
| 157 |
+
constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0;
|
| 158 |
+
|
| 159 |
+
constexpr bool synclog_enable_cp_async_wait = true;
|
| 160 |
+
constexpr uint32_t synclog_header_cp_async_wait = 18;
|
| 161 |
+
constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1;
|
| 162 |
+
|
| 163 |
+
constexpr bool synclog_enable_cp_async_wait_all = true;
|
| 164 |
+
constexpr uint32_t synclog_header_cp_async_wait_all = 19;
|
| 165 |
+
constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0;
|
| 166 |
+
|
| 167 |
+
constexpr bool synclog_enable_cp_async_fence = true;
|
| 168 |
+
constexpr uint32_t synclog_header_cp_async_fence = 20;
|
| 169 |
+
constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0;
|
| 170 |
+
|
| 171 |
+
constexpr bool synclog_enable_cp_async_nan = true;
|
| 172 |
+
constexpr uint32_t synclog_header_cp_async_nan = 21;
|
| 173 |
+
constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4;
|
| 174 |
+
|
| 175 |
+
constexpr bool synclog_enable_cp_async_zfill = true;
|
| 176 |
+
constexpr uint32_t synclog_header_cp_async_zfill = 22;
|
| 177 |
+
constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5;
|
| 178 |
+
|
| 179 |
+
constexpr bool synclog_enable_cp_async = true;
|
| 180 |
+
constexpr uint32_t synclog_header_cp_async = 23;
|
| 181 |
+
constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5;
|
| 182 |
+
|
| 183 |
+
constexpr bool synclog_enable_tma_load = true;
|
| 184 |
+
constexpr uint32_t synclog_header_tma_load = 24;
|
| 185 |
+
constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4;
|
| 186 |
+
|
| 187 |
+
constexpr bool synclog_enable_tma_store = true;
|
| 188 |
+
constexpr uint32_t synclog_header_tma_store = 25;
|
| 189 |
+
constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3;
|
| 190 |
+
|
| 191 |
+
constexpr bool synclog_enable_tma_store_arrive = true;
|
| 192 |
+
constexpr uint32_t synclog_header_tma_store_arrive = 26;
|
| 193 |
+
constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0;
|
| 194 |
+
|
| 195 |
+
constexpr bool synclog_enable_tma_store_wait = true;
|
| 196 |
+
constexpr uint32_t synclog_header_tma_store_wait = 27;
|
| 197 |
+
constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1;
|
| 198 |
+
|
| 199 |
+
constexpr bool synclog_enable_warpgroup_arrive = true;
|
| 200 |
+
constexpr uint32_t synclog_header_warpgroup_arrive = 28;
|
| 201 |
+
constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0;
|
| 202 |
+
|
| 203 |
+
constexpr bool synclog_enable_warpgroup_wait = true;
|
| 204 |
+
constexpr uint32_t synclog_header_warpgroup_wait = 29;
|
| 205 |
+
constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1;
|
| 206 |
+
|
| 207 |
+
constexpr bool synclog_enable_warpgroup_commit_batch = true;
|
| 208 |
+
constexpr uint32_t synclog_header_warpgroup_commit_batch = 30;
|
| 209 |
+
constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0;
|
| 210 |
+
|
| 211 |
+
constexpr bool synclog_enable_wgmma_reg_smem = true;
|
| 212 |
+
constexpr uint32_t synclog_header_wgmma_reg_smem = 31;
|
| 213 |
+
constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2;
|
| 214 |
+
|
| 215 |
+
constexpr bool synclog_enable_wgmma_smem_smem = true;
|
| 216 |
+
constexpr uint32_t synclog_header_wgmma_smem_smem = 32;
|
| 217 |
+
constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4;
|
| 218 |
+
|
| 219 |
+
constexpr bool synclog_enable_cpasync_barrier_arrive = true;
|
| 220 |
+
constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33;
|
| 221 |
+
constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 1;
|
| 222 |
+
CUTLASS_DEVICE
|
| 223 |
+
bool synclog_condition_emit() {
|
| 224 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 225 |
+
return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 &&
|
| 226 |
+
blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0;
|
| 227 |
+
#else
|
| 228 |
+
return 0;
|
| 229 |
+
#endif
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
CUTLASS_DEVICE
|
| 233 |
+
bool synclog_condition_print() {
|
| 234 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 235 |
+
return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 &&
|
| 236 |
+
blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0;
|
| 237 |
+
#else
|
| 238 |
+
return false;
|
| 239 |
+
#endif
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
CUTLASS_DEVICE
|
| 243 |
+
void synclog_print_prefix(char const* header, uint32_t at) {
|
| 244 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 245 |
+
uint32_t line = synclog_buf[at + 1];
|
| 246 |
+
uint32_t timeLo = synclog_buf[at + 2];
|
| 247 |
+
uint32_t timeHi = synclog_buf[at + 3];
|
| 248 |
+
uint32_t threadIdxX = synclog_buf[at + 4];
|
| 249 |
+
uint32_t threadIdxY = synclog_buf[at + 5];
|
| 250 |
+
uint32_t threadIdxZ = synclog_buf[at + 6];
|
| 251 |
+
uint32_t blockIdxX = synclog_buf[at + 7];
|
| 252 |
+
uint32_t blockIdxY = synclog_buf[at + 8];
|
| 253 |
+
uint32_t blockIdxZ = synclog_buf[at + 9];
|
| 254 |
+
printf(
|
| 255 |
+
"%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ",
|
| 256 |
+
header, line,
|
| 257 |
+
(uint64_t)timeHi << 32 | timeLo,
|
| 258 |
+
threadIdxX, threadIdxY, threadIdxZ,
|
| 259 |
+
blockIdxX, blockIdxY, blockIdxZ
|
| 260 |
+
);
|
| 261 |
+
#endif
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
CUTLASS_DEVICE
|
| 265 |
+
void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) {
|
| 266 |
+
CUTLASS_UNUSED(hi);
|
| 267 |
+
uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4;
|
| 268 |
+
printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 272 |
+
|
| 273 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 274 |
+
|
| 275 |
+
inline void synclog_setup() {
|
| 276 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 277 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 278 |
+
std::scoped_lock lock(synclog_mutex);
|
| 279 |
+
auto fail = [] () {
|
| 280 |
+
fprintf(stderr, "synclog_setup() failed\n");
|
| 281 |
+
std::terminate();
|
| 282 |
+
};
|
| 283 |
+
int orig_device = 0;
|
| 284 |
+
if (cudaGetDevice(&orig_device) != cudaSuccess) {
|
| 285 |
+
fail();
|
| 286 |
+
}
|
| 287 |
+
int device_count = 0;
|
| 288 |
+
if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
|
| 289 |
+
fail();
|
| 290 |
+
}
|
| 291 |
+
if (synclog_buf_list.size() == 0) {
|
| 292 |
+
for (int device = 0; device < device_count; device++) {
|
| 293 |
+
uint32_t* buf = 0;
|
| 294 |
+
if (cudaSetDevice(device) != cudaSuccess ||
|
| 295 |
+
cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) {
|
| 296 |
+
fail();
|
| 297 |
+
}
|
| 298 |
+
synclog_buf_list.push_back(buf);
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
for (int device = 0; device < device_count; device++) {
|
| 302 |
+
uint32_t* buf = synclog_buf_list.at(device);
|
| 303 |
+
if (cudaSetDevice(device) != cudaSuccess ||
|
| 304 |
+
cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess ||
|
| 305 |
+
cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) {
|
| 306 |
+
fail();
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
if (cudaSetDevice(orig_device) != cudaSuccess) {
|
| 310 |
+
fail();
|
| 311 |
+
}
|
| 312 |
+
#endif
|
| 313 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
CUTLASS_DEVICE
|
| 317 |
+
void synclog_emit_syncthreads(uint32_t line) {
|
| 318 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 319 |
+
if constexpr (!synclog_enable_syncthreads) return;
|
| 320 |
+
if (!synclog_condition_emit()) return;
|
| 321 |
+
uint32_t* to = synclog_alloc(synclog_length_syncthreads);
|
| 322 |
+
if (to == nullptr) return;
|
| 323 |
+
synclog_emit_prefix(to, synclog_header_syncthreads, line);
|
| 324 |
+
#else
|
| 325 |
+
CUTLASS_UNUSED(line);
|
| 326 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
CUTLASS_DEVICE
|
| 330 |
+
void synclog_emit_syncwarp(uint32_t line) {
|
| 331 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 332 |
+
if constexpr (!synclog_enable_syncwarp) return;
|
| 333 |
+
if (!synclog_condition_emit()) return;
|
| 334 |
+
uint32_t* to = synclog_alloc(synclog_length_syncwarp);
|
| 335 |
+
if (to == nullptr) return;
|
| 336 |
+
synclog_emit_prefix(to, synclog_header_syncwarp, line);
|
| 337 |
+
#else
|
| 338 |
+
CUTLASS_UNUSED(line);
|
| 339 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
CUTLASS_DEVICE
|
| 343 |
+
void synclog_emit_named_barrier_arrive_and_wait(
|
| 344 |
+
uint32_t line,
|
| 345 |
+
uint32_t num_threads,
|
| 346 |
+
uint32_t barrier_id) {
|
| 347 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 348 |
+
if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return;
|
| 349 |
+
if (!synclog_condition_emit()) return;
|
| 350 |
+
uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait);
|
| 351 |
+
if (to == nullptr) return;
|
| 352 |
+
synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line);
|
| 353 |
+
to[synclog_length_prefix + 0] = num_threads;
|
| 354 |
+
to[synclog_length_prefix + 1] = barrier_id;
|
| 355 |
+
#else
|
| 356 |
+
CUTLASS_UNUSED(line);
|
| 357 |
+
CUTLASS_UNUSED(num_threads);
|
| 358 |
+
CUTLASS_UNUSED(barrier_id);
|
| 359 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
CUTLASS_DEVICE
|
| 363 |
+
void synclog_emit_named_barrier_arrive(
|
| 364 |
+
uint32_t line,
|
| 365 |
+
uint32_t num_threads,
|
| 366 |
+
uint32_t barrier_id) {
|
| 367 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 368 |
+
if constexpr (!synclog_enable_named_barrier_arrive) return;
|
| 369 |
+
if (!synclog_condition_emit()) return;
|
| 370 |
+
uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive);
|
| 371 |
+
if (to == nullptr) return;
|
| 372 |
+
synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line);
|
| 373 |
+
to[synclog_length_prefix + 0] = num_threads;
|
| 374 |
+
to[synclog_length_prefix + 1] = barrier_id;
|
| 375 |
+
#else
|
| 376 |
+
CUTLASS_UNUSED(line);
|
| 377 |
+
CUTLASS_UNUSED(num_threads);
|
| 378 |
+
CUTLASS_UNUSED(barrier_id);
|
| 379 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
CUTLASS_DEVICE
|
| 383 |
+
void synclog_emit_cluster_barrier_init(
|
| 384 |
+
uint32_t line,
|
| 385 |
+
uint32_t smem_addr,
|
| 386 |
+
uint32_t arrive_count) {
|
| 387 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 388 |
+
if constexpr (!synclog_enable_cluster_barrier_init) return;
|
| 389 |
+
if (!synclog_condition_emit()) return;
|
| 390 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init);
|
| 391 |
+
if (to == nullptr) return;
|
| 392 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line);
|
| 393 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 394 |
+
to[synclog_length_prefix + 1] = arrive_count;
|
| 395 |
+
#else
|
| 396 |
+
CUTLASS_UNUSED(line);
|
| 397 |
+
CUTLASS_UNUSED(smem_addr);
|
| 398 |
+
CUTLASS_UNUSED(arrive_count);
|
| 399 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
CUTLASS_DEVICE
|
| 403 |
+
void synclog_emit_cluster_barrier_wait(
|
| 404 |
+
uint32_t line,
|
| 405 |
+
uint32_t smem_addr,
|
| 406 |
+
uint32_t phase) {
|
| 407 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 408 |
+
if constexpr (!synclog_enable_cluster_barrier_wait) return;
|
| 409 |
+
if (!synclog_condition_emit()) return;
|
| 410 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait);
|
| 411 |
+
if (to == nullptr) return;
|
| 412 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line);
|
| 413 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 414 |
+
to[synclog_length_prefix + 1] = phase;
|
| 415 |
+
#else
|
| 416 |
+
CUTLASS_UNUSED(line);
|
| 417 |
+
CUTLASS_UNUSED(smem_addr);
|
| 418 |
+
CUTLASS_UNUSED(phase);
|
| 419 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
CUTLASS_DEVICE
|
| 423 |
+
void synclog_emit_cluster_barrier_test_wait(
|
| 424 |
+
uint32_t line,
|
| 425 |
+
uint32_t smem_addr,
|
| 426 |
+
uint32_t phase,
|
| 427 |
+
uint32_t pred) {
|
| 428 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 429 |
+
if constexpr (!synclog_enable_cluster_barrier_test_wait) return;
|
| 430 |
+
if (!synclog_condition_emit()) return;
|
| 431 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait);
|
| 432 |
+
if (to == nullptr) return;
|
| 433 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line);
|
| 434 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 435 |
+
to[synclog_length_prefix + 1] = phase;
|
| 436 |
+
to[synclog_length_prefix + 2] = pred;
|
| 437 |
+
#else
|
| 438 |
+
CUTLASS_UNUSED(line);
|
| 439 |
+
CUTLASS_UNUSED(smem_addr);
|
| 440 |
+
CUTLASS_UNUSED(phase);
|
| 441 |
+
CUTLASS_UNUSED(pred);
|
| 442 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
CUTLASS_DEVICE
|
| 446 |
+
void synclog_emit_cluster_barrier_try_wait(
|
| 447 |
+
uint32_t line,
|
| 448 |
+
uint32_t smem_addr,
|
| 449 |
+
uint32_t phase) {
|
| 450 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 451 |
+
if constexpr (!synclog_enable_cluster_barrier_try_wait) return;
|
| 452 |
+
if (!synclog_condition_emit()) return;
|
| 453 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait);
|
| 454 |
+
if (to == nullptr) return;
|
| 455 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line);
|
| 456 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 457 |
+
to[synclog_length_prefix + 1] = phase;
|
| 458 |
+
#else
|
| 459 |
+
CUTLASS_UNUSED(line);
|
| 460 |
+
CUTLASS_UNUSED(smem_addr);
|
| 461 |
+
CUTLASS_UNUSED(phase);
|
| 462 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
CUTLASS_DEVICE
|
| 466 |
+
void synclog_emit_cluster_barrier_arrive_cluster(
|
| 467 |
+
uint32_t line,
|
| 468 |
+
uint32_t smem_addr,
|
| 469 |
+
uint32_t cta_id,
|
| 470 |
+
uint32_t pred) {
|
| 471 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 472 |
+
if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return;
|
| 473 |
+
if (!synclog_condition_emit()) return;
|
| 474 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster);
|
| 475 |
+
if (to == nullptr) return;
|
| 476 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line);
|
| 477 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 478 |
+
to[synclog_length_prefix + 1] = cta_id;
|
| 479 |
+
to[synclog_length_prefix + 2] = pred;
|
| 480 |
+
#else
|
| 481 |
+
CUTLASS_UNUSED(line);
|
| 482 |
+
CUTLASS_UNUSED(smem_addr);
|
| 483 |
+
CUTLASS_UNUSED(cta_id);
|
| 484 |
+
CUTLASS_UNUSED(pred);
|
| 485 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
CUTLASS_DEVICE
|
| 489 |
+
void synclog_emit_cluster_barrier_arrive(
|
| 490 |
+
uint32_t line,
|
| 491 |
+
uint32_t smem_addr) {
|
| 492 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 493 |
+
if constexpr (!synclog_enable_cluster_barrier_arrive) return;
|
| 494 |
+
if (!synclog_condition_emit()) return;
|
| 495 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive);
|
| 496 |
+
if (to == nullptr) return;
|
| 497 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line);
|
| 498 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 499 |
+
#else
|
| 500 |
+
CUTLASS_UNUSED(line);
|
| 501 |
+
CUTLASS_UNUSED(smem_addr);
|
| 502 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
CUTLASS_DEVICE
|
| 506 |
+
void synclog_emit_cluster_barrier_invalidate(
|
| 507 |
+
uint32_t line,
|
| 508 |
+
uint32_t smem_addr) {
|
| 509 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 510 |
+
if constexpr (!synclog_enable_cluster_barrier_invalidate) return;
|
| 511 |
+
if (!synclog_condition_emit()) return;
|
| 512 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate);
|
| 513 |
+
if (to == nullptr) return;
|
| 514 |
+
synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line);
|
| 515 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 516 |
+
#else
|
| 517 |
+
CUTLASS_UNUSED(line);
|
| 518 |
+
CUTLASS_UNUSED(smem_addr);
|
| 519 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
CUTLASS_DEVICE
|
| 523 |
+
void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(
|
| 524 |
+
uint32_t line,
|
| 525 |
+
uint32_t smem_addr,
|
| 526 |
+
uint32_t transaction_bytes) {
|
| 527 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 528 |
+
if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return;
|
| 529 |
+
if (!synclog_condition_emit()) return;
|
| 530 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx);
|
| 531 |
+
if (to == nullptr) return;
|
| 532 |
+
synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line);
|
| 533 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 534 |
+
to[synclog_length_prefix + 1] = transaction_bytes;
|
| 535 |
+
#else
|
| 536 |
+
CUTLASS_UNUSED(line);
|
| 537 |
+
CUTLASS_UNUSED(smem_addr);
|
| 538 |
+
CUTLASS_UNUSED(transaction_bytes);
|
| 539 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
CUTLASS_DEVICE
|
| 543 |
+
void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster(
|
| 544 |
+
uint32_t line,
|
| 545 |
+
uint32_t smem_addr,
|
| 546 |
+
uint32_t transaction_bytes,
|
| 547 |
+
uint32_t cta_id,
|
| 548 |
+
uint32_t pred) {
|
| 549 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 550 |
+
if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return;
|
| 551 |
+
if (!synclog_condition_emit()) return;
|
| 552 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster);
|
| 553 |
+
if (to == nullptr) return;
|
| 554 |
+
synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line);
|
| 555 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 556 |
+
to[synclog_length_prefix + 1] = transaction_bytes;
|
| 557 |
+
to[synclog_length_prefix + 2] = cta_id;
|
| 558 |
+
to[synclog_length_prefix + 3] = pred;
|
| 559 |
+
#else
|
| 560 |
+
CUTLASS_UNUSED(line);
|
| 561 |
+
CUTLASS_UNUSED(smem_addr);
|
| 562 |
+
CUTLASS_UNUSED(transaction_bytes);
|
| 563 |
+
CUTLASS_UNUSED(cta_id);
|
| 564 |
+
CUTLASS_UNUSED(pred);
|
| 565 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
CUTLASS_DEVICE
|
| 569 |
+
void synclog_emit_cluster_transaction_barrier_expect_transaction(
|
| 570 |
+
uint32_t line,
|
| 571 |
+
uint32_t smem_addr,
|
| 572 |
+
uint32_t transaction_bytes) {
|
| 573 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 574 |
+
if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return;
|
| 575 |
+
if (!synclog_condition_emit()) return;
|
| 576 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction);
|
| 577 |
+
if (to == nullptr) return;
|
| 578 |
+
synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line);
|
| 579 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 580 |
+
to[synclog_length_prefix + 1] = transaction_bytes;
|
| 581 |
+
#else
|
| 582 |
+
CUTLASS_UNUSED(line);
|
| 583 |
+
CUTLASS_UNUSED(smem_addr);
|
| 584 |
+
CUTLASS_UNUSED(transaction_bytes);
|
| 585 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
CUTLASS_DEVICE
|
| 589 |
+
void synclog_emit_cluster_transaction_barrier_complete_transaction(
|
| 590 |
+
uint32_t line,
|
| 591 |
+
uint32_t smem_addr,
|
| 592 |
+
uint32_t dst_cta_id,
|
| 593 |
+
uint32_t transaction_bytes,
|
| 594 |
+
uint32_t pred) {
|
| 595 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 596 |
+
if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return;
|
| 597 |
+
if (!synclog_condition_emit()) return;
|
| 598 |
+
uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction);
|
| 599 |
+
if (to == nullptr) return;
|
| 600 |
+
synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line);
|
| 601 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 602 |
+
to[synclog_length_prefix + 1] = dst_cta_id;
|
| 603 |
+
to[synclog_length_prefix + 2] = transaction_bytes;
|
| 604 |
+
to[synclog_length_prefix + 3] = pred;
|
| 605 |
+
#else
|
| 606 |
+
CUTLASS_UNUSED(line);
|
| 607 |
+
CUTLASS_UNUSED(smem_addr);
|
| 608 |
+
CUTLASS_UNUSED(dst_cta_id);
|
| 609 |
+
CUTLASS_UNUSED(transaction_bytes);
|
| 610 |
+
CUTLASS_UNUSED(pred);
|
| 611 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
CUTLASS_DEVICE
|
| 615 |
+
void synclog_emit_fence_barrier_init(uint32_t line) {
|
| 616 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 617 |
+
if constexpr (!synclog_enable_fence_barrier_init) return;
|
| 618 |
+
if (!synclog_condition_emit()) return;
|
| 619 |
+
uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init);
|
| 620 |
+
if (to == nullptr) return;
|
| 621 |
+
synclog_emit_prefix(to, synclog_header_fence_barrier_init, line);
|
| 622 |
+
#else
|
| 623 |
+
CUTLASS_UNUSED(line);
|
| 624 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
CUTLASS_DEVICE
|
| 628 |
+
void synclog_emit_fence_view_async_shared(uint32_t line) {
|
| 629 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 630 |
+
if constexpr (!synclog_enable_fence_view_async_shared) return;
|
| 631 |
+
if (!synclog_condition_emit()) return;
|
| 632 |
+
uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared);
|
| 633 |
+
if (to == nullptr) return;
|
| 634 |
+
synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line);
|
| 635 |
+
#else
|
| 636 |
+
CUTLASS_UNUSED(line);
|
| 637 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
CUTLASS_DEVICE
|
| 641 |
+
void synclog_emit_cp_async_wait(
|
| 642 |
+
uint32_t line,
|
| 643 |
+
uint32_t n) {
|
| 644 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 645 |
+
if constexpr (!synclog_enable_cp_async_wait) return;
|
| 646 |
+
if (!synclog_condition_emit()) return;
|
| 647 |
+
uint32_t* to = synclog_alloc(synclog_length_cp_async_wait);
|
| 648 |
+
if (to == nullptr) return;
|
| 649 |
+
synclog_emit_prefix(to, synclog_header_cp_async_wait, line);
|
| 650 |
+
to[synclog_length_prefix + 0] = n;
|
| 651 |
+
#else
|
| 652 |
+
CUTLASS_UNUSED(line);
|
| 653 |
+
CUTLASS_UNUSED(n);
|
| 654 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
CUTLASS_DEVICE
|
| 658 |
+
void synclog_emit_cp_async_wait_all(uint32_t line) {
|
| 659 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 660 |
+
if constexpr (!synclog_enable_cp_async_wait_all) return;
|
| 661 |
+
if (!synclog_condition_emit()) return;
|
| 662 |
+
uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all);
|
| 663 |
+
if (to == nullptr) return;
|
| 664 |
+
synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line);
|
| 665 |
+
#else
|
| 666 |
+
CUTLASS_UNUSED(line);
|
| 667 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
CUTLASS_DEVICE
|
| 671 |
+
void synclog_emit_cp_async_fence(uint32_t line) {
|
| 672 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 673 |
+
if constexpr (!synclog_enable_cp_async_fence) return;
|
| 674 |
+
if (!synclog_condition_emit()) return;
|
| 675 |
+
uint32_t* to = synclog_alloc(synclog_length_cp_async_fence);
|
| 676 |
+
if (to == nullptr) return;
|
| 677 |
+
synclog_emit_prefix(to, synclog_header_cp_async_fence, line);
|
| 678 |
+
#else
|
| 679 |
+
CUTLASS_UNUSED(line);
|
| 680 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
CUTLASS_DEVICE
|
| 684 |
+
void synclog_emit_cp_async_nan(
|
| 685 |
+
uint32_t line,
|
| 686 |
+
uint32_t smem_addr,
|
| 687 |
+
const void* gmem_ptr,
|
| 688 |
+
uint32_t pred) {
|
| 689 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 690 |
+
if constexpr (!synclog_enable_cp_async_nan) return;
|
| 691 |
+
if (!synclog_condition_emit()) return;
|
| 692 |
+
uint32_t* to = synclog_alloc(synclog_length_cp_async_nan);
|
| 693 |
+
if (to == nullptr) return;
|
| 694 |
+
synclog_emit_prefix(to, synclog_header_cp_async_nan, line);
|
| 695 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 696 |
+
to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr);
|
| 697 |
+
to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32);
|
| 698 |
+
to[synclog_length_prefix + 3] = pred;
|
| 699 |
+
#else
|
| 700 |
+
CUTLASS_UNUSED(line);
|
| 701 |
+
CUTLASS_UNUSED(smem_addr);
|
| 702 |
+
CUTLASS_UNUSED(gmem_ptr);
|
| 703 |
+
CUTLASS_UNUSED(pred);
|
| 704 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
CUTLASS_DEVICE
|
| 708 |
+
void synclog_emit_cp_async_zfill(
|
| 709 |
+
uint32_t line,
|
| 710 |
+
uint32_t smem_addr,
|
| 711 |
+
const void* gmem_ptr,
|
| 712 |
+
uint32_t pred,
|
| 713 |
+
uint32_t size) {
|
| 714 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 715 |
+
if constexpr (!synclog_enable_cp_async_zfill) return;
|
| 716 |
+
if (!synclog_condition_emit()) return;
|
| 717 |
+
uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill);
|
| 718 |
+
if (to == nullptr) return;
|
| 719 |
+
synclog_emit_prefix(to, synclog_header_cp_async_zfill, line);
|
| 720 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 721 |
+
to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr);
|
| 722 |
+
to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32);
|
| 723 |
+
to[synclog_length_prefix + 3] = pred;
|
| 724 |
+
to[synclog_length_prefix + 4] = size;
|
| 725 |
+
#else
|
| 726 |
+
CUTLASS_UNUSED(line);
|
| 727 |
+
CUTLASS_UNUSED(smem_addr);
|
| 728 |
+
CUTLASS_UNUSED(gmem_ptr);
|
| 729 |
+
CUTLASS_UNUSED(pred);
|
| 730 |
+
CUTLASS_UNUSED(size);
|
| 731 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
CUTLASS_DEVICE
|
| 735 |
+
void synclog_emit_cp_async(
|
| 736 |
+
uint32_t line,
|
| 737 |
+
uint32_t smem_addr,
|
| 738 |
+
const void* gmem_ptr,
|
| 739 |
+
uint32_t pred,
|
| 740 |
+
uint32_t size) {
|
| 741 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 742 |
+
if constexpr (!synclog_enable_cp_async) return;
|
| 743 |
+
if (!synclog_condition_emit()) return;
|
| 744 |
+
uint32_t* to = synclog_alloc(synclog_length_cp_async);
|
| 745 |
+
if (to == nullptr) return;
|
| 746 |
+
synclog_emit_prefix(to, synclog_header_cp_async, line);
|
| 747 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 748 |
+
to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr);
|
| 749 |
+
to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32);
|
| 750 |
+
to[synclog_length_prefix + 3] = pred;
|
| 751 |
+
to[synclog_length_prefix + 4] = size;
|
| 752 |
+
#else
|
| 753 |
+
CUTLASS_UNUSED(line);
|
| 754 |
+
CUTLASS_UNUSED(smem_addr);
|
| 755 |
+
CUTLASS_UNUSED(gmem_ptr);
|
| 756 |
+
CUTLASS_UNUSED(pred);
|
| 757 |
+
CUTLASS_UNUSED(size);
|
| 758 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
CUTLASS_DEVICE
|
| 762 |
+
void synclog_emit_tma_load(
|
| 763 |
+
uint32_t line,
|
| 764 |
+
uint64_t gmem_int_desc,
|
| 765 |
+
uint32_t smem_int_mbar,
|
| 766 |
+
uint32_t smem_int_ptr) {
|
| 767 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 768 |
+
if constexpr (!synclog_enable_tma_load) return;
|
| 769 |
+
if (!synclog_condition_emit()) return;
|
| 770 |
+
uint32_t* to = synclog_alloc(synclog_length_tma_load);
|
| 771 |
+
if (to == nullptr) return;
|
| 772 |
+
synclog_emit_prefix(to, synclog_header_tma_load, line);
|
| 773 |
+
to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc);
|
| 774 |
+
to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32);
|
| 775 |
+
to[synclog_length_prefix + 2] = smem_int_mbar;
|
| 776 |
+
to[synclog_length_prefix + 3] = smem_int_ptr;
|
| 777 |
+
#else
|
| 778 |
+
CUTLASS_UNUSED(line);
|
| 779 |
+
CUTLASS_UNUSED(gmem_int_desc);
|
| 780 |
+
CUTLASS_UNUSED(smem_int_mbar);
|
| 781 |
+
CUTLASS_UNUSED(smem_int_ptr);
|
| 782 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
CUTLASS_DEVICE
|
| 786 |
+
void synclog_emit_tma_store(
|
| 787 |
+
uint32_t line,
|
| 788 |
+
uint64_t gmem_int_desc,
|
| 789 |
+
uint32_t smem_int_ptr) {
|
| 790 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 791 |
+
if constexpr (!synclog_enable_tma_store) return;
|
| 792 |
+
if (!synclog_condition_emit()) return;
|
| 793 |
+
uint32_t* to = synclog_alloc(synclog_length_tma_store);
|
| 794 |
+
if (to == nullptr) return;
|
| 795 |
+
synclog_emit_prefix(to, synclog_header_tma_store, line);
|
| 796 |
+
to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc);
|
| 797 |
+
to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32);
|
| 798 |
+
to[synclog_length_prefix + 2] = smem_int_ptr;
|
| 799 |
+
#else
|
| 800 |
+
CUTLASS_UNUSED(line);
|
| 801 |
+
CUTLASS_UNUSED(gmem_int_desc);
|
| 802 |
+
CUTLASS_UNUSED(smem_int_ptr);
|
| 803 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
CUTLASS_DEVICE
|
| 807 |
+
void synclog_emit_tma_store_arrive(uint32_t line) {
|
| 808 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 809 |
+
if constexpr (!synclog_enable_tma_store_arrive) return;
|
| 810 |
+
if (!synclog_condition_emit()) return;
|
| 811 |
+
uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive);
|
| 812 |
+
if (to == nullptr) return;
|
| 813 |
+
synclog_emit_prefix(to, synclog_header_tma_store_arrive, line);
|
| 814 |
+
#else
|
| 815 |
+
CUTLASS_UNUSED(line);
|
| 816 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
CUTLASS_DEVICE
|
| 820 |
+
void synclog_emit_tma_store_wait(
|
| 821 |
+
uint32_t line,
|
| 822 |
+
uint32_t count) {
|
| 823 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 824 |
+
if constexpr (!synclog_enable_tma_store_wait) return;
|
| 825 |
+
if (!synclog_condition_emit()) return;
|
| 826 |
+
uint32_t* to = synclog_alloc(synclog_length_tma_store_wait);
|
| 827 |
+
if (to == nullptr) return;
|
| 828 |
+
synclog_emit_prefix(to, synclog_header_tma_store_wait, line);
|
| 829 |
+
to[synclog_length_prefix + 0] = count;
|
| 830 |
+
#else
|
| 831 |
+
CUTLASS_UNUSED(line);
|
| 832 |
+
CUTLASS_UNUSED(count);
|
| 833 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 834 |
+
}
|
| 835 |
+
|
| 836 |
+
CUTLASS_DEVICE
|
| 837 |
+
void synclog_emit_warpgroup_arrive(
|
| 838 |
+
uint32_t line) {
|
| 839 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 840 |
+
if constexpr (!synclog_enable_warpgroup_arrive) return;
|
| 841 |
+
if (!synclog_condition_emit()) return;
|
| 842 |
+
uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive);
|
| 843 |
+
if (to == nullptr) return;
|
| 844 |
+
synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line);
|
| 845 |
+
#else
|
| 846 |
+
CUTLASS_UNUSED(line);
|
| 847 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
CUTLASS_DEVICE
|
| 851 |
+
void synclog_emit_warpgroup_wait(
|
| 852 |
+
uint32_t line,
|
| 853 |
+
uint32_t n) {
|
| 854 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 855 |
+
if constexpr (!synclog_enable_warpgroup_wait) return;
|
| 856 |
+
if (!synclog_condition_emit()) return;
|
| 857 |
+
uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait);
|
| 858 |
+
if (to == nullptr) return;
|
| 859 |
+
synclog_emit_prefix(to, synclog_header_warpgroup_wait, line);
|
| 860 |
+
to[synclog_length_prefix + 0] = n;
|
| 861 |
+
#else
|
| 862 |
+
CUTLASS_UNUSED(line);
|
| 863 |
+
CUTLASS_UNUSED(n);
|
| 864 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
CUTLASS_DEVICE
|
| 868 |
+
void synclog_emit_warpgroup_commit_batch(
|
| 869 |
+
uint32_t line) {
|
| 870 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 871 |
+
if constexpr (!synclog_enable_warpgroup_commit_batch) return;
|
| 872 |
+
if (!synclog_condition_emit()) return;
|
| 873 |
+
uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch);
|
| 874 |
+
if (to == nullptr) return;
|
| 875 |
+
synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line);
|
| 876 |
+
#else
|
| 877 |
+
CUTLASS_UNUSED(line);
|
| 878 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
CUTLASS_DEVICE
|
| 882 |
+
void synclog_emit_wgmma_reg_smem(
|
| 883 |
+
uint32_t line,
|
| 884 |
+
uint64_t desc_b) {
|
| 885 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 886 |
+
if constexpr (!synclog_enable_wgmma_reg_smem) return;
|
| 887 |
+
if (!synclog_condition_emit()) return;
|
| 888 |
+
uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem);
|
| 889 |
+
if (to == nullptr) return;
|
| 890 |
+
synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line);
|
| 891 |
+
to[synclog_length_prefix + 0] = desc_b;
|
| 892 |
+
to[synclog_length_prefix + 1] = desc_b >> 32;
|
| 893 |
+
#else
|
| 894 |
+
CUTLASS_UNUSED(line);
|
| 895 |
+
CUTLASS_UNUSED(desc_b);
|
| 896 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
CUTLASS_DEVICE
|
| 900 |
+
void synclog_emit_wgmma_smem_smem(
|
| 901 |
+
uint32_t line,
|
| 902 |
+
uint64_t desc_a,
|
| 903 |
+
uint64_t desc_b) {
|
| 904 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 905 |
+
if constexpr (!synclog_enable_wgmma_smem_smem) return;
|
| 906 |
+
if (!synclog_condition_emit()) return;
|
| 907 |
+
uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem);
|
| 908 |
+
if (to == nullptr) return;
|
| 909 |
+
synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line);
|
| 910 |
+
to[synclog_length_prefix + 0] = desc_a;
|
| 911 |
+
to[synclog_length_prefix + 1] = desc_a >> 32;
|
| 912 |
+
to[synclog_length_prefix + 2] = desc_b;
|
| 913 |
+
to[synclog_length_prefix + 3] = desc_b >> 32;
|
| 914 |
+
#else
|
| 915 |
+
CUTLASS_UNUSED(line);
|
| 916 |
+
CUTLASS_UNUSED(desc_a);
|
| 917 |
+
CUTLASS_UNUSED(desc_b);
|
| 918 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
CUTLASS_DEVICE
|
| 922 |
+
void synclog_emit_cpasync_barrier_arrive(
|
| 923 |
+
uint32_t line,
|
| 924 |
+
uint32_t smem_addr) {
|
| 925 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 926 |
+
if constexpr (!synclog_enable_cpasync_barrier_arrive) return;
|
| 927 |
+
if (!synclog_condition_emit()) return;
|
| 928 |
+
uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive);
|
| 929 |
+
if (to == nullptr) return;
|
| 930 |
+
synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line);
|
| 931 |
+
to[synclog_length_prefix + 0] = smem_addr;
|
| 932 |
+
#else
|
| 933 |
+
CUTLASS_UNUSED(line);
|
| 934 |
+
CUTLASS_UNUSED(smem_addr);
|
| 935 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
#if !defined(CUTLASS_ENABLE_SYNCLOG)
|
| 939 |
+
CUTLASS_DEVICE
|
| 940 |
+
#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 941 |
+
static __attribute__((__noinline__)) __device__
|
| 942 |
+
#else
|
| 943 |
+
static __attribute__((__noinline__))
|
| 944 |
+
#endif
|
| 945 |
+
void synclog_print() {
|
| 946 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 947 |
+
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
| 948 |
+
if (synclog_buf == nullptr || !synclog_condition_print()) {
|
| 949 |
+
return;
|
| 950 |
+
}
|
| 951 |
+
printf("synclog start\n");
|
| 952 |
+
for (uint32_t at = 1; at < synclog_cap; ) {
|
| 953 |
+
uint32_t header = synclog_buf[at];
|
| 954 |
+
if (header == synclog_header_none) {
|
| 955 |
+
break;
|
| 956 |
+
}
|
| 957 |
+
printf("synclog at %u: ", at);
|
| 958 |
+
if constexpr (synclog_enable_syncthreads) {
|
| 959 |
+
if (header == synclog_header_syncthreads) {
|
| 960 |
+
synclog_print_prefix("syncthreads", at);
|
| 961 |
+
at += synclog_length_syncthreads;
|
| 962 |
+
printf("\n");
|
| 963 |
+
continue;
|
| 964 |
+
}
|
| 965 |
+
}
|
| 966 |
+
if constexpr (synclog_enable_syncwarp) {
|
| 967 |
+
if (header == synclog_header_syncwarp) {
|
| 968 |
+
synclog_print_prefix("syncwarp", at);
|
| 969 |
+
at += synclog_length_syncwarp;
|
| 970 |
+
printf("\n");
|
| 971 |
+
continue;
|
| 972 |
+
}
|
| 973 |
+
}
|
| 974 |
+
if constexpr (synclog_enable_named_barrier_arrive_and_wait) {
|
| 975 |
+
if (header == synclog_header_named_barrier_arrive_and_wait) {
|
| 976 |
+
synclog_print_prefix("named_barrier_arrive_and_wait", at);
|
| 977 |
+
at += synclog_length_named_barrier_arrive_and_wait;
|
| 978 |
+
printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 979 |
+
continue;
|
| 980 |
+
}
|
| 981 |
+
}
|
| 982 |
+
if constexpr (synclog_enable_named_barrier_arrive) {
|
| 983 |
+
if (header == synclog_header_named_barrier_arrive) {
|
| 984 |
+
synclog_print_prefix("named_barrier_arrive", at);
|
| 985 |
+
at += synclog_length_named_barrier_arrive;
|
| 986 |
+
printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 987 |
+
continue;
|
| 988 |
+
}
|
| 989 |
+
}
|
| 990 |
+
if constexpr (synclog_enable_cluster_barrier_init) {
|
| 991 |
+
if (header == synclog_header_cluster_barrier_init) {
|
| 992 |
+
synclog_print_prefix("cluster_barrier_init", at);
|
| 993 |
+
at += synclog_length_cluster_barrier_init;
|
| 994 |
+
printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 995 |
+
continue;
|
| 996 |
+
}
|
| 997 |
+
}
|
| 998 |
+
if constexpr (synclog_enable_cluster_barrier_wait) {
|
| 999 |
+
if (header == synclog_header_cluster_barrier_wait) {
|
| 1000 |
+
synclog_print_prefix("cluster_barrier_wait", at);
|
| 1001 |
+
at += synclog_length_cluster_barrier_wait;
|
| 1002 |
+
printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 1003 |
+
continue;
|
| 1004 |
+
}
|
| 1005 |
+
}
|
| 1006 |
+
if constexpr (synclog_enable_cluster_barrier_test_wait) {
|
| 1007 |
+
if (header == synclog_header_cluster_barrier_test_wait) {
|
| 1008 |
+
synclog_print_prefix("cluster_barrier_test_wait", at);
|
| 1009 |
+
at += synclog_length_cluster_barrier_test_wait;
|
| 1010 |
+
printf("smem_addr=%u phase=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
|
| 1011 |
+
continue;
|
| 1012 |
+
}
|
| 1013 |
+
}
|
| 1014 |
+
if constexpr (synclog_enable_cluster_barrier_try_wait) {
|
| 1015 |
+
if (header == synclog_header_cluster_barrier_try_wait) {
|
| 1016 |
+
synclog_print_prefix("cluster_barrier_try_wait", at);
|
| 1017 |
+
at += synclog_length_cluster_barrier_try_wait;
|
| 1018 |
+
printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 1019 |
+
continue;
|
| 1020 |
+
}
|
| 1021 |
+
}
|
| 1022 |
+
if constexpr (synclog_enable_cluster_barrier_arrive_cluster) {
|
| 1023 |
+
if (header == synclog_header_cluster_barrier_arrive_cluster) {
|
| 1024 |
+
synclog_print_prefix("cluster_barrier_arrive_cluster", at);
|
| 1025 |
+
at += synclog_length_cluster_barrier_arrive_cluster;
|
| 1026 |
+
printf("smem_addr=%u cta_id=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
|
| 1027 |
+
continue;
|
| 1028 |
+
}
|
| 1029 |
+
}
|
| 1030 |
+
if constexpr (synclog_enable_cluster_barrier_arrive) {
|
| 1031 |
+
if (header == synclog_header_cluster_barrier_arrive) {
|
| 1032 |
+
synclog_print_prefix("cluster_barrier_arrive", at);
|
| 1033 |
+
at += synclog_length_cluster_barrier_arrive;
|
| 1034 |
+
printf("smem_addr=%u\n", synclog_buf[at-1]);
|
| 1035 |
+
continue;
|
| 1036 |
+
}
|
| 1037 |
+
}
|
| 1038 |
+
if constexpr (synclog_enable_cluster_barrier_invalidate) {
|
| 1039 |
+
if (header == synclog_header_cluster_barrier_invalidate) {
|
| 1040 |
+
synclog_print_prefix("cluster_barrier_invalidate", at);
|
| 1041 |
+
at += synclog_length_cluster_barrier_invalidate;
|
| 1042 |
+
printf("smem_addr=%u\n", synclog_buf[at-1]);
|
| 1043 |
+
continue;
|
| 1044 |
+
}
|
| 1045 |
+
}
|
| 1046 |
+
if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) {
|
| 1047 |
+
if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) {
|
| 1048 |
+
synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at);
|
| 1049 |
+
at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx;
|
| 1050 |
+
printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 1051 |
+
continue;
|
| 1052 |
+
}
|
| 1053 |
+
}
|
| 1054 |
+
if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) {
|
| 1055 |
+
if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) {
|
| 1056 |
+
synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at);
|
| 1057 |
+
at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster;
|
| 1058 |
+
printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
|
| 1059 |
+
continue;
|
| 1060 |
+
}
|
| 1061 |
+
}
|
| 1062 |
+
if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) {
|
| 1063 |
+
if (header == synclog_header_cluster_transaction_barrier_expect_transaction) {
|
| 1064 |
+
synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at);
|
| 1065 |
+
at += synclog_length_cluster_transaction_barrier_expect_transaction;
|
| 1066 |
+
printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]);
|
| 1067 |
+
continue;
|
| 1068 |
+
}
|
| 1069 |
+
}
|
| 1070 |
+
if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) {
|
| 1071 |
+
if (header == synclog_header_cluster_transaction_barrier_complete_transaction) {
|
| 1072 |
+
synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at);
|
| 1073 |
+
at += synclog_length_cluster_transaction_barrier_complete_transaction;
|
| 1074 |
+
printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]);
|
| 1075 |
+
continue;
|
| 1076 |
+
}
|
| 1077 |
+
}
|
| 1078 |
+
if constexpr (synclog_enable_fence_barrier_init) {
|
| 1079 |
+
if (header == synclog_header_fence_barrier_init) {
|
| 1080 |
+
synclog_print_prefix("fence_barrier_init", at);
|
| 1081 |
+
at += synclog_length_fence_barrier_init;
|
| 1082 |
+
printf("\n");
|
| 1083 |
+
continue;
|
| 1084 |
+
}
|
| 1085 |
+
}
|
| 1086 |
+
if constexpr (synclog_enable_fence_view_async_shared) {
|
| 1087 |
+
if (header == synclog_header_fence_view_async_shared) {
|
| 1088 |
+
synclog_print_prefix("fence_view_async_shared", at);
|
| 1089 |
+
at += synclog_length_fence_view_async_shared;
|
| 1090 |
+
printf("\n");
|
| 1091 |
+
continue;
|
| 1092 |
+
}
|
| 1093 |
+
}
|
| 1094 |
+
if constexpr (synclog_enable_cp_async_wait) {
|
| 1095 |
+
if (header == synclog_header_cp_async_wait) {
|
| 1096 |
+
synclog_print_prefix("cp_async_wait", at);
|
| 1097 |
+
at += synclog_length_cp_async_wait;
|
| 1098 |
+
printf("n=%u\n", synclog_buf[at-1]);
|
| 1099 |
+
continue;
|
| 1100 |
+
}
|
| 1101 |
+
}
|
| 1102 |
+
if constexpr (synclog_enable_cp_async_wait_all) {
|
| 1103 |
+
if (header == synclog_header_cp_async_wait_all) {
|
| 1104 |
+
synclog_print_prefix("cp_async_wait_all", at);
|
| 1105 |
+
at += synclog_length_cp_async_wait_all;
|
| 1106 |
+
printf("\n");
|
| 1107 |
+
continue;
|
| 1108 |
+
}
|
| 1109 |
+
}
|
| 1110 |
+
if constexpr (synclog_enable_cp_async_fence) {
|
| 1111 |
+
if (header == synclog_header_cp_async_fence) {
|
| 1112 |
+
synclog_print_prefix("cp_async_fence", at);
|
| 1113 |
+
at += synclog_length_cp_async_fence;
|
| 1114 |
+
printf("\n");
|
| 1115 |
+
continue;
|
| 1116 |
+
}
|
| 1117 |
+
}
|
| 1118 |
+
if constexpr (synclog_enable_cp_async_nan) {
|
| 1119 |
+
if (header == synclog_header_cp_async_nan) {
|
| 1120 |
+
synclog_print_prefix("cp_async_nan", at);
|
| 1121 |
+
at += synclog_length_cp_async_nan;
|
| 1122 |
+
uint64_t gmem_addr = synclog_buf[at-3];
|
| 1123 |
+
gmem_addr += (uint64_t)synclog_buf[at-2] << 32;
|
| 1124 |
+
printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]);
|
| 1125 |
+
continue;
|
| 1126 |
+
}
|
| 1127 |
+
}
|
| 1128 |
+
if constexpr (synclog_enable_cp_async_zfill) {
|
| 1129 |
+
if (header == synclog_header_cp_async_zfill) {
|
| 1130 |
+
synclog_print_prefix("cp_async_zfill", at);
|
| 1131 |
+
at += synclog_length_cp_async_zfill;
|
| 1132 |
+
uint64_t gmem_addr = synclog_buf[at-4];
|
| 1133 |
+
gmem_addr += (uint64_t)synclog_buf[at-3] << 32;
|
| 1134 |
+
printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]);
|
| 1135 |
+
continue;
|
| 1136 |
+
}
|
| 1137 |
+
}
|
| 1138 |
+
if constexpr (synclog_enable_cp_async) {
|
| 1139 |
+
if (header == synclog_header_cp_async) {
|
| 1140 |
+
synclog_print_prefix("cp_async", at);
|
| 1141 |
+
at += synclog_length_cp_async;
|
| 1142 |
+
uint64_t gmem_addr = synclog_buf[at-4];
|
| 1143 |
+
gmem_addr += (uint64_t)synclog_buf[at-3] << 32;
|
| 1144 |
+
printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]);
|
| 1145 |
+
continue;
|
| 1146 |
+
}
|
| 1147 |
+
}
|
| 1148 |
+
if constexpr (synclog_enable_tma_load) {
|
| 1149 |
+
if (header == synclog_header_tma_load) {
|
| 1150 |
+
synclog_print_prefix("tma_load", at);
|
| 1151 |
+
at += synclog_length_tma_load;
|
| 1152 |
+
uint64_t gmem_int_desc = synclog_buf[at-4];
|
| 1153 |
+
gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32;
|
| 1154 |
+
printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]);
|
| 1155 |
+
continue;
|
| 1156 |
+
}
|
| 1157 |
+
}
|
| 1158 |
+
if constexpr (synclog_enable_tma_store) {
|
| 1159 |
+
if (header == synclog_header_tma_store) {
|
| 1160 |
+
synclog_print_prefix("tma_store", at);
|
| 1161 |
+
at += synclog_length_tma_store;
|
| 1162 |
+
uint64_t gmem_int_desc = synclog_buf[at-3];
|
| 1163 |
+
gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32;
|
| 1164 |
+
printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]);
|
| 1165 |
+
continue;
|
| 1166 |
+
}
|
| 1167 |
+
}
|
| 1168 |
+
if constexpr (synclog_enable_tma_store_arrive) {
|
| 1169 |
+
if (header == synclog_header_tma_store_arrive) {
|
| 1170 |
+
synclog_print_prefix("tma_store_arrive", at);
|
| 1171 |
+
at += synclog_length_tma_store_arrive;
|
| 1172 |
+
printf("\n");
|
| 1173 |
+
continue;
|
| 1174 |
+
}
|
| 1175 |
+
}
|
| 1176 |
+
if constexpr (synclog_enable_tma_store_wait) {
|
| 1177 |
+
if (header == synclog_header_tma_store_wait) {
|
| 1178 |
+
synclog_print_prefix("tma_store_wait", at);
|
| 1179 |
+
at += synclog_length_tma_store_wait;
|
| 1180 |
+
printf("count=%u\n", synclog_buf[at-1]);
|
| 1181 |
+
continue;
|
| 1182 |
+
}
|
| 1183 |
+
}
|
| 1184 |
+
if constexpr (synclog_enable_warpgroup_arrive) {
|
| 1185 |
+
if (header == synclog_header_warpgroup_arrive) {
|
| 1186 |
+
synclog_print_prefix("warpgroup_arrive", at);
|
| 1187 |
+
at += synclog_length_warpgroup_arrive;
|
| 1188 |
+
printf("\n");
|
| 1189 |
+
continue;
|
| 1190 |
+
}
|
| 1191 |
+
}
|
| 1192 |
+
if constexpr (synclog_enable_warpgroup_wait) {
|
| 1193 |
+
if (header == synclog_header_warpgroup_wait) {
|
| 1194 |
+
synclog_print_prefix("warpgroup_wait", at);
|
| 1195 |
+
at += synclog_length_warpgroup_wait;
|
| 1196 |
+
printf("n=%u\n", synclog_buf[at-1]);
|
| 1197 |
+
continue;
|
| 1198 |
+
}
|
| 1199 |
+
}
|
| 1200 |
+
if constexpr (synclog_enable_warpgroup_commit_batch) {
|
| 1201 |
+
if (header == synclog_header_warpgroup_commit_batch) {
|
| 1202 |
+
synclog_print_prefix("warpgroup_commit_batch", at);
|
| 1203 |
+
at += synclog_length_warpgroup_commit_batch;
|
| 1204 |
+
printf("\n");
|
| 1205 |
+
continue;
|
| 1206 |
+
}
|
| 1207 |
+
}
|
| 1208 |
+
if constexpr (synclog_enable_wgmma_reg_smem) {
|
| 1209 |
+
if (header == synclog_header_wgmma_reg_smem) {
|
| 1210 |
+
synclog_print_prefix("wgmma_reg_smem", at);
|
| 1211 |
+
at += synclog_length_wgmma_reg_smem;
|
| 1212 |
+
synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], "");
|
| 1213 |
+
printf("\n");
|
| 1214 |
+
continue;
|
| 1215 |
+
}
|
| 1216 |
+
}
|
| 1217 |
+
if constexpr (synclog_enable_wgmma_smem_smem) {
|
| 1218 |
+
if (header == synclog_header_wgmma_smem_smem) {
|
| 1219 |
+
synclog_print_prefix("wgmma_smem_smem", at);
|
| 1220 |
+
at += synclog_length_wgmma_smem_smem;
|
| 1221 |
+
synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " ");
|
| 1222 |
+
synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], "");
|
| 1223 |
+
printf("\n");
|
| 1224 |
+
continue;
|
| 1225 |
+
}
|
| 1226 |
+
}
|
| 1227 |
+
if constexpr (synclog_enable_cpasync_barrier_arrive) {
|
| 1228 |
+
if (header == synclog_header_cpasync_barrier_arrive) {
|
| 1229 |
+
synclog_print_prefix("cpasync_barrier_arrive", at);
|
| 1230 |
+
at += synclog_length_cpasync_barrier_arrive;
|
| 1231 |
+
printf("smem_addr=%u\n", synclog_buf[at-1]);
|
| 1232 |
+
continue;
|
| 1233 |
+
}
|
| 1234 |
+
}
|
| 1235 |
+
asm volatile ("brkpt;\n" ::);
|
| 1236 |
+
}
|
| 1237 |
+
if (synclog_buf[0] >= synclog_cap) {
|
| 1238 |
+
printf(
|
| 1239 |
+
"synclog was truncated (exceeded capacity of %lu bytes)\n",
|
| 1240 |
+
(synclog_cap - 1) * sizeof(uint32_t)
|
| 1241 |
+
);
|
| 1242 |
+
}
|
| 1243 |
+
printf("synclog end\n");
|
| 1244 |
+
#endif
|
| 1245 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 1246 |
+
}
|
| 1247 |
+
|
| 1248 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 1252 |
+
#undef __syncthreads
|
| 1253 |
+
#define __syncthreads() do {\
|
| 1254 |
+
cutlass::arch::synclog_emit_syncthreads(__LINE__);\
|
| 1255 |
+
__syncthreads();\
|
| 1256 |
+
} while (0)
|
| 1257 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 1258 |
+
|
| 1259 |
+
#if defined(CUTLASS_ENABLE_SYNCLOG)
|
| 1260 |
+
#undef __syncwarp
|
| 1261 |
+
#define __syncwarp(...) do {\
|
| 1262 |
+
cutlass::arch::synclog_emit_syncwarp(__LINE__);\
|
| 1263 |
+
__syncwarp(__VA_ARGS__);\
|
| 1264 |
+
} while (0)
|
| 1265 |
+
#endif // defined(CUTLASS_ENABLE_SYNCLOG)
|
| 1266 |
+
|
| 1267 |
+
|
| 1268 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1269 |
+
|
| 1270 |
+
} // namespace arch
|
| 1271 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#if (__CUDACC_VER_MAJOR__ >= 9)
|
| 38 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))
|
| 39 |
+
#define CUTLASS_ARCH_WMMA_ENABLED
|
| 40 |
+
#define CUTLASS_ARCH_WMMA_SM70_ENABLED
|
| 41 |
+
#endif
|
| 42 |
+
#endif
|
| 43 |
+
|
| 44 |
+
#if (__CUDACC_VER_MAJOR__ >= 10)
|
| 45 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720))
|
| 46 |
+
#define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED
|
| 47 |
+
#define CUTLASS_ARCH_WMMA_SM72_ENABLED
|
| 48 |
+
#endif
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#if (__CUDACC_VER_MAJOR__ >= 10)
|
| 52 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750))
|
| 53 |
+
#define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED
|
| 54 |
+
#define CUTLASS_ARCH_WMMA_SM75_ENABLED
|
| 55 |
+
#endif
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
| 59 |
+
|
| 60 |
+
#include <mma.h>
|
| 61 |
+
#include "cutlass/arch/mma.h"
|
| 62 |
+
#include "cutlass/array.h"
|
| 63 |
+
#include "cutlass/numeric_types.h"
|
| 64 |
+
#include "cutlass/gemm/gemm.h"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
namespace cutlass {
|
| 70 |
+
namespace arch {
|
| 71 |
+
|
| 72 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 73 |
+
/// Statically maps cutlass data types => nvcuda::wmma data types
|
| 74 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 75 |
+
template <typename Type_>
|
| 76 |
+
struct CutlassToWmmaDataType{
|
| 77 |
+
using Type = Type_;
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
/// Statically maps cutlass::half_t => __half
|
| 81 |
+
template<>
|
| 82 |
+
struct CutlassToWmmaDataType<cutlass::half_t> {
|
| 83 |
+
using Type = __half;
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
| 87 |
+
template<>
|
| 88 |
+
struct CutlassToWmmaDataType<cutlass::bfloat16_t> {
|
| 89 |
+
using Type = __nv_bfloat16;
|
| 90 |
+
};
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
/// Statically maps int8_t => char
|
| 94 |
+
template<>
|
| 95 |
+
struct CutlassToWmmaDataType<int8_t> {
|
| 96 |
+
using Type = signed char;
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
/// Statically maps uint8_t => char
|
| 100 |
+
template<>
|
| 101 |
+
struct CutlassToWmmaDataType<uint8_t> {
|
| 102 |
+
using Type = unsigned char;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
/// Statically maps int32_t => int
|
| 106 |
+
template<>
|
| 107 |
+
struct CutlassToWmmaDataType<int32_t> {
|
| 108 |
+
using Type = int;
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED)
|
| 112 |
+
/// Statically maps cutlass::int4b_t => experimental::precision::s4
|
| 113 |
+
template<>
|
| 114 |
+
struct CutlassToWmmaDataType<cutlass::int4b_t> {
|
| 115 |
+
using Type = nvcuda::wmma::experimental::precision::s4;
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
/// Statically maps cutlass::uint4b_t => experimental::precision::s4
|
| 119 |
+
template<>
|
| 120 |
+
struct CutlassToWmmaDataType<cutlass::uint4b_t> {
|
| 121 |
+
using Type = nvcuda::wmma::experimental::precision::u4;
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
/// Statically maps cutlass::uint1b_t => experimental::precision::b1
|
| 125 |
+
template<>
|
| 126 |
+
struct CutlassToWmmaDataType<cutlass::uint1b_t> {
|
| 127 |
+
using Type = nvcuda::wmma::experimental::precision::b1;
|
| 128 |
+
};
|
| 129 |
+
#endif
|
| 130 |
+
|
| 131 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 132 |
+
/// Statically maps cutlass::layout => nvcuda::wmma layout tags
|
| 133 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
template <typename Layout_>
|
| 135 |
+
struct CutlassToWmmaLayout {
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags
|
| 139 |
+
template <>
|
| 140 |
+
struct CutlassToWmmaLayout<cutlass::layout::RowMajor> {
|
| 141 |
+
using Layout = nvcuda::wmma::row_major;
|
| 142 |
+
static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major;
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 146 |
+
/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags
|
| 147 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 148 |
+
template <>
|
| 149 |
+
struct CutlassToWmmaLayout<cutlass::layout::ColumnMajor> {
|
| 150 |
+
using Layout = nvcuda::wmma::col_major;
|
| 151 |
+
static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major;
|
| 152 |
+
};
|
| 153 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 154 |
+
|
| 155 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 156 |
+
/// Statically maps nvcuda::wmma data types => cutlass data types
|
| 157 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 158 |
+
template <typename Type_>
|
| 159 |
+
struct WmmaToCutlassDataType{
|
| 160 |
+
using Type = Type_;
|
| 161 |
+
};
|
| 162 |
+
|
| 163 |
+
/// Statically maps __half => cutlass::half_t
|
| 164 |
+
template<>
|
| 165 |
+
struct WmmaToCutlassDataType<__half> {
|
| 166 |
+
using Type = cutlass::half_t;
|
| 167 |
+
};
|
| 168 |
+
|
| 169 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
| 170 |
+
template<>
|
| 171 |
+
struct WmmaToCutlassDataType<__nv_bfloat16> {
|
| 172 |
+
using Type = cutlass::bfloat16_t;
|
| 173 |
+
};
|
| 174 |
+
#endif
|
| 175 |
+
|
| 176 |
+
////////////////////////////////////////////////////////////////////////////////////////////////
|
| 177 |
+
|
| 178 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 179 |
+
// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks
|
| 180 |
+
// for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]),
|
| 181 |
+
// and native wmma size (Shape)
|
| 182 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 183 |
+
template <
|
| 184 |
+
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
|
| 185 |
+
typename ElementA_, ///< Data type of A elements
|
| 186 |
+
typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout)
|
| 187 |
+
typename ElementB_, ///< Data type of B elements
|
| 188 |
+
typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout)
|
| 189 |
+
typename ElementC_, ///< Element type of C matrix
|
| 190 |
+
typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout)
|
| 191 |
+
typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc)
|
| 192 |
+
>
|
| 193 |
+
struct Wmma;
|
| 194 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 195 |
+
|
| 196 |
+
} // namespace arch
|
| 197 |
+
} // namespace cutlass
|
| 198 |
+
|
| 199 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 200 |
+
|
| 201 |
+
//
|
| 202 |
+
// Specializations for each compute capability
|
| 203 |
+
//
|
| 204 |
+
#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED
|
| 205 |
+
#include "cutlass/arch/wmma_sm70.h"
|
| 206 |
+
#endif
|
| 207 |
+
|
| 208 |
+
#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED
|
| 209 |
+
#include "cutlass/arch/wmma_sm72.h"
|
| 210 |
+
#endif
|
| 211 |
+
|
| 212 |
+
#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED
|
| 213 |
+
#include "cutlass/arch/wmma_sm75.h"
|
| 214 |
+
#endif
|
| 215 |
+
|
| 216 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 217 |
+
|
| 218 |
+
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include CUDA_STD_HEADER(cassert)
|
| 38 |
+
#include "cutlass/layout/matrix.h"
|
| 39 |
+
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace arch {
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
//
|
| 47 |
+
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
|
| 48 |
+
// wmma native instruction sizes supported for half
|
| 49 |
+
//
|
| 50 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
template <
|
| 52 |
+
typename Shape_,
|
| 53 |
+
typename LayoutA_,
|
| 54 |
+
typename LayoutB_,
|
| 55 |
+
typename ElementC_,
|
| 56 |
+
typename LayoutC_>
|
| 57 |
+
struct Wmma<
|
| 58 |
+
Shape_, ///< Size of the matrix product (concept: GemmShape)
|
| 59 |
+
cutlass::half_t, ///< ElementA
|
| 60 |
+
LayoutA_, ///< LayoutA
|
| 61 |
+
cutlass::half_t, ///< ElementB
|
| 62 |
+
LayoutB_, ///< LayoutB
|
| 63 |
+
ElementC_, ///< ElementC
|
| 64 |
+
LayoutC_, ///< LayoutC
|
| 65 |
+
cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
|
| 66 |
+
> {
|
| 67 |
+
|
| 68 |
+
#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED)
|
| 69 |
+
using Shape = Shape_;
|
| 70 |
+
using ElementA = cutlass::half_t;
|
| 71 |
+
using LayoutA = LayoutA_;
|
| 72 |
+
using ElementB = cutlass::half_t;
|
| 73 |
+
using LayoutB = LayoutB_;
|
| 74 |
+
using ElementC = ElementC_;
|
| 75 |
+
using LayoutC = LayoutC_;
|
| 76 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 77 |
+
using ArchTag = arch::Sm70;
|
| 78 |
+
|
| 79 |
+
// check supported wmma shape for the given multiplicand data types
|
| 80 |
+
static_assert(
|
| 81 |
+
platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
|
| 82 |
+
platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
|
| 83 |
+
platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
|
| 84 |
+
"Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
|
| 85 |
+
|
| 86 |
+
// check supported wmma output data type for the given multiplicand data types
|
| 87 |
+
static_assert(
|
| 88 |
+
platform::is_same<cutlass::half_t, ElementC>::value || platform::is_same<float, ElementC>::value,
|
| 89 |
+
"Supported of wmma output data type for f16 multiplicands are: f16 and f32");
|
| 90 |
+
|
| 91 |
+
// Wmma Fragment
|
| 92 |
+
using FragmentA = nvcuda::wmma::fragment<
|
| 93 |
+
nvcuda::wmma::matrix_a,
|
| 94 |
+
Shape::kM,
|
| 95 |
+
Shape::kN,
|
| 96 |
+
Shape::kK,
|
| 97 |
+
typename CutlassToWmmaDataType<ElementA>::Type,
|
| 98 |
+
typename CutlassToWmmaLayout<LayoutA>::Layout>;
|
| 99 |
+
|
| 100 |
+
using FragmentB = nvcuda::wmma::fragment<
|
| 101 |
+
nvcuda::wmma::matrix_b,
|
| 102 |
+
Shape::kM,
|
| 103 |
+
Shape::kN,
|
| 104 |
+
Shape::kK,
|
| 105 |
+
typename CutlassToWmmaDataType<ElementB>::Type,
|
| 106 |
+
typename CutlassToWmmaLayout<LayoutB>::Layout>;
|
| 107 |
+
|
| 108 |
+
using FragmentC = nvcuda::wmma::fragment<
|
| 109 |
+
nvcuda::wmma::accumulator,
|
| 110 |
+
Shape::kM,
|
| 111 |
+
Shape::kN,
|
| 112 |
+
Shape::kK,
|
| 113 |
+
typename CutlassToWmmaDataType<ElementC>::Type>;
|
| 114 |
+
|
| 115 |
+
/// Performs a nvcuda::wmma matrix multiply-accumulate operation
|
| 116 |
+
CUTLASS_DEVICE
|
| 117 |
+
void operator()(
|
| 118 |
+
FragmentC &D,
|
| 119 |
+
FragmentA const &A,
|
| 120 |
+
FragmentB const &B,
|
| 121 |
+
FragmentC const &C) const {
|
| 122 |
+
|
| 123 |
+
nvcuda::wmma::mma_sync(D, A, B, C);
|
| 124 |
+
}
|
| 125 |
+
#else
|
| 126 |
+
static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond");
|
| 127 |
+
#endif
|
| 128 |
+
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
} // namespace arch
|
| 132 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include CUDA_STD_HEADER(cassert)
|
| 38 |
+
#include "cutlass/layout/matrix.h"
|
| 39 |
+
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace arch {
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
//
|
| 46 |
+
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
|
| 47 |
+
// wmma native instruction sizes supported for int8_t
|
| 48 |
+
//
|
| 49 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
template <
|
| 51 |
+
typename Shape_,
|
| 52 |
+
typename LayoutA_,
|
| 53 |
+
typename LayoutB_,
|
| 54 |
+
typename LayoutC_>
|
| 55 |
+
struct Wmma<
|
| 56 |
+
Shape_, ///< Size of the matrix product (concept: GemmShape)
|
| 57 |
+
int8_t, ///< ElementA
|
| 58 |
+
LayoutA_, ///< LayoutA
|
| 59 |
+
int8_t, ///< ElementB
|
| 60 |
+
LayoutB_, ///< LayoutB
|
| 61 |
+
int32_t, ///< ElementC
|
| 62 |
+
LayoutC_, ///< LayoutC
|
| 63 |
+
cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
|
| 64 |
+
> {
|
| 65 |
+
#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
|
| 66 |
+
using Shape = Shape_;
|
| 67 |
+
using ElementA = int8_t;
|
| 68 |
+
using LayoutA = LayoutA_;
|
| 69 |
+
using ElementB = int8_t;
|
| 70 |
+
using LayoutB = LayoutB_;
|
| 71 |
+
using ElementC = int32_t;
|
| 72 |
+
using LayoutC = LayoutC_;
|
| 73 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 74 |
+
using ArchTag = arch::Sm72;
|
| 75 |
+
|
| 76 |
+
// check supported wmma shape for the given multiplicand data types
|
| 77 |
+
static_assert(
|
| 78 |
+
platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
|
| 79 |
+
platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
|
| 80 |
+
platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
|
| 81 |
+
"Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
// Wmma Fragment
|
| 85 |
+
using FragmentA = nvcuda::wmma::fragment<
|
| 86 |
+
nvcuda::wmma::matrix_a,
|
| 87 |
+
Shape::kM,
|
| 88 |
+
Shape::kN,
|
| 89 |
+
Shape::kK,
|
| 90 |
+
typename CutlassToWmmaDataType<ElementA>::Type,
|
| 91 |
+
typename CutlassToWmmaLayout<LayoutA>::Layout>;
|
| 92 |
+
|
| 93 |
+
using FragmentB = nvcuda::wmma::fragment<
|
| 94 |
+
nvcuda::wmma::matrix_b,
|
| 95 |
+
Shape::kM,
|
| 96 |
+
Shape::kN,
|
| 97 |
+
Shape::kK,
|
| 98 |
+
typename CutlassToWmmaDataType<ElementB>::Type,
|
| 99 |
+
typename CutlassToWmmaLayout<LayoutB>::Layout>;
|
| 100 |
+
|
| 101 |
+
using FragmentC = nvcuda::wmma::fragment<
|
| 102 |
+
nvcuda::wmma::accumulator,
|
| 103 |
+
Shape::kM,
|
| 104 |
+
Shape::kN,
|
| 105 |
+
Shape::kK,
|
| 106 |
+
typename CutlassToWmmaDataType<ElementC>::Type>;
|
| 107 |
+
|
| 108 |
+
/// Performs a nvcuda::wmma matrix multiply-accumulate operation
|
| 109 |
+
CUTLASS_DEVICE
|
| 110 |
+
void operator()(
|
| 111 |
+
FragmentC &D,
|
| 112 |
+
FragmentA const &A,
|
| 113 |
+
FragmentB const &B,
|
| 114 |
+
FragmentC const &C) const {
|
| 115 |
+
|
| 116 |
+
nvcuda::wmma::mma_sync(D, A, B, C);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
#else
|
| 120 |
+
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond");
|
| 121 |
+
#endif
|
| 122 |
+
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 126 |
+
//
|
| 127 |
+
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
|
| 128 |
+
// wmma native instruction sizes supported for uint8_t
|
| 129 |
+
//
|
| 130 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 131 |
+
template <
|
| 132 |
+
typename Shape_,
|
| 133 |
+
typename LayoutA_,
|
| 134 |
+
typename LayoutB_,
|
| 135 |
+
typename LayoutC_>
|
| 136 |
+
struct Wmma<
|
| 137 |
+
Shape_, ///< Size of the matrix product (concept: GemmShape)
|
| 138 |
+
uint8_t, ///< ElementA
|
| 139 |
+
LayoutA_, ///< LayoutA
|
| 140 |
+
uint8_t, ///< ElementB
|
| 141 |
+
LayoutB_, ///< LayoutB
|
| 142 |
+
int32_t, ///< ElementC
|
| 143 |
+
LayoutC_, ///< LayoutC
|
| 144 |
+
cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
|
| 145 |
+
> {
|
| 146 |
+
#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
|
| 147 |
+
using Shape = Shape_;
|
| 148 |
+
using ElementA = uint8_t;
|
| 149 |
+
using LayoutA = LayoutA_;
|
| 150 |
+
using ElementB = uint8_t;
|
| 151 |
+
using LayoutB = LayoutB_;
|
| 152 |
+
using ElementC = int32_t;
|
| 153 |
+
using LayoutC = LayoutC_;
|
| 154 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 155 |
+
using ArchTag = arch::Sm72;
|
| 156 |
+
|
| 157 |
+
// check supported wmma shape for the given multiplicand data types
|
| 158 |
+
static_assert(
|
| 159 |
+
platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
|
| 160 |
+
platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
|
| 161 |
+
platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
|
| 162 |
+
"Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
|
| 163 |
+
|
| 164 |
+
// Wmma Fragment
|
| 165 |
+
using FragmentA = nvcuda::wmma::fragment<
|
| 166 |
+
nvcuda::wmma::matrix_a,
|
| 167 |
+
Shape::kM,
|
| 168 |
+
Shape::kN,
|
| 169 |
+
Shape::kK,
|
| 170 |
+
typename CutlassToWmmaDataType<ElementA>::Type,
|
| 171 |
+
typename CutlassToWmmaLayout<LayoutA>::Layout>;
|
| 172 |
+
|
| 173 |
+
using FragmentB = nvcuda::wmma::fragment<
|
| 174 |
+
nvcuda::wmma::matrix_b,
|
| 175 |
+
Shape::kM,
|
| 176 |
+
Shape::kN,
|
| 177 |
+
Shape::kK,
|
| 178 |
+
typename CutlassToWmmaDataType<ElementB>::Type,
|
| 179 |
+
typename CutlassToWmmaLayout<LayoutB>::Layout>;
|
| 180 |
+
|
| 181 |
+
using FragmentC = nvcuda::wmma::fragment<
|
| 182 |
+
nvcuda::wmma::accumulator,
|
| 183 |
+
Shape::kM,
|
| 184 |
+
Shape::kN,
|
| 185 |
+
Shape::kK,
|
| 186 |
+
typename CutlassToWmmaDataType<ElementC>::Type>;
|
| 187 |
+
|
| 188 |
+
/// Performs a nvcuda::wmma matrix multiply-accumulate operation
|
| 189 |
+
CUTLASS_DEVICE
|
| 190 |
+
void operator()(
|
| 191 |
+
FragmentC &D,
|
| 192 |
+
FragmentA const &A,
|
| 193 |
+
FragmentB const &B,
|
| 194 |
+
FragmentC const &C) const {
|
| 195 |
+
|
| 196 |
+
nvcuda::wmma::mma_sync(D, A, B, C);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
#else
|
| 200 |
+
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond");
|
| 201 |
+
#endif
|
| 202 |
+
|
| 203 |
+
};
|
| 204 |
+
|
| 205 |
+
} // namespace arch
|
| 206 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Matrix multiply
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include CUDA_STD_HEADER(cassert)
|
| 38 |
+
#include "cutlass/layout/matrix.h"
|
| 39 |
+
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace arch {
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
//
|
| 46 |
+
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
|
| 47 |
+
// wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4).
|
| 48 |
+
//
|
| 49 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
template <
|
| 51 |
+
typename Shape_,
|
| 52 |
+
typename LayoutA_,
|
| 53 |
+
typename LayoutB_,
|
| 54 |
+
typename LayoutC_>
|
| 55 |
+
struct Wmma<
|
| 56 |
+
Shape_, ///< Size of the matrix product (concept: GemmShape)
|
| 57 |
+
cutlass::int4b_t, ///< ElementA
|
| 58 |
+
LayoutA_, ///< LayoutA
|
| 59 |
+
cutlass::int4b_t, ///< ElementB
|
| 60 |
+
LayoutB_, ///< LayoutB
|
| 61 |
+
int32_t, ///< ElementC
|
| 62 |
+
LayoutC_, ///< LayoutC
|
| 63 |
+
cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc)
|
| 64 |
+
> {
|
| 65 |
+
#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
|
| 66 |
+
using Shape = Shape_;
|
| 67 |
+
using ElementA = cutlass::int4b_t;
|
| 68 |
+
using LayoutA = LayoutA_;
|
| 69 |
+
using ElementB = cutlass::int4b_t;
|
| 70 |
+
using LayoutB = LayoutB_;
|
| 71 |
+
using ElementC = int32_t;
|
| 72 |
+
using LayoutC = LayoutC_;
|
| 73 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 74 |
+
using ArchTag = arch::Sm75;
|
| 75 |
+
|
| 76 |
+
// check supported wmma shape for the given multiplicand data types
|
| 77 |
+
static_assert(
|
| 78 |
+
platform::is_same<cutlass::gemm::GemmShape<8, 8, 32>, Shape>::value,
|
| 79 |
+
"Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
// Wmma Fragment
|
| 83 |
+
using FragmentA = nvcuda::wmma::fragment<
|
| 84 |
+
nvcuda::wmma::matrix_a,
|
| 85 |
+
Shape::kM,
|
| 86 |
+
Shape::kN,
|
| 87 |
+
Shape::kK,
|
| 88 |
+
typename CutlassToWmmaDataType<ElementA>::Type,
|
| 89 |
+
typename CutlassToWmmaLayout<LayoutA>::Layout>;
|
| 90 |
+
|
| 91 |
+
using FragmentB = nvcuda::wmma::fragment<
|
| 92 |
+
nvcuda::wmma::matrix_b,
|
| 93 |
+
Shape::kM,
|
| 94 |
+
Shape::kN,
|
| 95 |
+
Shape::kK,
|
| 96 |
+
typename CutlassToWmmaDataType<ElementB>::Type,
|
| 97 |
+
typename CutlassToWmmaLayout<LayoutB>::Layout>;
|
| 98 |
+
|
| 99 |
+
using FragmentC = nvcuda::wmma::fragment<
|
| 100 |
+
nvcuda::wmma::accumulator,
|
| 101 |
+
Shape::kM,
|
| 102 |
+
Shape::kN,
|
| 103 |
+
Shape::kK,
|
| 104 |
+
typename CutlassToWmmaDataType<ElementC>::Type>;
|
| 105 |
+
|
| 106 |
+
/// Performs a nvcuda::wmma matrix multiply-accumulate operation
|
| 107 |
+
CUTLASS_DEVICE
|
| 108 |
+
void operator()(
|
| 109 |
+
FragmentC &D,
|
| 110 |
+
FragmentA const &A,
|
| 111 |
+
FragmentB const &B,
|
| 112 |
+
FragmentC const &C) const {
|
| 113 |
+
nvcuda::wmma::mma_sync(D, A, B, C);
|
| 114 |
+
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
#else
|
| 118 |
+
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond");
|
| 119 |
+
#endif
|
| 120 |
+
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 124 |
+
//
|
| 125 |
+
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
|
| 126 |
+
// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1).
|
| 127 |
+
//
|
| 128 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 129 |
+
template <
|
| 130 |
+
typename Shape_,
|
| 131 |
+
typename LayoutA_,
|
| 132 |
+
typename LayoutB_,
|
| 133 |
+
typename LayoutC_>
|
| 134 |
+
struct Wmma<
|
| 135 |
+
Shape_, ///< Size of the matrix product (concept: GemmShape)
|
| 136 |
+
cutlass::uint1b_t, ///< ElementA
|
| 137 |
+
LayoutA_, ///< LayoutA
|
| 138 |
+
cutlass::uint1b_t, ///< ElementB
|
| 139 |
+
LayoutB_, ///< LayoutB
|
| 140 |
+
int32_t, ///< ElementC
|
| 141 |
+
LayoutC_, ///< LayoutC
|
| 142 |
+
cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc)
|
| 143 |
+
> {
|
| 144 |
+
#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
|
| 145 |
+
using Shape = Shape_;
|
| 146 |
+
using ElementA = cutlass::uint1b_t;
|
| 147 |
+
using LayoutA = LayoutA_;
|
| 148 |
+
using ElementB = cutlass::uint1b_t;
|
| 149 |
+
using LayoutB = LayoutB_;
|
| 150 |
+
using ElementC = int32_t;
|
| 151 |
+
using LayoutC = LayoutC_;
|
| 152 |
+
using Operator = cutlass::arch::OpXorPopc;
|
| 153 |
+
using ArchTag = arch::Sm75;
|
| 154 |
+
|
| 155 |
+
// check supported wmma shape for the given multiplicand data types
|
| 156 |
+
static_assert(
|
| 157 |
+
platform::is_same<cutlass::gemm::GemmShape<8, 8, 128>, Shape>::value,
|
| 158 |
+
"Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
// Wmma Fragment
|
| 162 |
+
using FragmentA = nvcuda::wmma::fragment<
|
| 163 |
+
nvcuda::wmma::matrix_a,
|
| 164 |
+
Shape::kM,
|
| 165 |
+
Shape::kN,
|
| 166 |
+
Shape::kK,
|
| 167 |
+
typename CutlassToWmmaDataType<ElementA>::Type,
|
| 168 |
+
typename CutlassToWmmaLayout<LayoutA>::Layout>;
|
| 169 |
+
|
| 170 |
+
using FragmentB = nvcuda::wmma::fragment<
|
| 171 |
+
nvcuda::wmma::matrix_b,
|
| 172 |
+
Shape::kM,
|
| 173 |
+
Shape::kN,
|
| 174 |
+
Shape::kK,
|
| 175 |
+
typename CutlassToWmmaDataType<ElementB>::Type,
|
| 176 |
+
typename CutlassToWmmaLayout<LayoutB>::Layout>;
|
| 177 |
+
|
| 178 |
+
using FragmentC = nvcuda::wmma::fragment<
|
| 179 |
+
nvcuda::wmma::accumulator,
|
| 180 |
+
Shape::kM,
|
| 181 |
+
Shape::kN,
|
| 182 |
+
Shape::kK,
|
| 183 |
+
typename CutlassToWmmaDataType<ElementC>::Type>;
|
| 184 |
+
|
| 185 |
+
/// Performs a nvcuda::wmma matrix multiply-accumulate operation
|
| 186 |
+
CUTLASS_DEVICE
|
| 187 |
+
void operator()(
|
| 188 |
+
FragmentC &D,
|
| 189 |
+
FragmentA const &A,
|
| 190 |
+
FragmentB const &B,
|
| 191 |
+
FragmentC const &C) const {
|
| 192 |
+
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
|
| 193 |
+
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
#else
|
| 197 |
+
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond");
|
| 198 |
+
#endif
|
| 199 |
+
|
| 200 |
+
};
|
| 201 |
+
|
| 202 |
+
} // namespace arch
|
| 203 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h
ADDED
|
@@ -0,0 +1,2860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
|
| 33 |
+
and is safe to use in a union.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/functional.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/platform/platform.h"
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
/// Statically sized array for any data type
|
| 46 |
+
template <
|
| 47 |
+
typename T,
|
| 48 |
+
int N,
|
| 49 |
+
bool RegisterSized = sizeof_bits<T>::value >= 32
|
| 50 |
+
>
|
| 51 |
+
struct Array;
|
| 52 |
+
|
| 53 |
+
namespace detail {
|
| 54 |
+
|
| 55 |
+
template<class T>
|
| 56 |
+
struct is_Array : platform::false_type {};
|
| 57 |
+
|
| 58 |
+
template <
|
| 59 |
+
typename T,
|
| 60 |
+
int N,
|
| 61 |
+
bool RegisterSized
|
| 62 |
+
>
|
| 63 |
+
struct is_Array<Array<T, N, RegisterSized> > : platform::true_type {};
|
| 64 |
+
|
| 65 |
+
template<typename T>
|
| 66 |
+
constexpr bool is_Array_v = is_Array<T>::value;
|
| 67 |
+
|
| 68 |
+
} // namespace detail
|
| 69 |
+
|
| 70 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
|
| 72 |
+
/// Defines the size of an Array<> in bits
|
| 73 |
+
template <typename T, int N, bool RegisterSized>
|
| 74 |
+
struct sizeof_bits<Array<T, N, RegisterSized> > {
|
| 75 |
+
static constexpr int value = sizeof(Array<T, N, RegisterSized>) * 8;
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 79 |
+
|
| 80 |
+
/// Returns true if the argument is a power of 2
|
| 81 |
+
CUTLASS_HOST_DEVICE
|
| 82 |
+
constexpr bool ispow2(unsigned x) {
|
| 83 |
+
return x && (!(x & (x - 1)));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 87 |
+
|
| 88 |
+
/// Returns the largest power of two not greater than the argument.
|
| 89 |
+
CUTLASS_HOST_DEVICE
|
| 90 |
+
constexpr unsigned floor_pow_2(unsigned x) {
|
| 91 |
+
return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 95 |
+
|
| 96 |
+
/// Statically sized array for any data type
|
| 97 |
+
template <
|
| 98 |
+
typename T,
|
| 99 |
+
int N
|
| 100 |
+
>
|
| 101 |
+
struct Array<T, N, true> {
|
| 102 |
+
|
| 103 |
+
/// Storage type
|
| 104 |
+
using Storage = T;
|
| 105 |
+
|
| 106 |
+
/// Element type
|
| 107 |
+
using Element = T;
|
| 108 |
+
|
| 109 |
+
/// Number of storage elements
|
| 110 |
+
//static std::size_t const kStorageElements = N;
|
| 111 |
+
static constexpr size_t kStorageElements = N;
|
| 112 |
+
|
| 113 |
+
/// Number of logical elements
|
| 114 |
+
static constexpr size_t kElements = N;
|
| 115 |
+
|
| 116 |
+
//
|
| 117 |
+
// C++ standard members
|
| 118 |
+
//
|
| 119 |
+
|
| 120 |
+
typedef T value_type;
|
| 121 |
+
typedef size_t size_type;
|
| 122 |
+
typedef ptrdiff_t difference_type;
|
| 123 |
+
typedef value_type &reference;
|
| 124 |
+
typedef value_type const & const_reference;
|
| 125 |
+
typedef value_type *pointer;
|
| 126 |
+
typedef value_type const * const_pointer;
|
| 127 |
+
|
| 128 |
+
//
|
| 129 |
+
// Iterators
|
| 130 |
+
//
|
| 131 |
+
|
| 132 |
+
/// Bidirectional iterator over elements
|
| 133 |
+
class iterator {
|
| 134 |
+
|
| 135 |
+
/// Pointer to object
|
| 136 |
+
T *ptr_;
|
| 137 |
+
|
| 138 |
+
public:
|
| 139 |
+
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
iterator(): ptr_(nullptr) { }
|
| 142 |
+
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
iterator(T *_ptr): ptr_(_ptr) { }
|
| 145 |
+
|
| 146 |
+
CUTLASS_HOST_DEVICE
|
| 147 |
+
iterator &operator++() {
|
| 148 |
+
++ptr_;
|
| 149 |
+
return *this;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
CUTLASS_HOST_DEVICE
|
| 153 |
+
iterator &operator--() {
|
| 154 |
+
--ptr_;
|
| 155 |
+
return *this;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
CUTLASS_HOST_DEVICE
|
| 159 |
+
iterator operator++(int) {
|
| 160 |
+
iterator ret(*this);
|
| 161 |
+
++ptr_;
|
| 162 |
+
return ret;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
CUTLASS_HOST_DEVICE
|
| 166 |
+
iterator operator--(int) {
|
| 167 |
+
iterator ret(*this);
|
| 168 |
+
--ptr_;
|
| 169 |
+
return ret;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
CUTLASS_HOST_DEVICE
|
| 173 |
+
T &operator*() const {
|
| 174 |
+
return *ptr_;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
CUTLASS_HOST_DEVICE
|
| 178 |
+
bool operator==(iterator const &other) const {
|
| 179 |
+
return ptr_ == other.ptr_;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
CUTLASS_HOST_DEVICE
|
| 183 |
+
bool operator!=(iterator const &other) const {
|
| 184 |
+
return ptr_ != other.ptr_;
|
| 185 |
+
}
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
/// Bidirectional constant iterator over elements
|
| 189 |
+
class const_iterator {
|
| 190 |
+
|
| 191 |
+
/// Pointer to object
|
| 192 |
+
const T *ptr_;
|
| 193 |
+
|
| 194 |
+
public:
|
| 195 |
+
|
| 196 |
+
CUTLASS_HOST_DEVICE
|
| 197 |
+
const_iterator(): ptr_(nullptr) { }
|
| 198 |
+
|
| 199 |
+
CUTLASS_HOST_DEVICE
|
| 200 |
+
const_iterator(T const *_ptr): ptr_(_ptr) { }
|
| 201 |
+
|
| 202 |
+
CUTLASS_HOST_DEVICE
|
| 203 |
+
const_iterator &operator++() {
|
| 204 |
+
++ptr_;
|
| 205 |
+
return *this;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
CUTLASS_HOST_DEVICE
|
| 209 |
+
const_iterator &operator--() {
|
| 210 |
+
--ptr_;
|
| 211 |
+
return *this;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
const_iterator operator++(int) {
|
| 216 |
+
const_iterator ret(*this);
|
| 217 |
+
++ptr_;
|
| 218 |
+
return ret;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
CUTLASS_HOST_DEVICE
|
| 222 |
+
const_iterator operator--(int) {
|
| 223 |
+
const_iterator ret(*this);
|
| 224 |
+
--ptr_;
|
| 225 |
+
return ret;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
CUTLASS_HOST_DEVICE
|
| 229 |
+
T const &operator*() const {
|
| 230 |
+
return *ptr_;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
CUTLASS_HOST_DEVICE
|
| 234 |
+
bool operator==(const_iterator const &other) const {
|
| 235 |
+
return ptr_ == other.ptr_;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
CUTLASS_HOST_DEVICE
|
| 239 |
+
bool operator!=(const_iterator const &other) const {
|
| 240 |
+
return ptr_ != other.ptr_;
|
| 241 |
+
}
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
/// Bidirectional iterator over elements
|
| 245 |
+
class reverse_iterator {
|
| 246 |
+
|
| 247 |
+
/// Pointer to object
|
| 248 |
+
T *ptr_;
|
| 249 |
+
|
| 250 |
+
public:
|
| 251 |
+
|
| 252 |
+
CUTLASS_HOST_DEVICE
|
| 253 |
+
reverse_iterator(): ptr_(nullptr) { }
|
| 254 |
+
|
| 255 |
+
CUTLASS_HOST_DEVICE
|
| 256 |
+
reverse_iterator(T *_ptr): ptr_(_ptr) { }
|
| 257 |
+
|
| 258 |
+
CUTLASS_HOST_DEVICE
|
| 259 |
+
reverse_iterator &operator++() {
|
| 260 |
+
--ptr_;
|
| 261 |
+
return *this;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
CUTLASS_HOST_DEVICE
|
| 265 |
+
reverse_iterator &operator--() {
|
| 266 |
+
++ptr_;
|
| 267 |
+
return *this;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
CUTLASS_HOST_DEVICE
|
| 271 |
+
reverse_iterator operator++(int) {
|
| 272 |
+
iterator ret(*this);
|
| 273 |
+
--ptr_;
|
| 274 |
+
return ret;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
CUTLASS_HOST_DEVICE
|
| 278 |
+
reverse_iterator operator--(int) {
|
| 279 |
+
iterator ret(*this);
|
| 280 |
+
++ptr_;
|
| 281 |
+
return ret;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
CUTLASS_HOST_DEVICE
|
| 285 |
+
T &operator*() const {
|
| 286 |
+
return *(ptr_ - 1);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
CUTLASS_HOST_DEVICE
|
| 290 |
+
bool operator==(reverse_iterator const &other) const {
|
| 291 |
+
return ptr_ == other.ptr_;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
CUTLASS_HOST_DEVICE
|
| 295 |
+
bool operator!=(reverse_iterator const &other) const {
|
| 296 |
+
return ptr_ != other.ptr_;
|
| 297 |
+
}
|
| 298 |
+
};
|
| 299 |
+
|
| 300 |
+
/// Bidirectional constant iterator over elements
|
| 301 |
+
class const_reverse_iterator {
|
| 302 |
+
|
| 303 |
+
/// Pointer to object
|
| 304 |
+
T const *ptr_;
|
| 305 |
+
|
| 306 |
+
public:
|
| 307 |
+
|
| 308 |
+
CUTLASS_HOST_DEVICE
|
| 309 |
+
const_reverse_iterator(): ptr_(nullptr) { }
|
| 310 |
+
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
const_reverse_iterator(T const *_ptr): ptr_(_ptr) { }
|
| 313 |
+
|
| 314 |
+
CUTLASS_HOST_DEVICE
|
| 315 |
+
const_reverse_iterator &operator++() {
|
| 316 |
+
--ptr_;
|
| 317 |
+
return *this;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
CUTLASS_HOST_DEVICE
|
| 321 |
+
const_reverse_iterator &operator--() {
|
| 322 |
+
++ptr_;
|
| 323 |
+
return *this;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
CUTLASS_HOST_DEVICE
|
| 327 |
+
const_reverse_iterator operator++(int) {
|
| 328 |
+
const_reverse_iterator ret(*this);
|
| 329 |
+
--ptr_;
|
| 330 |
+
return ret;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
CUTLASS_HOST_DEVICE
|
| 334 |
+
const_reverse_iterator operator--(int) {
|
| 335 |
+
const_reverse_iterator ret(*this);
|
| 336 |
+
++ptr_;
|
| 337 |
+
return ret;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
CUTLASS_HOST_DEVICE
|
| 341 |
+
T const &operator*() const {
|
| 342 |
+
return *(ptr_ - 1);
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
CUTLASS_HOST_DEVICE
|
| 346 |
+
bool operator==(const_iterator const &other) const {
|
| 347 |
+
return ptr_ == other.ptr_;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
CUTLASS_HOST_DEVICE
|
| 351 |
+
bool operator!=(const_iterator const &other) const {
|
| 352 |
+
return ptr_ != other.ptr_;
|
| 353 |
+
}
|
| 354 |
+
};
|
| 355 |
+
|
| 356 |
+
/// Internal storage
|
| 357 |
+
Storage storage[kElements];
|
| 358 |
+
|
| 359 |
+
/// Efficient clear method
|
| 360 |
+
CUTLASS_HOST_DEVICE
|
| 361 |
+
void clear() {
|
| 362 |
+
fill(T(0));
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
CUTLASS_HOST_DEVICE
|
| 366 |
+
reference at(size_type pos) {
|
| 367 |
+
return reinterpret_cast<reference>(storage[pos]);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
CUTLASS_HOST_DEVICE
|
| 371 |
+
const_reference at(size_type pos) const {
|
| 372 |
+
return reinterpret_cast<const_reference>(storage[pos]);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
CUTLASS_HOST_DEVICE
|
| 376 |
+
reference operator[](size_type pos) {
|
| 377 |
+
return reinterpret_cast<reference>(storage[pos]);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
CUTLASS_HOST_DEVICE
|
| 381 |
+
const_reference operator[](size_type pos) const {
|
| 382 |
+
return reinterpret_cast<const_reference>(storage[pos]);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
CUTLASS_HOST_DEVICE
|
| 386 |
+
reference front() {
|
| 387 |
+
return reinterpret_cast<reference>(storage[0]);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
CUTLASS_HOST_DEVICE
|
| 391 |
+
const_reference front() const {
|
| 392 |
+
return reinterpret_cast<const_reference>(storage[0]);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
CUTLASS_HOST_DEVICE
|
| 396 |
+
reference back() {
|
| 397 |
+
return reinterpret_cast<reference>(storage[kStorageElements - 1]);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
CUTLASS_HOST_DEVICE
|
| 401 |
+
const_reference back() const {
|
| 402 |
+
return reinterpret_cast<const_reference>(storage[kStorageElements - 1]);
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
CUTLASS_HOST_DEVICE
|
| 406 |
+
pointer data() {
|
| 407 |
+
return reinterpret_cast<pointer>(storage);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
CUTLASS_HOST_DEVICE
|
| 411 |
+
const_pointer data() const {
|
| 412 |
+
return reinterpret_cast<const_pointer>(storage);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
CUTLASS_HOST_DEVICE
|
| 416 |
+
pointer raw_data() {
|
| 417 |
+
return reinterpret_cast<pointer>(storage);
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
CUTLASS_HOST_DEVICE
|
| 421 |
+
const_pointer raw_data() const {
|
| 422 |
+
return reinterpret_cast<const_pointer>(storage);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
CUTLASS_HOST_DEVICE
|
| 427 |
+
constexpr bool empty() const {
|
| 428 |
+
return !kElements;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
CUTLASS_HOST_DEVICE
|
| 432 |
+
constexpr size_type size() const {
|
| 433 |
+
return kElements;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
CUTLASS_HOST_DEVICE
|
| 437 |
+
constexpr size_type max_size() const {
|
| 438 |
+
return kElements;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
CUTLASS_HOST_DEVICE
|
| 442 |
+
void fill(T const &value) {
|
| 443 |
+
CUTLASS_PRAGMA_UNROLL
|
| 444 |
+
for (int i = 0; i < int(kElements); ++i) {
|
| 445 |
+
storage[i] = static_cast<Storage>(value);
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
CUTLASS_HOST_DEVICE
|
| 450 |
+
iterator begin() {
|
| 451 |
+
return iterator(storage);
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
CUTLASS_HOST_DEVICE
|
| 455 |
+
const_iterator begin() const {
|
| 456 |
+
return cbegin();
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
CUTLASS_HOST_DEVICE
|
| 460 |
+
const_iterator cbegin() const {
|
| 461 |
+
return const_iterator(storage);
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
CUTLASS_HOST_DEVICE
|
| 465 |
+
iterator end() {
|
| 466 |
+
return iterator(reinterpret_cast<pointer>(storage + kStorageElements));
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
CUTLASS_HOST_DEVICE
|
| 470 |
+
const_iterator end() const {
|
| 471 |
+
return cend();
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
CUTLASS_HOST_DEVICE
|
| 475 |
+
const_iterator cend() const {
|
| 476 |
+
return const_iterator(reinterpret_cast<const_pointer>(storage + kStorageElements));
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
CUTLASS_HOST_DEVICE
|
| 480 |
+
reverse_iterator rbegin() {
|
| 481 |
+
return reverse_iterator(reinterpret_cast<pointer>(storage + kStorageElements));
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
CUTLASS_HOST_DEVICE
|
| 485 |
+
const_reverse_iterator rbegin() const {
|
| 486 |
+
return crbegin();
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
CUTLASS_HOST_DEVICE
|
| 490 |
+
const_reverse_iterator crbegin() const {
|
| 491 |
+
return const_reverse_iterator(reinterpret_cast<const_pointer>(storage + kStorageElements));
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
CUTLASS_HOST_DEVICE
|
| 495 |
+
reverse_iterator rend() {
|
| 496 |
+
return reverse_iterator(reinterpret_cast<pointer>(storage));
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
CUTLASS_HOST_DEVICE
|
| 500 |
+
const_reverse_iterator rend() const {
|
| 501 |
+
return crend();
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
CUTLASS_HOST_DEVICE
|
| 505 |
+
const_reverse_iterator crend() const {
|
| 506 |
+
return const_reverse_iterator(reinterpret_cast<const_pointer>(storage));
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
//
|
| 510 |
+
// Comparison operators
|
| 511 |
+
//
|
| 512 |
+
|
| 513 |
+
};
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 517 |
+
// Factories
|
| 518 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 519 |
+
|
| 520 |
+
template <typename Element>
|
| 521 |
+
CUTLASS_HOST_DEVICE
|
| 522 |
+
Array<Element, 1> make_Array(Element x) {
|
| 523 |
+
return {x};
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
template <typename Element>
|
| 527 |
+
CUTLASS_HOST_DEVICE
|
| 528 |
+
Array<Element, 2> make_Array(Element x, Element y) {
|
| 529 |
+
return {x,y};
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
template <typename Element>
|
| 533 |
+
CUTLASS_HOST_DEVICE
|
| 534 |
+
Array<Element, 3> make_Array(Element x, Element y, Element z) {
|
| 535 |
+
return {x,y,z};
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
template <typename Element>
|
| 539 |
+
CUTLASS_HOST_DEVICE
|
| 540 |
+
Array<Element, 4> make_Array(Element x, Element y, Element z, Element w) {
|
| 541 |
+
return {x,y,z,w};
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 546 |
+
// functional.h numeric specializations
|
| 547 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 548 |
+
|
| 549 |
+
template <typename T, int N>
|
| 550 |
+
struct absolute_value_op< Array<T, N> > {
|
| 551 |
+
|
| 552 |
+
CUTLASS_HOST_DEVICE
|
| 553 |
+
Array<T, N> operator()(Array<T, N> const &lhs) const {
|
| 554 |
+
|
| 555 |
+
Array<T, N> result;
|
| 556 |
+
absolute_value_op<T> scalar_op;
|
| 557 |
+
|
| 558 |
+
CUTLASS_PRAGMA_UNROLL
|
| 559 |
+
for (int i = 0; i < N; ++i) {
|
| 560 |
+
result[i] = scalar_op(lhs[i]);
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
return result;
|
| 564 |
+
}
|
| 565 |
+
};
|
| 566 |
+
|
| 567 |
+
template <typename T, int N>
|
| 568 |
+
struct plus<Array<T, N>> {
|
| 569 |
+
CUTLASS_HOST_DEVICE
|
| 570 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 571 |
+
|
| 572 |
+
Array<T, N> result;
|
| 573 |
+
plus<T> scalar_op;
|
| 574 |
+
|
| 575 |
+
CUTLASS_PRAGMA_UNROLL
|
| 576 |
+
for (int i = 0; i < N; ++i) {
|
| 577 |
+
result[i] = scalar_op(lhs[i], rhs[i]);
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
return result;
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
CUTLASS_HOST_DEVICE
|
| 584 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
| 585 |
+
|
| 586 |
+
Array<T, N> result;
|
| 587 |
+
plus<T> scalar_op;
|
| 588 |
+
|
| 589 |
+
CUTLASS_PRAGMA_UNROLL
|
| 590 |
+
for (int i = 0; i < N; ++i) {
|
| 591 |
+
result[i] = scalar_op(lhs[i], scalar);
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
return result;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
CUTLASS_HOST_DEVICE
|
| 598 |
+
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
| 599 |
+
|
| 600 |
+
Array<T, N> result;
|
| 601 |
+
plus<T> scalar_op;
|
| 602 |
+
|
| 603 |
+
CUTLASS_PRAGMA_UNROLL
|
| 604 |
+
for (int i = 0; i < N; ++i) {
|
| 605 |
+
result[i] = scalar_op(scalar, rhs[i]);
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
return result;
|
| 609 |
+
}
|
| 610 |
+
};
|
| 611 |
+
template <typename T, int N>
|
| 612 |
+
struct minus<Array<T, N>> {
|
| 613 |
+
|
| 614 |
+
CUTLASS_HOST_DEVICE
|
| 615 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 616 |
+
|
| 617 |
+
Array<T, N> result;
|
| 618 |
+
minus<T> scalar_op;
|
| 619 |
+
|
| 620 |
+
CUTLASS_PRAGMA_UNROLL
|
| 621 |
+
for (int i = 0; i < N; ++i) {
|
| 622 |
+
result[i] = scalar_op(lhs[i], rhs[i]);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
return result;
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
CUTLASS_HOST_DEVICE
|
| 629 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
| 630 |
+
|
| 631 |
+
Array<T, N> result;
|
| 632 |
+
minus<T> scalar_op;
|
| 633 |
+
|
| 634 |
+
CUTLASS_PRAGMA_UNROLL
|
| 635 |
+
for (int i = 0; i < N; ++i) {
|
| 636 |
+
result[i] = scalar_op(lhs[i], scalar);
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
return result;
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
CUTLASS_HOST_DEVICE
|
| 643 |
+
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
| 644 |
+
|
| 645 |
+
Array<T, N> result;
|
| 646 |
+
minus<T> scalar_op;
|
| 647 |
+
|
| 648 |
+
CUTLASS_PRAGMA_UNROLL
|
| 649 |
+
for (int i = 0; i < N; ++i) {
|
| 650 |
+
result[i] = scalar_op(scalar, rhs[i]);
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
return result;
|
| 654 |
+
}
|
| 655 |
+
};
|
| 656 |
+
|
| 657 |
+
template <typename T, int N>
|
| 658 |
+
struct multiplies<Array<T, N>> {
|
| 659 |
+
|
| 660 |
+
CUTLASS_HOST_DEVICE
|
| 661 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 662 |
+
|
| 663 |
+
Array<T, N> result;
|
| 664 |
+
multiplies<T> scalar_op;
|
| 665 |
+
|
| 666 |
+
CUTLASS_PRAGMA_UNROLL
|
| 667 |
+
for (int i = 0; i < N; ++i) {
|
| 668 |
+
result[i] = scalar_op(lhs[i], rhs[i]);
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
return result;
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
CUTLASS_HOST_DEVICE
|
| 675 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
| 676 |
+
|
| 677 |
+
Array<T, N> result;
|
| 678 |
+
multiplies<T> scalar_op;
|
| 679 |
+
|
| 680 |
+
CUTLASS_PRAGMA_UNROLL
|
| 681 |
+
for (int i = 0; i < N; ++i) {
|
| 682 |
+
result[i] = scalar_op(lhs[i], scalar);
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
return result;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
CUTLASS_HOST_DEVICE
|
| 689 |
+
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
| 690 |
+
|
| 691 |
+
Array<T, N> result;
|
| 692 |
+
multiplies<T> scalar_op;
|
| 693 |
+
|
| 694 |
+
CUTLASS_PRAGMA_UNROLL
|
| 695 |
+
for (int i = 0; i < N; ++i) {
|
| 696 |
+
result[i] = scalar_op(scalar, rhs[i]);
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
return result;
|
| 700 |
+
}
|
| 701 |
+
};
|
| 702 |
+
|
| 703 |
+
template <typename T, int N, bool PropogateNaN>
|
| 704 |
+
struct maximum_absolute_value_reduction<Array<T, N>, PropogateNaN> {
|
| 705 |
+
|
| 706 |
+
CUTLASS_HOST_DEVICE
|
| 707 |
+
T operator() (T const& scalar, Array<T, N> const& rhs) const {
|
| 708 |
+
|
| 709 |
+
T result = scalar;
|
| 710 |
+
maximum_absolute_value_reduction<T, PropogateNaN> scalar_op;
|
| 711 |
+
|
| 712 |
+
CUTLASS_PRAGMA_UNROLL
|
| 713 |
+
for (int i = 0; i < N; ++i) {
|
| 714 |
+
result = scalar_op(result, rhs[i]);
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
return result;
|
| 718 |
+
}
|
| 719 |
+
};
|
| 720 |
+
|
| 721 |
+
template <typename T, int N>
|
| 722 |
+
struct scale<Array<T, N>> {
|
| 723 |
+
T const scaling_factor_;
|
| 724 |
+
|
| 725 |
+
CUTLASS_HOST_DEVICE
|
| 726 |
+
scale(T scaling_factor) : scaling_factor_(scaling_factor) {
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
CUTLASS_HOST_DEVICE
|
| 730 |
+
Array<T, N> operator()(Array<T, N> const & rhs) const {
|
| 731 |
+
Array<T, N> result;
|
| 732 |
+
|
| 733 |
+
CUTLASS_PRAGMA_UNROLL
|
| 734 |
+
for (int i = 0; i < N; ++i) {
|
| 735 |
+
result[i] = rhs[i] * scaling_factor_;
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
return result;
|
| 739 |
+
}
|
| 740 |
+
};
|
| 741 |
+
|
| 742 |
+
template <typename T, int N>
|
| 743 |
+
struct divides<Array<T, N>> {
|
| 744 |
+
|
| 745 |
+
CUTLASS_HOST_DEVICE
|
| 746 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 747 |
+
|
| 748 |
+
Array<T, N> result;
|
| 749 |
+
divides<T> scalar_op;
|
| 750 |
+
|
| 751 |
+
CUTLASS_PRAGMA_UNROLL
|
| 752 |
+
for (int i = 0; i < N; ++i) {
|
| 753 |
+
result[i] = scalar_op(lhs[i], rhs[i]);
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
return result;
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
CUTLASS_HOST_DEVICE
|
| 760 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
| 761 |
+
|
| 762 |
+
Array<T, N> result;
|
| 763 |
+
divides<T> scalar_op;
|
| 764 |
+
|
| 765 |
+
CUTLASS_PRAGMA_UNROLL
|
| 766 |
+
for (int i = 0; i < N; ++i) {
|
| 767 |
+
result[i] = scalar_op(lhs[i], scalar);
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
return result;
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
CUTLASS_HOST_DEVICE
|
| 774 |
+
Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
|
| 775 |
+
|
| 776 |
+
Array<T, N> result;
|
| 777 |
+
divides<T> scalar_op;
|
| 778 |
+
|
| 779 |
+
CUTLASS_PRAGMA_UNROLL
|
| 780 |
+
for (int i = 0; i < N; ++i) {
|
| 781 |
+
result[i] = scalar_op(scalar, rhs[i]);
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
return result;
|
| 785 |
+
}
|
| 786 |
+
};
|
| 787 |
+
|
| 788 |
+
template <typename T, int N>
|
| 789 |
+
struct reciprocal_approximate<Array<T, N>> {
|
| 790 |
+
|
| 791 |
+
CUTLASS_HOST_DEVICE
|
| 792 |
+
Array<T, N> operator()(Array<T, N> const &lhs) const {
|
| 793 |
+
|
| 794 |
+
Array<T, N> result;
|
| 795 |
+
reciprocal_approximate<T> scalar_op;
|
| 796 |
+
|
| 797 |
+
CUTLASS_PRAGMA_UNROLL
|
| 798 |
+
for (int i = 0; i < N; ++i) {
|
| 799 |
+
result[i] = scalar_op(lhs[i]);
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
return result;
|
| 803 |
+
}
|
| 804 |
+
};
|
| 805 |
+
|
| 806 |
+
template <typename T, int N>
|
| 807 |
+
struct reciprocal_approximate_ftz<Array<T, N>> {
|
| 808 |
+
|
| 809 |
+
CUTLASS_HOST_DEVICE
|
| 810 |
+
Array<T, N> operator()(Array<T, N> const &lhs) const {
|
| 811 |
+
|
| 812 |
+
Array<T, N> result;
|
| 813 |
+
reciprocal_approximate_ftz<T> scalar_op;
|
| 814 |
+
|
| 815 |
+
CUTLASS_PRAGMA_UNROLL
|
| 816 |
+
for (int i = 0; i < N; ++i) {
|
| 817 |
+
result[i] = scalar_op(lhs[i]);
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
return result;
|
| 821 |
+
}
|
| 822 |
+
};
|
| 823 |
+
|
| 824 |
+
template <typename T, int N, bool PropagateNaN>
|
| 825 |
+
struct maximum<Array<T, N>, PropagateNaN> {
|
| 826 |
+
|
| 827 |
+
CUTLASS_HOST_DEVICE
|
| 828 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 829 |
+
|
| 830 |
+
Array<T, N> result;
|
| 831 |
+
maximum<T, PropagateNaN> scalar_op;
|
| 832 |
+
|
| 833 |
+
CUTLASS_PRAGMA_UNROLL
|
| 834 |
+
for (int i = 0; i < N; ++i) {
|
| 835 |
+
result[i] = scalar_op(lhs[i], rhs[i]);
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
return result;
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
CUTLASS_HOST_DEVICE
|
| 842 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
| 843 |
+
|
| 844 |
+
Array<T, N> result;
|
| 845 |
+
maximum<T, PropagateNaN> scalar_op;
|
| 846 |
+
|
| 847 |
+
CUTLASS_PRAGMA_UNROLL
|
| 848 |
+
for (int i = 0; i < N; ++i) {
|
| 849 |
+
result[i] = scalar_op(lhs[i], scalar);
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
return result;
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
CUTLASS_HOST_DEVICE
|
| 856 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &rhs) const {
|
| 857 |
+
|
| 858 |
+
Array<T, N> result;
|
| 859 |
+
maximum<T, PropagateNaN> scalar_op;
|
| 860 |
+
|
| 861 |
+
CUTLASS_PRAGMA_UNROLL
|
| 862 |
+
for (int i = 0; i < N; ++i) {
|
| 863 |
+
result[i] = scalar_op(scalar, rhs[i]);
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
return result;
|
| 867 |
+
}
|
| 868 |
+
};
|
| 869 |
+
|
| 870 |
+
template <typename T, int N, bool PropagateNaN>
|
| 871 |
+
struct minimum<Array<T, N>, PropagateNaN> {
|
| 872 |
+
|
| 873 |
+
CUTLASS_HOST_DEVICE
|
| 874 |
+
static T scalar_op(T const &lhs, T const &rhs) {
|
| 875 |
+
return (rhs < lhs ? rhs : lhs);
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
CUTLASS_HOST_DEVICE
|
| 879 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 880 |
+
|
| 881 |
+
Array<T, N> result;
|
| 882 |
+
minimum<T, PropagateNaN> scalar_op;
|
| 883 |
+
|
| 884 |
+
CUTLASS_PRAGMA_UNROLL
|
| 885 |
+
for (int i = 0; i < N; ++i) {
|
| 886 |
+
result[i] = scalar_op(lhs[i], rhs[i]);
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
return result;
|
| 890 |
+
}
|
| 891 |
+
|
| 892 |
+
CUTLASS_HOST_DEVICE
|
| 893 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
|
| 894 |
+
|
| 895 |
+
Array<T, N> result;
|
| 896 |
+
minimum<T, PropagateNaN> scalar_op;
|
| 897 |
+
|
| 898 |
+
CUTLASS_PRAGMA_UNROLL
|
| 899 |
+
for (int i = 0; i < N; ++i) {
|
| 900 |
+
result[i] = scalar_op(lhs[i], scalar);
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
return result;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
CUTLASS_HOST_DEVICE
|
| 907 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &rhs) const {
|
| 908 |
+
|
| 909 |
+
Array<T, N> result;
|
| 910 |
+
minimum<T, PropagateNaN> scalar_op;
|
| 911 |
+
|
| 912 |
+
CUTLASS_PRAGMA_UNROLL
|
| 913 |
+
for (int i = 0; i < N; ++i) {
|
| 914 |
+
result[i] = scalar_op(scalar, rhs[i]);
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
return result;
|
| 918 |
+
}
|
| 919 |
+
};
|
| 920 |
+
|
| 921 |
+
template <typename T, int N>
|
| 922 |
+
struct minimum_with_nan_propagation<Array<T, N>> : minimum<Array<T, N>, true>
|
| 923 |
+
{};
|
| 924 |
+
|
| 925 |
+
template <typename T, int N>
|
| 926 |
+
struct negate<Array<T, N>> {
|
| 927 |
+
|
| 928 |
+
CUTLASS_HOST_DEVICE
|
| 929 |
+
Array<T, N> operator()(Array<T, N> const &lhs) const {
|
| 930 |
+
|
| 931 |
+
Array<T, N> result;
|
| 932 |
+
negate<T> scalar_op;
|
| 933 |
+
|
| 934 |
+
CUTLASS_PRAGMA_UNROLL
|
| 935 |
+
for (int i = 0; i < N; ++i) {
|
| 936 |
+
result[i] = scalar_op(lhs[i]);
|
| 937 |
+
}
|
| 938 |
+
|
| 939 |
+
return result;
|
| 940 |
+
}
|
| 941 |
+
};
|
| 942 |
+
|
| 943 |
+
/// Fused multiply-add
|
| 944 |
+
template <typename T, int N>
|
| 945 |
+
struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
| 946 |
+
|
| 947 |
+
CUTLASS_HOST_DEVICE
|
| 948 |
+
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 949 |
+
|
| 950 |
+
Array<T, N> result;
|
| 951 |
+
multiply_add<T> scalar_op;
|
| 952 |
+
|
| 953 |
+
CUTLASS_PRAGMA_UNROLL
|
| 954 |
+
for (int i = 0; i < N; ++i) {
|
| 955 |
+
result[i] = scalar_op(a[i], b[i], c[i]);
|
| 956 |
+
}
|
| 957 |
+
|
| 958 |
+
return result;
|
| 959 |
+
}
|
| 960 |
+
|
| 961 |
+
CUTLASS_HOST_DEVICE
|
| 962 |
+
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
| 963 |
+
|
| 964 |
+
Array<T, N> result;
|
| 965 |
+
multiply_add<T> scalar_op;
|
| 966 |
+
|
| 967 |
+
CUTLASS_PRAGMA_UNROLL
|
| 968 |
+
for (int i = 0; i < N; ++i) {
|
| 969 |
+
result[i] = scalar_op(a[i], scalar, c[i]);
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
return result;
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
CUTLASS_HOST_DEVICE
|
| 976 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 977 |
+
|
| 978 |
+
Array<T, N> result;
|
| 979 |
+
multiply_add<T> scalar_op;
|
| 980 |
+
|
| 981 |
+
CUTLASS_PRAGMA_UNROLL
|
| 982 |
+
for (int i = 0; i < N; ++i) {
|
| 983 |
+
result[i] = scalar_op(scalar, b[i], c[i]);
|
| 984 |
+
}
|
| 985 |
+
|
| 986 |
+
return result;
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
CUTLASS_HOST_DEVICE
|
| 990 |
+
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, T const &scalar) const {
|
| 991 |
+
|
| 992 |
+
Array<T, N> result;
|
| 993 |
+
multiply_add<T> scalar_op;
|
| 994 |
+
|
| 995 |
+
CUTLASS_PRAGMA_UNROLL
|
| 996 |
+
for (int i = 0; i < N; ++i) {
|
| 997 |
+
result[i] = scalar_op(a[i], b[i], scalar);
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
return result;
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
CUTLASS_HOST_DEVICE
|
| 1005 |
+
Array<T, N> operator()(Array<T, N> const &a, T const &scalar_b, T const &scalar_c) const {
|
| 1006 |
+
|
| 1007 |
+
Array<T, N> result;
|
| 1008 |
+
multiply_add<T> scalar_op;
|
| 1009 |
+
|
| 1010 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1011 |
+
for (int i = 0; i < N; ++i) {
|
| 1012 |
+
result[i] = scalar_op(a[i], scalar_b, scalar_c);
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
return result;
|
| 1016 |
+
}
|
| 1017 |
+
};
|
| 1018 |
+
|
| 1019 |
+
/// Fused square-and-plus
|
| 1020 |
+
template <typename T, int N>
|
| 1021 |
+
struct square_and_plus<Array<T, N>> {
|
| 1022 |
+
|
| 1023 |
+
CUTLASS_HOST_DEVICE
|
| 1024 |
+
Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
|
| 1025 |
+
multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> ma_op;
|
| 1026 |
+
return ma_op(rhs, rhs, lhs);
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
CUTLASS_HOST_DEVICE
|
| 1030 |
+
Array<T, N> operator()(Array<T, N> const &lhs, T const &rhs) const {
|
| 1031 |
+
plus<Array<T, N>> plus_op;
|
| 1032 |
+
multiplies<T> multiplies_op;
|
| 1033 |
+
return plus_op(multiplies_op(rhs, rhs), lhs);
|
| 1034 |
+
}
|
| 1035 |
+
};
|
| 1036 |
+
|
| 1037 |
+
/// Inverse-square-root
|
| 1038 |
+
template <typename T, int N>
|
| 1039 |
+
struct inverse_square_root<Array<T, N>> {
|
| 1040 |
+
CUTLASS_HOST_DEVICE
|
| 1041 |
+
Array<T, N> operator()(Array<T, N> const &a) const {
|
| 1042 |
+
Array<T, N> result;
|
| 1043 |
+
inverse_square_root<T> scalar_op;
|
| 1044 |
+
|
| 1045 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1046 |
+
for (int i = 0; i < N; ++i) {
|
| 1047 |
+
result[i] = scalar_op(a[i]);
|
| 1048 |
+
}
|
| 1049 |
+
return result;
|
| 1050 |
+
}
|
| 1051 |
+
};
|
| 1052 |
+
|
| 1053 |
+
template <int N>
|
| 1054 |
+
struct inverse_square_root<Array<half_t, N>> {
|
| 1055 |
+
CUTLASS_HOST_DEVICE
|
| 1056 |
+
Array<half_t, N> operator()(Array<half_t, N> const & a) const {
|
| 1057 |
+
Array<half_t, N> result;
|
| 1058 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1059 |
+
|
| 1060 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1061 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1062 |
+
|
| 1063 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1064 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1065 |
+
result_ptr[i] = h2rsqrt(a_ptr[i]);
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
if constexpr (N % 2) {
|
| 1069 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1070 |
+
__half d_residual = hrsqrt(a_residual_ptr[N - 1]);
|
| 1071 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1072 |
+
}
|
| 1073 |
+
|
| 1074 |
+
#else
|
| 1075 |
+
|
| 1076 |
+
inverse_square_root<half_t> scalar_op;
|
| 1077 |
+
|
| 1078 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1079 |
+
for (int i = 0; i < N; ++i) {
|
| 1080 |
+
result[i] = scalar_op(a[i]);
|
| 1081 |
+
}
|
| 1082 |
+
|
| 1083 |
+
#endif
|
| 1084 |
+
|
| 1085 |
+
return result;
|
| 1086 |
+
}
|
| 1087 |
+
};
|
| 1088 |
+
|
| 1089 |
+
/// Fused multiply-add-relu0
|
| 1090 |
+
template <typename T, int N>
|
| 1091 |
+
struct multiply_add_relu0<Array<T, N>, Array<T, N>, Array<T, N>> {
|
| 1092 |
+
|
| 1093 |
+
CUTLASS_HOST_DEVICE
|
| 1094 |
+
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 1095 |
+
|
| 1096 |
+
Array<T, N> result;
|
| 1097 |
+
multiply_add<T> scalar_op;
|
| 1098 |
+
maximum<T> mx;
|
| 1099 |
+
|
| 1100 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1101 |
+
for (int i = 0; i < N; ++i) {
|
| 1102 |
+
result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0));
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
return result;
|
| 1106 |
+
}
|
| 1107 |
+
|
| 1108 |
+
CUTLASS_HOST_DEVICE
|
| 1109 |
+
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
| 1110 |
+
|
| 1111 |
+
Array<T, N> result;
|
| 1112 |
+
multiply_add<T> scalar_op;
|
| 1113 |
+
maximum<T> mx;
|
| 1114 |
+
|
| 1115 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1116 |
+
for (int i = 0; i < N; ++i) {
|
| 1117 |
+
result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0));
|
| 1118 |
+
}
|
| 1119 |
+
|
| 1120 |
+
return result;
|
| 1121 |
+
}
|
| 1122 |
+
|
| 1123 |
+
CUTLASS_HOST_DEVICE
|
| 1124 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 1125 |
+
|
| 1126 |
+
Array<T, N> result;
|
| 1127 |
+
multiply_add<T> scalar_op;
|
| 1128 |
+
maximum<T> mx;
|
| 1129 |
+
|
| 1130 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1131 |
+
for (int i = 0; i < N; ++i) {
|
| 1132 |
+
result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0));
|
| 1133 |
+
}
|
| 1134 |
+
|
| 1135 |
+
return result;
|
| 1136 |
+
}
|
| 1137 |
+
};
|
| 1138 |
+
|
| 1139 |
+
|
| 1140 |
+
template <typename T, int N>
|
| 1141 |
+
struct conjugate<Array<T, N> > {
|
| 1142 |
+
CUTLASS_HOST_DEVICE
|
| 1143 |
+
Array<T, N> operator()(Array<T, N> const &a) const {
|
| 1144 |
+
|
| 1145 |
+
conjugate<T> conj_op;
|
| 1146 |
+
|
| 1147 |
+
Array<T, N> ca;
|
| 1148 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1149 |
+
for (int i = 0; i < N; ++i) {
|
| 1150 |
+
ca[i] = conj_op(a[i]);
|
| 1151 |
+
}
|
| 1152 |
+
return ca;
|
| 1153 |
+
}
|
| 1154 |
+
};
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1158 |
+
// functional.h numeric specializations targeting SIMD instructions in device code.
|
| 1159 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1160 |
+
|
| 1161 |
+
template <int N>
|
| 1162 |
+
struct plus<Array<half_t, N>> {
|
| 1163 |
+
CUTLASS_HOST_DEVICE
|
| 1164 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
| 1165 |
+
Array<half_t, N> result;
|
| 1166 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1167 |
+
|
| 1168 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1169 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1170 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1171 |
+
|
| 1172 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1173 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1174 |
+
result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]);
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
if constexpr (N % 2) {
|
| 1178 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1179 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1180 |
+
__half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
|
| 1181 |
+
|
| 1182 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1183 |
+
}
|
| 1184 |
+
|
| 1185 |
+
#else
|
| 1186 |
+
|
| 1187 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1188 |
+
for (int i = 0; i < N; ++i) {
|
| 1189 |
+
result[i] = lhs[i] + rhs[i];
|
| 1190 |
+
}
|
| 1191 |
+
#endif
|
| 1192 |
+
|
| 1193 |
+
return result;
|
| 1194 |
+
}
|
| 1195 |
+
|
| 1196 |
+
CUTLASS_HOST_DEVICE
|
| 1197 |
+
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
| 1198 |
+
Array<half_t, N> result;
|
| 1199 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1200 |
+
|
| 1201 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1202 |
+
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
| 1203 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1204 |
+
|
| 1205 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1206 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1207 |
+
result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]);
|
| 1208 |
+
}
|
| 1209 |
+
|
| 1210 |
+
if constexpr (N % 2) {
|
| 1211 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1212 |
+
__half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
|
| 1213 |
+
|
| 1214 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
#else
|
| 1218 |
+
|
| 1219 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1220 |
+
for (int i = 0; i < N; ++i) {
|
| 1221 |
+
result[i] = lhs + rhs[i];
|
| 1222 |
+
}
|
| 1223 |
+
#endif
|
| 1224 |
+
|
| 1225 |
+
return result;
|
| 1226 |
+
}
|
| 1227 |
+
|
| 1228 |
+
CUTLASS_HOST_DEVICE
|
| 1229 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
| 1230 |
+
Array<half_t, N> result;
|
| 1231 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1232 |
+
|
| 1233 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1234 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1235 |
+
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
| 1236 |
+
|
| 1237 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1238 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1239 |
+
result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair);
|
| 1240 |
+
}
|
| 1241 |
+
|
| 1242 |
+
if constexpr (N % 2) {
|
| 1243 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1244 |
+
__half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
|
| 1245 |
+
|
| 1246 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1247 |
+
}
|
| 1248 |
+
|
| 1249 |
+
#else
|
| 1250 |
+
|
| 1251 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1252 |
+
for (int i = 0; i < N; ++i) {
|
| 1253 |
+
result[i] = lhs[i] + rhs;
|
| 1254 |
+
}
|
| 1255 |
+
#endif
|
| 1256 |
+
|
| 1257 |
+
return result;
|
| 1258 |
+
}
|
| 1259 |
+
};
|
| 1260 |
+
|
| 1261 |
+
template <int N>
|
| 1262 |
+
struct minus<Array<half_t, N>> {
|
| 1263 |
+
CUTLASS_HOST_DEVICE
|
| 1264 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
| 1265 |
+
Array<half_t, N> result;
|
| 1266 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1267 |
+
|
| 1268 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1269 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1270 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1271 |
+
|
| 1272 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1273 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1274 |
+
result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]);
|
| 1275 |
+
}
|
| 1276 |
+
|
| 1277 |
+
if constexpr (N % 2) {
|
| 1278 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1279 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1280 |
+
__half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
|
| 1281 |
+
|
| 1282 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1283 |
+
}
|
| 1284 |
+
|
| 1285 |
+
#else
|
| 1286 |
+
|
| 1287 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1288 |
+
for (int i = 0; i < N; ++i) {
|
| 1289 |
+
result[i] = lhs[i] - rhs[i];
|
| 1290 |
+
}
|
| 1291 |
+
#endif
|
| 1292 |
+
|
| 1293 |
+
return result;
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
+
CUTLASS_HOST_DEVICE
|
| 1297 |
+
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
| 1298 |
+
Array<half_t, N> result;
|
| 1299 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1300 |
+
|
| 1301 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1302 |
+
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
| 1303 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1304 |
+
|
| 1305 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1306 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1307 |
+
result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]);
|
| 1308 |
+
}
|
| 1309 |
+
|
| 1310 |
+
if constexpr (N % 2) {
|
| 1311 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1312 |
+
__half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
|
| 1313 |
+
|
| 1314 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1315 |
+
}
|
| 1316 |
+
|
| 1317 |
+
#else
|
| 1318 |
+
|
| 1319 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1320 |
+
for (int i = 0; i < N; ++i) {
|
| 1321 |
+
result[i] = lhs - rhs[i];
|
| 1322 |
+
}
|
| 1323 |
+
#endif
|
| 1324 |
+
|
| 1325 |
+
return result;
|
| 1326 |
+
}
|
| 1327 |
+
|
| 1328 |
+
CUTLASS_HOST_DEVICE
|
| 1329 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
| 1330 |
+
Array<half_t, N> result;
|
| 1331 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1332 |
+
|
| 1333 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1334 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1335 |
+
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
| 1336 |
+
|
| 1337 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1338 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1339 |
+
result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair);
|
| 1340 |
+
}
|
| 1341 |
+
|
| 1342 |
+
if constexpr (N % 2) {
|
| 1343 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1344 |
+
__half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
|
| 1345 |
+
|
| 1346 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1347 |
+
}
|
| 1348 |
+
|
| 1349 |
+
#else
|
| 1350 |
+
|
| 1351 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1352 |
+
for (int i = 0; i < N; ++i) {
|
| 1353 |
+
result[i] = lhs[i] - rhs;
|
| 1354 |
+
}
|
| 1355 |
+
#endif
|
| 1356 |
+
|
| 1357 |
+
return result;
|
| 1358 |
+
}
|
| 1359 |
+
};
|
| 1360 |
+
|
| 1361 |
+
template <int N>
|
| 1362 |
+
struct multiplies<Array<half_t, N>> {
|
| 1363 |
+
CUTLASS_HOST_DEVICE
|
| 1364 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
| 1365 |
+
Array<half_t, N> result;
|
| 1366 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1367 |
+
|
| 1368 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1369 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1370 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1371 |
+
|
| 1372 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1373 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1374 |
+
result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]);
|
| 1375 |
+
}
|
| 1376 |
+
|
| 1377 |
+
if constexpr (N % 2) {
|
| 1378 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1379 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1380 |
+
__half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
|
| 1381 |
+
|
| 1382 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1383 |
+
}
|
| 1384 |
+
|
| 1385 |
+
#else
|
| 1386 |
+
|
| 1387 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1388 |
+
for (int i = 0; i < N; ++i) {
|
| 1389 |
+
result[i] = lhs[i] * rhs[i];
|
| 1390 |
+
}
|
| 1391 |
+
#endif
|
| 1392 |
+
|
| 1393 |
+
return result;
|
| 1394 |
+
}
|
| 1395 |
+
|
| 1396 |
+
CUTLASS_HOST_DEVICE
|
| 1397 |
+
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
| 1398 |
+
Array<half_t, N> result;
|
| 1399 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1400 |
+
|
| 1401 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1402 |
+
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
| 1403 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1404 |
+
|
| 1405 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1406 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1407 |
+
result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]);
|
| 1408 |
+
}
|
| 1409 |
+
|
| 1410 |
+
if constexpr (N % 2) {
|
| 1411 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1412 |
+
|
| 1413 |
+
__half d_residual = __hmul(
|
| 1414 |
+
reinterpret_cast<__half const &>(lhs),
|
| 1415 |
+
b_residual_ptr[N - 1]);
|
| 1416 |
+
|
| 1417 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1418 |
+
}
|
| 1419 |
+
|
| 1420 |
+
#else
|
| 1421 |
+
|
| 1422 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1423 |
+
for (int i = 0; i < N; ++i) {
|
| 1424 |
+
result[i] = lhs * rhs[i];
|
| 1425 |
+
}
|
| 1426 |
+
#endif
|
| 1427 |
+
|
| 1428 |
+
return result;
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
CUTLASS_HOST_DEVICE
|
| 1432 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
| 1433 |
+
Array<half_t, N> result;
|
| 1434 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1435 |
+
|
| 1436 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1437 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1438 |
+
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
| 1439 |
+
|
| 1440 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1441 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1442 |
+
result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair);
|
| 1443 |
+
}
|
| 1444 |
+
|
| 1445 |
+
if constexpr (N % 2) {
|
| 1446 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1447 |
+
|
| 1448 |
+
__half d_residual = __hmul(
|
| 1449 |
+
a_residual_ptr[N - 1],
|
| 1450 |
+
reinterpret_cast<__half const &>(rhs));
|
| 1451 |
+
|
| 1452 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1453 |
+
}
|
| 1454 |
+
|
| 1455 |
+
#else
|
| 1456 |
+
|
| 1457 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1458 |
+
for (int i = 0; i < N; ++i) {
|
| 1459 |
+
result[i] = lhs[i] * rhs;
|
| 1460 |
+
}
|
| 1461 |
+
#endif
|
| 1462 |
+
|
| 1463 |
+
return result;
|
| 1464 |
+
}
|
| 1465 |
+
};
|
| 1466 |
+
|
| 1467 |
+
template <int N>
|
| 1468 |
+
struct divides<Array<half_t, N>> {
|
| 1469 |
+
CUTLASS_HOST_DEVICE
|
| 1470 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
| 1471 |
+
Array<half_t, N> result;
|
| 1472 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1473 |
+
|
| 1474 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1475 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1476 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1477 |
+
|
| 1478 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1479 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1480 |
+
result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]);
|
| 1481 |
+
}
|
| 1482 |
+
|
| 1483 |
+
if constexpr (N % 2) {
|
| 1484 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1485 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1486 |
+
|
| 1487 |
+
__half d_residual = __hdiv(
|
| 1488 |
+
a_residual_ptr[N - 1],
|
| 1489 |
+
b_residual_ptr[N - 1]);
|
| 1490 |
+
|
| 1491 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1492 |
+
}
|
| 1493 |
+
|
| 1494 |
+
#else
|
| 1495 |
+
|
| 1496 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1497 |
+
for (int i = 0; i < N; ++i) {
|
| 1498 |
+
result[i] = lhs[i] / rhs[i];
|
| 1499 |
+
}
|
| 1500 |
+
#endif
|
| 1501 |
+
|
| 1502 |
+
return result;
|
| 1503 |
+
}
|
| 1504 |
+
|
| 1505 |
+
CUTLASS_HOST_DEVICE
|
| 1506 |
+
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
| 1507 |
+
Array<half_t, N> result;
|
| 1508 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1509 |
+
|
| 1510 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1511 |
+
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
| 1512 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 1513 |
+
|
| 1514 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1515 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1516 |
+
result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]);
|
| 1517 |
+
}
|
| 1518 |
+
|
| 1519 |
+
if constexpr (N % 2) {
|
| 1520 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 1521 |
+
|
| 1522 |
+
__half d_residual = __hdiv(
|
| 1523 |
+
reinterpret_cast<__half const &>(lhs),
|
| 1524 |
+
b_residual_ptr[N - 1]);
|
| 1525 |
+
|
| 1526 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1527 |
+
}
|
| 1528 |
+
|
| 1529 |
+
#else
|
| 1530 |
+
|
| 1531 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1532 |
+
for (int i = 0; i < N; ++i) {
|
| 1533 |
+
result[i] = lhs / rhs[i];
|
| 1534 |
+
}
|
| 1535 |
+
#endif
|
| 1536 |
+
|
| 1537 |
+
return result;
|
| 1538 |
+
}
|
| 1539 |
+
|
| 1540 |
+
CUTLASS_HOST_DEVICE
|
| 1541 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
| 1542 |
+
Array<half_t, N> result;
|
| 1543 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1544 |
+
|
| 1545 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1546 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1547 |
+
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
| 1548 |
+
|
| 1549 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1550 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1551 |
+
result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair);
|
| 1552 |
+
}
|
| 1553 |
+
|
| 1554 |
+
if constexpr (N % 2) {
|
| 1555 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 1556 |
+
|
| 1557 |
+
__half d_residual = __hdiv(
|
| 1558 |
+
a_residual_ptr[N - 1],
|
| 1559 |
+
reinterpret_cast<__half const &>(rhs));
|
| 1560 |
+
|
| 1561 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1562 |
+
}
|
| 1563 |
+
|
| 1564 |
+
#else
|
| 1565 |
+
|
| 1566 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1567 |
+
for (int i = 0; i < N; ++i) {
|
| 1568 |
+
result[i] = lhs[i] / rhs;
|
| 1569 |
+
}
|
| 1570 |
+
#endif
|
| 1571 |
+
|
| 1572 |
+
return result;
|
| 1573 |
+
}
|
| 1574 |
+
};
|
| 1575 |
+
|
| 1576 |
+
template <int N>
|
| 1577 |
+
struct negate<Array<half_t, N>> {
|
| 1578 |
+
CUTLASS_HOST_DEVICE
|
| 1579 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs) const {
|
| 1580 |
+
Array<half_t, N> result;
|
| 1581 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1582 |
+
|
| 1583 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1584 |
+
__half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 1585 |
+
|
| 1586 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1587 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1588 |
+
result_ptr[i] = __hneg2(source_ptr[i]);
|
| 1589 |
+
}
|
| 1590 |
+
|
| 1591 |
+
if constexpr (N % 2) {
|
| 1592 |
+
half_t x = -lhs[N - 1];
|
| 1593 |
+
__half lhs_val = reinterpret_cast<__half const &>(x);
|
| 1594 |
+
result[N - 1] = reinterpret_cast<half_t const &>(lhs_val);
|
| 1595 |
+
}
|
| 1596 |
+
|
| 1597 |
+
#else
|
| 1598 |
+
|
| 1599 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1600 |
+
for (int i = 0; i < N; ++i) {
|
| 1601 |
+
result[i] = -lhs[i];
|
| 1602 |
+
}
|
| 1603 |
+
#endif
|
| 1604 |
+
|
| 1605 |
+
return result;
|
| 1606 |
+
}
|
| 1607 |
+
};
|
| 1608 |
+
|
| 1609 |
+
/// Fused multiply-add
|
| 1610 |
+
template <int N>
|
| 1611 |
+
struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
|
| 1612 |
+
|
| 1613 |
+
CUTLASS_HOST_DEVICE
|
| 1614 |
+
Array<half_t, N> operator()(
|
| 1615 |
+
Array<half_t, N> const &a,
|
| 1616 |
+
Array<half_t, N> const &b,
|
| 1617 |
+
Array<half_t, N> const &c) const {
|
| 1618 |
+
|
| 1619 |
+
Array<half_t, N> result;
|
| 1620 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1621 |
+
|
| 1622 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1623 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1624 |
+
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
| 1625 |
+
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
| 1626 |
+
|
| 1627 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1628 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1629 |
+
result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]);
|
| 1630 |
+
}
|
| 1631 |
+
|
| 1632 |
+
if constexpr (N % 2) {
|
| 1633 |
+
|
| 1634 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1635 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
| 1636 |
+
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
| 1637 |
+
|
| 1638 |
+
__half d_residual = __hfma(
|
| 1639 |
+
a_residual_ptr[N - 1],
|
| 1640 |
+
b_residual_ptr[N - 1],
|
| 1641 |
+
c_residual_ptr[N - 1]);
|
| 1642 |
+
|
| 1643 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1644 |
+
}
|
| 1645 |
+
|
| 1646 |
+
#else
|
| 1647 |
+
|
| 1648 |
+
multiply_add<half_t> op;
|
| 1649 |
+
|
| 1650 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1651 |
+
for (int i = 0; i < N; ++i) {
|
| 1652 |
+
result[i] = op(a[i], b[i], c[i]);
|
| 1653 |
+
}
|
| 1654 |
+
#endif
|
| 1655 |
+
|
| 1656 |
+
return result;
|
| 1657 |
+
}
|
| 1658 |
+
|
| 1659 |
+
CUTLASS_HOST_DEVICE
|
| 1660 |
+
Array<half_t, N> operator()(
|
| 1661 |
+
half_t const &a,
|
| 1662 |
+
Array<half_t, N> const &b,
|
| 1663 |
+
Array<half_t, N> const &c) const {
|
| 1664 |
+
|
| 1665 |
+
Array<half_t, N> result;
|
| 1666 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1667 |
+
|
| 1668 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1669 |
+
__half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
|
| 1670 |
+
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
| 1671 |
+
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
| 1672 |
+
|
| 1673 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1674 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1675 |
+
result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]);
|
| 1676 |
+
}
|
| 1677 |
+
|
| 1678 |
+
if constexpr (N % 2) {
|
| 1679 |
+
|
| 1680 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
| 1681 |
+
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
| 1682 |
+
__half d_residual = __hfma(
|
| 1683 |
+
reinterpret_cast<__half const &>(a),
|
| 1684 |
+
b_residual_ptr[N - 1],
|
| 1685 |
+
c_residual_ptr[N - 1]);
|
| 1686 |
+
|
| 1687 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1688 |
+
}
|
| 1689 |
+
|
| 1690 |
+
#else
|
| 1691 |
+
|
| 1692 |
+
multiply_add<half_t> op;
|
| 1693 |
+
|
| 1694 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1695 |
+
for (int i = 0; i < N; ++i) {
|
| 1696 |
+
result[i] = op(a, b[i], c[i]);
|
| 1697 |
+
}
|
| 1698 |
+
#endif
|
| 1699 |
+
|
| 1700 |
+
return result;
|
| 1701 |
+
}
|
| 1702 |
+
|
| 1703 |
+
CUTLASS_HOST_DEVICE
|
| 1704 |
+
Array<half_t, N> operator()(
|
| 1705 |
+
Array<half_t, N> const &a,
|
| 1706 |
+
half_t const &b,
|
| 1707 |
+
Array<half_t, N> const &c) const {
|
| 1708 |
+
|
| 1709 |
+
Array<half_t, N> result;
|
| 1710 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1711 |
+
|
| 1712 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1713 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1714 |
+
__half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
|
| 1715 |
+
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
| 1716 |
+
|
| 1717 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1718 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1719 |
+
result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]);
|
| 1720 |
+
}
|
| 1721 |
+
|
| 1722 |
+
if constexpr (N % 2) {
|
| 1723 |
+
|
| 1724 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1725 |
+
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
| 1726 |
+
|
| 1727 |
+
__half d_residual = __hfma(
|
| 1728 |
+
a_residual_ptr[N - 1],
|
| 1729 |
+
reinterpret_cast<__half const &>(b),
|
| 1730 |
+
c_residual_ptr[N - 1]);
|
| 1731 |
+
|
| 1732 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1733 |
+
}
|
| 1734 |
+
|
| 1735 |
+
#else
|
| 1736 |
+
|
| 1737 |
+
multiply_add<half_t> op;
|
| 1738 |
+
|
| 1739 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1740 |
+
for (int i = 0; i < N; ++i) {
|
| 1741 |
+
result[i] = op(a[i], b, c[i]);
|
| 1742 |
+
}
|
| 1743 |
+
#endif
|
| 1744 |
+
|
| 1745 |
+
return result;
|
| 1746 |
+
}
|
| 1747 |
+
|
| 1748 |
+
CUTLASS_HOST_DEVICE
|
| 1749 |
+
Array<half_t, N> operator()(
|
| 1750 |
+
Array<half_t, N> const &a,
|
| 1751 |
+
Array<half_t, N> const &b,
|
| 1752 |
+
half_t const &c) const {
|
| 1753 |
+
|
| 1754 |
+
Array<half_t, N> result;
|
| 1755 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1756 |
+
|
| 1757 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1758 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1759 |
+
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
| 1760 |
+
__half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
|
| 1761 |
+
|
| 1762 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1763 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1764 |
+
result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair);
|
| 1765 |
+
}
|
| 1766 |
+
|
| 1767 |
+
if constexpr (N % 2) {
|
| 1768 |
+
|
| 1769 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1770 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
| 1771 |
+
|
| 1772 |
+
__half d_residual = __hfma(
|
| 1773 |
+
a_residual_ptr[N - 1],
|
| 1774 |
+
b_residual_ptr[N - 1],
|
| 1775 |
+
reinterpret_cast<__half const &>(c));
|
| 1776 |
+
|
| 1777 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1778 |
+
}
|
| 1779 |
+
|
| 1780 |
+
#else
|
| 1781 |
+
|
| 1782 |
+
multiply_add<half_t> op;
|
| 1783 |
+
|
| 1784 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1785 |
+
for (int i = 0; i < N; ++i) {
|
| 1786 |
+
result[i] = op(a[i], b[i], c);
|
| 1787 |
+
}
|
| 1788 |
+
#endif
|
| 1789 |
+
|
| 1790 |
+
return result;
|
| 1791 |
+
}
|
| 1792 |
+
|
| 1793 |
+
CUTLASS_HOST_DEVICE
|
| 1794 |
+
Array<half_t, N> operator()(
|
| 1795 |
+
Array<half_t, N> const &a,
|
| 1796 |
+
half_t const &b,
|
| 1797 |
+
half_t const &c) const {
|
| 1798 |
+
|
| 1799 |
+
Array<half_t, N> result;
|
| 1800 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
| 1801 |
+
|
| 1802 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1803 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1804 |
+
__half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
|
| 1805 |
+
__half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
|
| 1806 |
+
|
| 1807 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1808 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1809 |
+
result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_pair);
|
| 1810 |
+
}
|
| 1811 |
+
|
| 1812 |
+
if constexpr (N % 2) {
|
| 1813 |
+
|
| 1814 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1815 |
+
|
| 1816 |
+
__half d_residual = __hfma(
|
| 1817 |
+
a_residual_ptr[N - 1],
|
| 1818 |
+
reinterpret_cast<__half const &>(b),
|
| 1819 |
+
reinterpret_cast<__half const &>(c));
|
| 1820 |
+
|
| 1821 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1822 |
+
}
|
| 1823 |
+
|
| 1824 |
+
#else
|
| 1825 |
+
|
| 1826 |
+
multiply_add<half_t> op;
|
| 1827 |
+
|
| 1828 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1829 |
+
for (int i = 0; i < N; ++i) {
|
| 1830 |
+
result[i] = op(a[i], b, c);
|
| 1831 |
+
}
|
| 1832 |
+
#endif
|
| 1833 |
+
|
| 1834 |
+
return result;
|
| 1835 |
+
}
|
| 1836 |
+
};
|
| 1837 |
+
|
| 1838 |
+
/// Fused multiply-add-relu0
|
| 1839 |
+
template <int N>
|
| 1840 |
+
struct multiply_add_relu0<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
|
| 1841 |
+
|
| 1842 |
+
CUTLASS_HOST_DEVICE
|
| 1843 |
+
Array<half_t, N> operator()(
|
| 1844 |
+
Array<half_t, N> const &a,
|
| 1845 |
+
Array<half_t, N> const &b,
|
| 1846 |
+
Array<half_t, N> const &c) const {
|
| 1847 |
+
|
| 1848 |
+
Array<half_t, N> result;
|
| 1849 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 1850 |
+
|
| 1851 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1852 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1853 |
+
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
| 1854 |
+
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
| 1855 |
+
|
| 1856 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1857 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1858 |
+
result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]);
|
| 1859 |
+
}
|
| 1860 |
+
|
| 1861 |
+
if constexpr (N % 2) {
|
| 1862 |
+
|
| 1863 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1864 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
| 1865 |
+
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
| 1866 |
+
|
| 1867 |
+
__half d_residual = __hfma_relu(
|
| 1868 |
+
a_residual_ptr[N - 1],
|
| 1869 |
+
b_residual_ptr[N - 1],
|
| 1870 |
+
c_residual_ptr[N - 1]);
|
| 1871 |
+
|
| 1872 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1873 |
+
}
|
| 1874 |
+
|
| 1875 |
+
#else
|
| 1876 |
+
|
| 1877 |
+
multiply_add<half_t> op;
|
| 1878 |
+
maximum<half_t> mx;
|
| 1879 |
+
|
| 1880 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1881 |
+
for (int i = 0; i < N; ++i) {
|
| 1882 |
+
result[i] = mx(op(a[i], b[i], c[i]), (half_t)0);
|
| 1883 |
+
}
|
| 1884 |
+
#endif
|
| 1885 |
+
|
| 1886 |
+
return result;
|
| 1887 |
+
}
|
| 1888 |
+
|
| 1889 |
+
CUTLASS_HOST_DEVICE
|
| 1890 |
+
Array<half_t, N> operator()(
|
| 1891 |
+
half_t const &a,
|
| 1892 |
+
Array<half_t, N> const &b,
|
| 1893 |
+
Array<half_t, N> const &c) const {
|
| 1894 |
+
|
| 1895 |
+
Array<half_t, N> result;
|
| 1896 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 1897 |
+
|
| 1898 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1899 |
+
__half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
|
| 1900 |
+
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
| 1901 |
+
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
| 1902 |
+
|
| 1903 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1904 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1905 |
+
result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]);
|
| 1906 |
+
}
|
| 1907 |
+
|
| 1908 |
+
if constexpr (N % 2) {
|
| 1909 |
+
|
| 1910 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
| 1911 |
+
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
| 1912 |
+
__half d_residual = __hfma_relu(
|
| 1913 |
+
reinterpret_cast<__half const &>(a),
|
| 1914 |
+
b_residual_ptr[N - 1],
|
| 1915 |
+
c_residual_ptr[N - 1]);
|
| 1916 |
+
|
| 1917 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1918 |
+
}
|
| 1919 |
+
|
| 1920 |
+
#else
|
| 1921 |
+
|
| 1922 |
+
multiply_add<half_t> op;
|
| 1923 |
+
maximum<half_t> mx;
|
| 1924 |
+
|
| 1925 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1926 |
+
for (int i = 0; i < N; ++i) {
|
| 1927 |
+
result[i] = mx(op(a, b[i], c[i]), half_t(0));
|
| 1928 |
+
}
|
| 1929 |
+
#endif
|
| 1930 |
+
|
| 1931 |
+
return result;
|
| 1932 |
+
}
|
| 1933 |
+
|
| 1934 |
+
CUTLASS_HOST_DEVICE
|
| 1935 |
+
Array<half_t, N> operator()(
|
| 1936 |
+
Array<half_t, N> const &a,
|
| 1937 |
+
half_t const &b,
|
| 1938 |
+
Array<half_t, N> const &c) const {
|
| 1939 |
+
|
| 1940 |
+
Array<half_t, N> result;
|
| 1941 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 1942 |
+
|
| 1943 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1944 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1945 |
+
__half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
|
| 1946 |
+
__half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
|
| 1947 |
+
|
| 1948 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1949 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1950 |
+
result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]);
|
| 1951 |
+
}
|
| 1952 |
+
|
| 1953 |
+
if constexpr (N % 2) {
|
| 1954 |
+
|
| 1955 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 1956 |
+
__half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
|
| 1957 |
+
|
| 1958 |
+
__half d_residual = __hfma_relu(
|
| 1959 |
+
a_residual_ptr[N - 1],
|
| 1960 |
+
reinterpret_cast<__half const &>(b),
|
| 1961 |
+
c_residual_ptr[N - 1]);
|
| 1962 |
+
|
| 1963 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 1964 |
+
}
|
| 1965 |
+
|
| 1966 |
+
#else
|
| 1967 |
+
|
| 1968 |
+
multiply_add<half_t> op;
|
| 1969 |
+
maximum<half_t> mx;
|
| 1970 |
+
|
| 1971 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1972 |
+
for (int i = 0; i < N; ++i) {
|
| 1973 |
+
result[i] = mx(op(a[i], b, c[i]), half_t(0));
|
| 1974 |
+
}
|
| 1975 |
+
#endif
|
| 1976 |
+
|
| 1977 |
+
return result;
|
| 1978 |
+
}
|
| 1979 |
+
|
| 1980 |
+
CUTLASS_HOST_DEVICE
|
| 1981 |
+
Array<half_t, N> operator()(
|
| 1982 |
+
Array<half_t, N> const &a,
|
| 1983 |
+
Array<half_t, N> const &b,
|
| 1984 |
+
half_t const &c) const {
|
| 1985 |
+
|
| 1986 |
+
Array<half_t, N> result;
|
| 1987 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 1988 |
+
|
| 1989 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 1990 |
+
__half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
|
| 1991 |
+
__half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
|
| 1992 |
+
__half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
|
| 1993 |
+
|
| 1994 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1995 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 1996 |
+
result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair);
|
| 1997 |
+
}
|
| 1998 |
+
|
| 1999 |
+
if constexpr (N % 2) {
|
| 2000 |
+
|
| 2001 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
|
| 2002 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
|
| 2003 |
+
|
| 2004 |
+
__half d_residual = __hfma_relu(
|
| 2005 |
+
a_residual_ptr[N - 1],
|
| 2006 |
+
b_residual_ptr[N - 1],
|
| 2007 |
+
reinterpret_cast<__half const &>(c));
|
| 2008 |
+
|
| 2009 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2010 |
+
}
|
| 2011 |
+
|
| 2012 |
+
#else
|
| 2013 |
+
|
| 2014 |
+
multiply_add<half_t> op;
|
| 2015 |
+
maximum<half_t> mx;
|
| 2016 |
+
|
| 2017 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2018 |
+
for (int i = 0; i < N; ++i) {
|
| 2019 |
+
result[i] = mx(op(a[i], b[i], c), half_t(0));
|
| 2020 |
+
}
|
| 2021 |
+
#endif
|
| 2022 |
+
|
| 2023 |
+
return result;
|
| 2024 |
+
}
|
| 2025 |
+
};
|
| 2026 |
+
|
| 2027 |
+
template <int N, bool PropagateNaN>
|
| 2028 |
+
struct minimum<Array<half_t, N>, PropagateNaN> {
|
| 2029 |
+
CUTLASS_HOST_DEVICE
|
| 2030 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
| 2031 |
+
Array<half_t, N> result;
|
| 2032 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2033 |
+
|
| 2034 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 2035 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 2036 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 2037 |
+
|
| 2038 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2039 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2040 |
+
result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i])
|
| 2041 |
+
: __hmin2(lhs_ptr[i], rhs_ptr[i]);
|
| 2042 |
+
}
|
| 2043 |
+
|
| 2044 |
+
if constexpr (N % 2) {
|
| 2045 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 2046 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 2047 |
+
|
| 2048 |
+
__half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1])
|
| 2049 |
+
: __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
|
| 2050 |
+
|
| 2051 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2052 |
+
}
|
| 2053 |
+
|
| 2054 |
+
#else
|
| 2055 |
+
|
| 2056 |
+
minimum<half_t,PropagateNaN> mn;
|
| 2057 |
+
|
| 2058 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2059 |
+
for (int i = 0; i < N; ++i) {
|
| 2060 |
+
result[i] = mn(lhs[i],rhs[i]);
|
| 2061 |
+
}
|
| 2062 |
+
#endif
|
| 2063 |
+
|
| 2064 |
+
return result;
|
| 2065 |
+
}
|
| 2066 |
+
|
| 2067 |
+
CUTLASS_HOST_DEVICE
|
| 2068 |
+
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
| 2069 |
+
Array<half_t, N> result;
|
| 2070 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2071 |
+
|
| 2072 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 2073 |
+
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
| 2074 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 2075 |
+
|
| 2076 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2077 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2078 |
+
result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i])
|
| 2079 |
+
: __hmin2(lhs_pair, rhs_ptr[i]);
|
| 2080 |
+
}
|
| 2081 |
+
|
| 2082 |
+
if constexpr (N % 2) {
|
| 2083 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 2084 |
+
|
| 2085 |
+
__half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1])
|
| 2086 |
+
: __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
|
| 2087 |
+
|
| 2088 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2089 |
+
}
|
| 2090 |
+
|
| 2091 |
+
#else
|
| 2092 |
+
|
| 2093 |
+
minimum<half_t,PropagateNaN> mn;
|
| 2094 |
+
|
| 2095 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2096 |
+
for (int i = 0; i < N; ++i) {
|
| 2097 |
+
result[i] = mn(lhs, rhs[i]);
|
| 2098 |
+
}
|
| 2099 |
+
#endif
|
| 2100 |
+
|
| 2101 |
+
return result;
|
| 2102 |
+
}
|
| 2103 |
+
|
| 2104 |
+
CUTLASS_HOST_DEVICE
|
| 2105 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
| 2106 |
+
Array<half_t, N> result;
|
| 2107 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2108 |
+
|
| 2109 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 2110 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 2111 |
+
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
| 2112 |
+
|
| 2113 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2114 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2115 |
+
result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair)
|
| 2116 |
+
: __hmin2(lhs_ptr[i], rhs_pair);
|
| 2117 |
+
}
|
| 2118 |
+
|
| 2119 |
+
if constexpr (N % 2) {
|
| 2120 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 2121 |
+
|
| 2122 |
+
__half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs))
|
| 2123 |
+
: __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
|
| 2124 |
+
|
| 2125 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2126 |
+
}
|
| 2127 |
+
|
| 2128 |
+
#else
|
| 2129 |
+
|
| 2130 |
+
minimum<half_t, PropagateNaN> mn;
|
| 2131 |
+
|
| 2132 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2133 |
+
for (int i = 0; i < N; ++i) {
|
| 2134 |
+
result[i] = mn(lhs[i], rhs);
|
| 2135 |
+
}
|
| 2136 |
+
#endif
|
| 2137 |
+
|
| 2138 |
+
return result;
|
| 2139 |
+
}
|
| 2140 |
+
};
|
| 2141 |
+
|
| 2142 |
+
template <int N, bool PropagateNaN>
|
| 2143 |
+
struct maximum<Array<half_t, N>, PropagateNaN> {
|
| 2144 |
+
CUTLASS_HOST_DEVICE
|
| 2145 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
|
| 2146 |
+
Array<half_t, N> result;
|
| 2147 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2148 |
+
|
| 2149 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 2150 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 2151 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 2152 |
+
|
| 2153 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2154 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2155 |
+
result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i])
|
| 2156 |
+
: __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
| 2157 |
+
}
|
| 2158 |
+
|
| 2159 |
+
if constexpr (N % 2) {
|
| 2160 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 2161 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 2162 |
+
|
| 2163 |
+
__half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1])
|
| 2164 |
+
: __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
|
| 2165 |
+
|
| 2166 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2167 |
+
}
|
| 2168 |
+
|
| 2169 |
+
#else
|
| 2170 |
+
|
| 2171 |
+
maximum<half_t,PropagateNaN> mx;
|
| 2172 |
+
|
| 2173 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2174 |
+
for (int i = 0; i < N; ++i) {
|
| 2175 |
+
result[i] = mx(lhs[i], rhs[i]);
|
| 2176 |
+
}
|
| 2177 |
+
#endif
|
| 2178 |
+
|
| 2179 |
+
return result;
|
| 2180 |
+
}
|
| 2181 |
+
|
| 2182 |
+
CUTLASS_HOST_DEVICE
|
| 2183 |
+
Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
|
| 2184 |
+
Array<half_t, N> result;
|
| 2185 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2186 |
+
|
| 2187 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 2188 |
+
__half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
|
| 2189 |
+
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
|
| 2190 |
+
|
| 2191 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2192 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2193 |
+
result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i])
|
| 2194 |
+
: __hmax2(lhs_pair, rhs_ptr[i]);
|
| 2195 |
+
}
|
| 2196 |
+
|
| 2197 |
+
if constexpr (N % 2) {
|
| 2198 |
+
__half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
|
| 2199 |
+
|
| 2200 |
+
__half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1])
|
| 2201 |
+
: __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
|
| 2202 |
+
|
| 2203 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2204 |
+
}
|
| 2205 |
+
|
| 2206 |
+
#else
|
| 2207 |
+
|
| 2208 |
+
maximum<half_t,PropagateNaN> mx;
|
| 2209 |
+
|
| 2210 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2211 |
+
for (int i = 0; i < N; ++i) {
|
| 2212 |
+
result[i] = mx(lhs, rhs[i]);
|
| 2213 |
+
}
|
| 2214 |
+
#endif
|
| 2215 |
+
|
| 2216 |
+
return result;
|
| 2217 |
+
}
|
| 2218 |
+
|
| 2219 |
+
CUTLASS_HOST_DEVICE
|
| 2220 |
+
Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
|
| 2221 |
+
Array<half_t, N> result;
|
| 2222 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2223 |
+
|
| 2224 |
+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
|
| 2225 |
+
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
|
| 2226 |
+
__half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
|
| 2227 |
+
|
| 2228 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2229 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2230 |
+
result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair)
|
| 2231 |
+
: __hmax2(lhs_ptr[i], rhs_pair);
|
| 2232 |
+
}
|
| 2233 |
+
|
| 2234 |
+
if constexpr (N % 2) {
|
| 2235 |
+
__half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
|
| 2236 |
+
|
| 2237 |
+
__half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs))
|
| 2238 |
+
: __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
|
| 2239 |
+
|
| 2240 |
+
result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
|
| 2241 |
+
}
|
| 2242 |
+
|
| 2243 |
+
#else
|
| 2244 |
+
|
| 2245 |
+
maximum<half_t,PropagateNaN> mx;
|
| 2246 |
+
|
| 2247 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2248 |
+
for (int i = 0; i < N; ++i) {
|
| 2249 |
+
result[i] = mx(lhs[i], rhs);
|
| 2250 |
+
}
|
| 2251 |
+
#endif
|
| 2252 |
+
|
| 2253 |
+
return result;
|
| 2254 |
+
}
|
| 2255 |
+
};
|
| 2256 |
+
|
| 2257 |
+
/// Fused multiply-add
|
| 2258 |
+
template <int N>
|
| 2259 |
+
struct multiply_add<Array<bfloat16_t, N>, Array<bfloat16_t, N>, Array<bfloat16_t, N>> {
|
| 2260 |
+
|
| 2261 |
+
CUTLASS_HOST_DEVICE
|
| 2262 |
+
Array<bfloat16_t, N> operator()(
|
| 2263 |
+
Array<bfloat16_t, N> const &a,
|
| 2264 |
+
Array<bfloat16_t, N> const &b,
|
| 2265 |
+
Array<bfloat16_t, N> const &c) const {
|
| 2266 |
+
|
| 2267 |
+
Array<bfloat16_t, N> result;
|
| 2268 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2269 |
+
|
| 2270 |
+
unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
|
| 2271 |
+
unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
|
| 2272 |
+
unsigned const *b_ptr = reinterpret_cast<unsigned const *>(&b);
|
| 2273 |
+
unsigned const *c_ptr = reinterpret_cast<unsigned const *>(&c);
|
| 2274 |
+
|
| 2275 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2276 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2277 |
+
asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
|
| 2278 |
+
: "=r"(result_ptr[i])
|
| 2279 |
+
: "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i])
|
| 2280 |
+
);
|
| 2281 |
+
}
|
| 2282 |
+
|
| 2283 |
+
if constexpr (N % 2) {
|
| 2284 |
+
|
| 2285 |
+
uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
|
| 2286 |
+
uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
|
| 2287 |
+
uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
|
| 2288 |
+
uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
|
| 2289 |
+
|
| 2290 |
+
asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
|
| 2291 |
+
: "=h"(result_ptr[N - 1])
|
| 2292 |
+
: "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1])
|
| 2293 |
+
);
|
| 2294 |
+
}
|
| 2295 |
+
|
| 2296 |
+
#else
|
| 2297 |
+
|
| 2298 |
+
multiply_add<bfloat16_t> op;
|
| 2299 |
+
|
| 2300 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2301 |
+
for (int i = 0; i < N; ++i) {
|
| 2302 |
+
result[i] = op(a[i], b[i], c[i]);
|
| 2303 |
+
}
|
| 2304 |
+
#endif
|
| 2305 |
+
|
| 2306 |
+
return result;
|
| 2307 |
+
}
|
| 2308 |
+
|
| 2309 |
+
CUTLASS_HOST_DEVICE
|
| 2310 |
+
Array<bfloat16_t, N> operator()(
|
| 2311 |
+
bfloat16_t const &a,
|
| 2312 |
+
Array<bfloat16_t, N> const &b,
|
| 2313 |
+
Array<bfloat16_t, N> const &c) const {
|
| 2314 |
+
|
| 2315 |
+
Array<bfloat16_t, N> result;
|
| 2316 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2317 |
+
|
| 2318 |
+
unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
|
| 2319 |
+
|
| 2320 |
+
unsigned const *b_ptr = reinterpret_cast<unsigned const *>(&b);
|
| 2321 |
+
unsigned const *c_ptr = reinterpret_cast<unsigned const *>(&c);
|
| 2322 |
+
|
| 2323 |
+
unsigned a_packed = static_cast<unsigned>(a.raw());
|
| 2324 |
+
a_packed = (a_packed | (a_packed << 16));
|
| 2325 |
+
|
| 2326 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2327 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2328 |
+
asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
|
| 2329 |
+
: "=r"(result_ptr[i])
|
| 2330 |
+
: "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i])
|
| 2331 |
+
);
|
| 2332 |
+
}
|
| 2333 |
+
|
| 2334 |
+
if constexpr (N % 2) {
|
| 2335 |
+
|
| 2336 |
+
uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
|
| 2337 |
+
uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
|
| 2338 |
+
uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
|
| 2339 |
+
uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
|
| 2340 |
+
|
| 2341 |
+
asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
|
| 2342 |
+
: "=h"(result_ptr[N - 1])
|
| 2343 |
+
: "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1])
|
| 2344 |
+
);
|
| 2345 |
+
}
|
| 2346 |
+
|
| 2347 |
+
#else
|
| 2348 |
+
|
| 2349 |
+
multiply_add<bfloat16_t> op;
|
| 2350 |
+
|
| 2351 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2352 |
+
for (int i = 0; i < N; ++i) {
|
| 2353 |
+
result[i] = op(a, b[i], c[i]);
|
| 2354 |
+
}
|
| 2355 |
+
#endif
|
| 2356 |
+
|
| 2357 |
+
return result;
|
| 2358 |
+
}
|
| 2359 |
+
|
| 2360 |
+
CUTLASS_HOST_DEVICE
|
| 2361 |
+
Array<bfloat16_t, N> operator()(
|
| 2362 |
+
Array<bfloat16_t, N> const &a,
|
| 2363 |
+
bfloat16_t const &b,
|
| 2364 |
+
Array<bfloat16_t, N> const &c) const {
|
| 2365 |
+
|
| 2366 |
+
Array<bfloat16_t, N> result;
|
| 2367 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2368 |
+
|
| 2369 |
+
unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
|
| 2370 |
+
|
| 2371 |
+
unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
|
| 2372 |
+
unsigned const *c_ptr = reinterpret_cast<unsigned const *>(&c);
|
| 2373 |
+
|
| 2374 |
+
unsigned b_packed = static_cast<unsigned>(b.raw());
|
| 2375 |
+
b_packed = (b_packed | (b_packed << 16));
|
| 2376 |
+
|
| 2377 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2378 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2379 |
+
asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
|
| 2380 |
+
: "=r"(result_ptr[i])
|
| 2381 |
+
: "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i])
|
| 2382 |
+
);
|
| 2383 |
+
}
|
| 2384 |
+
|
| 2385 |
+
if constexpr (N % 2) {
|
| 2386 |
+
|
| 2387 |
+
uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
|
| 2388 |
+
uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
|
| 2389 |
+
uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
|
| 2390 |
+
uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
|
| 2391 |
+
|
| 2392 |
+
asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
|
| 2393 |
+
: "=h"(result_ptr[N - 1])
|
| 2394 |
+
: "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1])
|
| 2395 |
+
);
|
| 2396 |
+
}
|
| 2397 |
+
|
| 2398 |
+
#else
|
| 2399 |
+
|
| 2400 |
+
multiply_add<bfloat16_t> op;
|
| 2401 |
+
|
| 2402 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2403 |
+
for (int i = 0; i < N; ++i) {
|
| 2404 |
+
result[i] = op(a[i], b, c[i]);
|
| 2405 |
+
}
|
| 2406 |
+
#endif
|
| 2407 |
+
|
| 2408 |
+
return result;
|
| 2409 |
+
}
|
| 2410 |
+
|
| 2411 |
+
CUTLASS_HOST_DEVICE
|
| 2412 |
+
Array<bfloat16_t, N> operator()(
|
| 2413 |
+
Array<bfloat16_t, N> const &a,
|
| 2414 |
+
Array<bfloat16_t, N> const &b,
|
| 2415 |
+
bfloat16_t const &c) const {
|
| 2416 |
+
|
| 2417 |
+
Array<bfloat16_t, N> result;
|
| 2418 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2419 |
+
|
| 2420 |
+
unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
|
| 2421 |
+
|
| 2422 |
+
unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
|
| 2423 |
+
unsigned const *b_ptr = reinterpret_cast<unsigned const *>(&b);
|
| 2424 |
+
|
| 2425 |
+
unsigned c_packed = static_cast<unsigned>(c.raw());
|
| 2426 |
+
c_packed = (c_packed | (c_packed << 16));
|
| 2427 |
+
|
| 2428 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2429 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2430 |
+
asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
|
| 2431 |
+
: "=r"(result_ptr[i])
|
| 2432 |
+
: "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed)
|
| 2433 |
+
);
|
| 2434 |
+
}
|
| 2435 |
+
|
| 2436 |
+
if constexpr (N % 2) {
|
| 2437 |
+
|
| 2438 |
+
uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
|
| 2439 |
+
uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
|
| 2440 |
+
uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
|
| 2441 |
+
uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
|
| 2442 |
+
|
| 2443 |
+
asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
|
| 2444 |
+
: "=h"(result_ptr[N - 1])
|
| 2445 |
+
: "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0])
|
| 2446 |
+
);
|
| 2447 |
+
}
|
| 2448 |
+
|
| 2449 |
+
#else
|
| 2450 |
+
|
| 2451 |
+
multiply_add<bfloat16_t> op;
|
| 2452 |
+
|
| 2453 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2454 |
+
for (int i = 0; i < N; ++i) {
|
| 2455 |
+
result[i] = op(a[i], b[i], c);
|
| 2456 |
+
}
|
| 2457 |
+
#endif
|
| 2458 |
+
|
| 2459 |
+
return result;
|
| 2460 |
+
}
|
| 2461 |
+
|
| 2462 |
+
CUTLASS_HOST_DEVICE
|
| 2463 |
+
Array<bfloat16_t, N> operator()(
|
| 2464 |
+
Array<bfloat16_t, N> const &a,
|
| 2465 |
+
bfloat16_t const &b,
|
| 2466 |
+
bfloat16_t const &c) const {
|
| 2467 |
+
|
| 2468 |
+
Array<bfloat16_t, N> result;
|
| 2469 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 2470 |
+
|
| 2471 |
+
unsigned *result_ptr = reinterpret_cast<unsigned *>(&result);
|
| 2472 |
+
|
| 2473 |
+
unsigned const *a_ptr = reinterpret_cast<unsigned const *>(&a);
|
| 2474 |
+
|
| 2475 |
+
unsigned b_packed = static_cast<unsigned>(b.raw());
|
| 2476 |
+
b_packed = (b_packed | (b_packed << 16));
|
| 2477 |
+
|
| 2478 |
+
unsigned c_packed = static_cast<unsigned>(c.raw());
|
| 2479 |
+
c_packed = (c_packed | (c_packed << 16));
|
| 2480 |
+
|
| 2481 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2482 |
+
for (int i = 0; i < N / 2; ++i) {
|
| 2483 |
+
asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n"
|
| 2484 |
+
: "=r"(result_ptr[i])
|
| 2485 |
+
: "r"(a_ptr[i]), "r"(b_packed), "r"(c_packed)
|
| 2486 |
+
);
|
| 2487 |
+
}
|
| 2488 |
+
|
| 2489 |
+
if constexpr (N % 2) {
|
| 2490 |
+
|
| 2491 |
+
uint16_t *result_ptr = reinterpret_cast<uint16_t *>(&result);
|
| 2492 |
+
uint16_t const *a_residual_ptr = reinterpret_cast<uint16_t const *>(&a);
|
| 2493 |
+
uint16_t const *b_residual_ptr = reinterpret_cast<uint16_t const *>(&b);
|
| 2494 |
+
uint16_t const *c_residual_ptr = reinterpret_cast<uint16_t const *>(&c);
|
| 2495 |
+
|
| 2496 |
+
asm ("fma.rn.bf16 %0, %1, %2, %3;\n"
|
| 2497 |
+
: "=h"(result_ptr[N - 1])
|
| 2498 |
+
: "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[0])
|
| 2499 |
+
);
|
| 2500 |
+
}
|
| 2501 |
+
|
| 2502 |
+
|
| 2503 |
+
#else
|
| 2504 |
+
|
| 2505 |
+
multiply_add<bfloat16_t> op;
|
| 2506 |
+
|
| 2507 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2508 |
+
for (int i = 0; i < N; ++i) {
|
| 2509 |
+
result[i] = op(a[i], b, c);
|
| 2510 |
+
}
|
| 2511 |
+
#endif
|
| 2512 |
+
|
| 2513 |
+
return result;
|
| 2514 |
+
}
|
| 2515 |
+
};
|
| 2516 |
+
|
| 2517 |
+
|
| 2518 |
+
/// bit_and
|
| 2519 |
+
template <int N>
|
| 2520 |
+
struct bit_and<Array<uint1b_t, N>> {
|
| 2521 |
+
CUTLASS_HOST_DEVICE
|
| 2522 |
+
Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a, Array<uint1b_t, N> const &b) const {
|
| 2523 |
+
using ArrayType = Array<uint1b_t, N>;
|
| 2524 |
+
using Storage = typename ArrayType::Storage;
|
| 2525 |
+
ArrayType result;
|
| 2526 |
+
|
| 2527 |
+
Storage *result_data = result.raw_data();
|
| 2528 |
+
Storage const *a_data = a.raw_data();
|
| 2529 |
+
Storage const *b_data = b.raw_data();
|
| 2530 |
+
|
| 2531 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2532 |
+
for (int i = 0; i < ArrayType::kStorageElements; ++i) {
|
| 2533 |
+
result_data[i] = (a_data[i] & b_data[i]);
|
| 2534 |
+
}
|
| 2535 |
+
|
| 2536 |
+
return result;
|
| 2537 |
+
}
|
| 2538 |
+
};
|
| 2539 |
+
|
| 2540 |
+
|
| 2541 |
+
/// bit_or
|
| 2542 |
+
template <int N>
|
| 2543 |
+
struct bit_or<Array<uint1b_t, N>> {
|
| 2544 |
+
CUTLASS_HOST_DEVICE
|
| 2545 |
+
Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a, Array<uint1b_t, N> const &b) const {
|
| 2546 |
+
using ArrayType = Array<uint1b_t, N>;
|
| 2547 |
+
using Storage = typename ArrayType::Storage;
|
| 2548 |
+
ArrayType result;
|
| 2549 |
+
|
| 2550 |
+
Storage *result_data = result.raw_data();
|
| 2551 |
+
Storage const *a_data = a.raw_data();
|
| 2552 |
+
Storage const *b_data = b.raw_data();
|
| 2553 |
+
|
| 2554 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2555 |
+
for (int i = 0; i < ArrayType::kStorageElements; ++i) {
|
| 2556 |
+
result_data[i] = (a_data[i] | b_data[i]);
|
| 2557 |
+
}
|
| 2558 |
+
|
| 2559 |
+
return result;
|
| 2560 |
+
}
|
| 2561 |
+
};
|
| 2562 |
+
|
| 2563 |
+
|
| 2564 |
+
/// bit_not
|
| 2565 |
+
template <int N>
|
| 2566 |
+
struct bit_not<Array<uint1b_t, N>> {
|
| 2567 |
+
CUTLASS_HOST_DEVICE
|
| 2568 |
+
Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a) const {
|
| 2569 |
+
using ArrayType = Array<uint1b_t, N>;
|
| 2570 |
+
using Storage = typename ArrayType::Storage;
|
| 2571 |
+
ArrayType result;
|
| 2572 |
+
|
| 2573 |
+
Storage *result_data = result.raw_data();
|
| 2574 |
+
Storage const *a_data = a.raw_data();
|
| 2575 |
+
|
| 2576 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2577 |
+
for (int i = 0; i < ArrayType::kStorageElements; ++i) {
|
| 2578 |
+
result_data[i] = (~a_data[i]);
|
| 2579 |
+
}
|
| 2580 |
+
|
| 2581 |
+
return result;
|
| 2582 |
+
}
|
| 2583 |
+
};
|
| 2584 |
+
|
| 2585 |
+
/// bit_xor
|
| 2586 |
+
template <int N>
|
| 2587 |
+
struct bit_xor<Array<uint1b_t, N>> {
|
| 2588 |
+
CUTLASS_HOST_DEVICE
|
| 2589 |
+
Array<uint1b_t, N> operator()(Array<uint1b_t, N> const &a, Array<uint1b_t, N> const &b) const {
|
| 2590 |
+
using ArrayType = Array<uint1b_t, N>;
|
| 2591 |
+
using Storage = typename ArrayType::Storage;
|
| 2592 |
+
ArrayType result;
|
| 2593 |
+
|
| 2594 |
+
Storage *result_data = result.raw_data();
|
| 2595 |
+
Storage const *a_data = a.raw_data();
|
| 2596 |
+
Storage const *b_data = b.raw_data();
|
| 2597 |
+
|
| 2598 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2599 |
+
for (int i = 0; i < ArrayType::kStorageElements; ++i) {
|
| 2600 |
+
result_data[i] = (a_data[i] ^ b_data[i]);
|
| 2601 |
+
}
|
| 2602 |
+
|
| 2603 |
+
return result;
|
| 2604 |
+
}
|
| 2605 |
+
};
|
| 2606 |
+
|
| 2607 |
+
/// Fused and-popc-add
|
| 2608 |
+
template <typename T, int N>
|
| 2609 |
+
struct and_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
| 2610 |
+
CUTLASS_HOST_DEVICE
|
| 2611 |
+
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 2612 |
+
Array<T, N> result;
|
| 2613 |
+
and_popc_add<T> scalar_op;
|
| 2614 |
+
|
| 2615 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2616 |
+
for (int i = 0; i < N; ++i) {
|
| 2617 |
+
result[i] = scalar_op(a[i], b[i], c[i]);
|
| 2618 |
+
}
|
| 2619 |
+
|
| 2620 |
+
return result;
|
| 2621 |
+
}
|
| 2622 |
+
|
| 2623 |
+
CUTLASS_HOST_DEVICE
|
| 2624 |
+
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
| 2625 |
+
Array<T, N> result;
|
| 2626 |
+
and_popc_add<T> scalar_op;
|
| 2627 |
+
|
| 2628 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2629 |
+
for (int i = 0; i < N; ++i) {
|
| 2630 |
+
result[i] = scalar_op(a[i], scalar, c[i]);
|
| 2631 |
+
}
|
| 2632 |
+
|
| 2633 |
+
return result;
|
| 2634 |
+
}
|
| 2635 |
+
|
| 2636 |
+
CUTLASS_HOST_DEVICE
|
| 2637 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 2638 |
+
Array<T, N> result;
|
| 2639 |
+
and_popc_add<T> scalar_op;
|
| 2640 |
+
|
| 2641 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2642 |
+
for (int i = 0; i < N; ++i) {
|
| 2643 |
+
result[i] = scalar_op(scalar, b[i], c[i]);
|
| 2644 |
+
}
|
| 2645 |
+
|
| 2646 |
+
return result;
|
| 2647 |
+
}
|
| 2648 |
+
};
|
| 2649 |
+
|
| 2650 |
+
|
| 2651 |
+
/// Fused or-popc-add
|
| 2652 |
+
template <typename T, int N>
|
| 2653 |
+
struct or_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
| 2654 |
+
CUTLASS_HOST_DEVICE
|
| 2655 |
+
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 2656 |
+
Array<T, N> result;
|
| 2657 |
+
or_popc_add<T> scalar_op;
|
| 2658 |
+
|
| 2659 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2660 |
+
for (int i = 0; i < N; ++i) {
|
| 2661 |
+
result[i] = scalar_op(a[i], b[i], c[i]);
|
| 2662 |
+
}
|
| 2663 |
+
|
| 2664 |
+
return result;
|
| 2665 |
+
}
|
| 2666 |
+
|
| 2667 |
+
CUTLASS_HOST_DEVICE
|
| 2668 |
+
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
| 2669 |
+
Array<T, N> result;
|
| 2670 |
+
or_popc_add<T> scalar_op;
|
| 2671 |
+
|
| 2672 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2673 |
+
for (int i = 0; i < N; ++i) {
|
| 2674 |
+
result[i] = scalar_op(a[i], scalar, c[i]);
|
| 2675 |
+
}
|
| 2676 |
+
|
| 2677 |
+
return result;
|
| 2678 |
+
}
|
| 2679 |
+
|
| 2680 |
+
CUTLASS_HOST_DEVICE
|
| 2681 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 2682 |
+
Array<T, N> result;
|
| 2683 |
+
or_popc_add<T> scalar_op;
|
| 2684 |
+
|
| 2685 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2686 |
+
for (int i = 0; i < N; ++i) {
|
| 2687 |
+
result[i] = scalar_op(scalar, b[i], c[i]);
|
| 2688 |
+
}
|
| 2689 |
+
|
| 2690 |
+
return result;
|
| 2691 |
+
}
|
| 2692 |
+
};
|
| 2693 |
+
|
| 2694 |
+
/// Fused xor-popc-add
|
| 2695 |
+
template <typename T, int N>
|
| 2696 |
+
struct xor_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
| 2697 |
+
CUTLASS_HOST_DEVICE
|
| 2698 |
+
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 2699 |
+
Array<T, N> result;
|
| 2700 |
+
xor_popc_add<T> scalar_op;
|
| 2701 |
+
|
| 2702 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2703 |
+
for (int i = 0; i < N; ++i) {
|
| 2704 |
+
result[i] = scalar_op(a[i], b[i], c[i]);
|
| 2705 |
+
}
|
| 2706 |
+
|
| 2707 |
+
return result;
|
| 2708 |
+
}
|
| 2709 |
+
|
| 2710 |
+
CUTLASS_HOST_DEVICE
|
| 2711 |
+
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
| 2712 |
+
Array<T, N> result;
|
| 2713 |
+
xor_popc_add<T> scalar_op;
|
| 2714 |
+
|
| 2715 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2716 |
+
for (int i = 0; i < N; ++i) {
|
| 2717 |
+
result[i] = scalar_op(a[i], scalar, c[i]);
|
| 2718 |
+
}
|
| 2719 |
+
|
| 2720 |
+
return result;
|
| 2721 |
+
}
|
| 2722 |
+
|
| 2723 |
+
CUTLASS_HOST_DEVICE
|
| 2724 |
+
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
| 2725 |
+
Array<T, N> result;
|
| 2726 |
+
xor_popc_add<T> scalar_op;
|
| 2727 |
+
|
| 2728 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2729 |
+
for (int i = 0; i < N; ++i) {
|
| 2730 |
+
result[i] = scalar_op(scalar, b[i], c[i]);
|
| 2731 |
+
}
|
| 2732 |
+
|
| 2733 |
+
return result;
|
| 2734 |
+
}
|
| 2735 |
+
};
|
| 2736 |
+
|
| 2737 |
+
|
| 2738 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2739 |
+
// Operator overloads
|
| 2740 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2741 |
+
|
| 2742 |
+
template <typename T, int N>
|
| 2743 |
+
CUTLASS_HOST_DEVICE
|
| 2744 |
+
Array<T, N> operator+(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
| 2745 |
+
plus<Array<T, N>> op;
|
| 2746 |
+
return op(lhs, rhs);
|
| 2747 |
+
}
|
| 2748 |
+
|
| 2749 |
+
template <typename T, int N>
|
| 2750 |
+
CUTLASS_HOST_DEVICE
|
| 2751 |
+
Array<T, N> operator+(T const &lhs, Array<T, N> const &rhs) {
|
| 2752 |
+
plus<Array<T, N>> op;
|
| 2753 |
+
return op(lhs, rhs);
|
| 2754 |
+
}
|
| 2755 |
+
|
| 2756 |
+
template <typename T, int N>
|
| 2757 |
+
CUTLASS_HOST_DEVICE
|
| 2758 |
+
Array<T, N> operator+(Array<T, N> const &lhs, T const &rhs) {
|
| 2759 |
+
plus<Array<T, N>> op;
|
| 2760 |
+
return op(lhs, rhs);
|
| 2761 |
+
}
|
| 2762 |
+
|
| 2763 |
+
template <typename T, int N>
|
| 2764 |
+
CUTLASS_HOST_DEVICE
|
| 2765 |
+
Array<T, N> operator-(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
| 2766 |
+
minus<Array<T, N>> op;
|
| 2767 |
+
return op(lhs, rhs);
|
| 2768 |
+
}
|
| 2769 |
+
|
| 2770 |
+
template <typename T, int N>
|
| 2771 |
+
CUTLASS_HOST_DEVICE
|
| 2772 |
+
Array<T, N> operator-(Array<T, N> const &lhs) {
|
| 2773 |
+
negate<Array<T, N>> op;
|
| 2774 |
+
return op(lhs);
|
| 2775 |
+
}
|
| 2776 |
+
|
| 2777 |
+
template <typename T, int N>
|
| 2778 |
+
CUTLASS_HOST_DEVICE
|
| 2779 |
+
Array<T, N> operator*(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
| 2780 |
+
multiplies<Array<T, N>> op;
|
| 2781 |
+
return op(lhs, rhs);
|
| 2782 |
+
}
|
| 2783 |
+
|
| 2784 |
+
template <typename T, int N>
|
| 2785 |
+
CUTLASS_HOST_DEVICE
|
| 2786 |
+
Array<T, N> operator*(T lhs, Array<T, N> const &rhs) {
|
| 2787 |
+
multiplies<Array<T, N>> op;
|
| 2788 |
+
return op(lhs, rhs);
|
| 2789 |
+
}
|
| 2790 |
+
|
| 2791 |
+
template <typename T, int N>
|
| 2792 |
+
CUTLASS_HOST_DEVICE
|
| 2793 |
+
Array<T, N> operator*(Array<T, N> const &lhs, T rhs) {
|
| 2794 |
+
multiplies<Array<T, N>> op;
|
| 2795 |
+
return op(lhs, rhs);
|
| 2796 |
+
}
|
| 2797 |
+
|
| 2798 |
+
template <typename T, int N>
|
| 2799 |
+
CUTLASS_HOST_DEVICE
|
| 2800 |
+
Array<T, N> operator/(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
| 2801 |
+
divides<Array<T, N>> op;
|
| 2802 |
+
return op(lhs, rhs);
|
| 2803 |
+
}
|
| 2804 |
+
|
| 2805 |
+
template <typename T, int N>
|
| 2806 |
+
CUTLASS_HOST_DEVICE
|
| 2807 |
+
Array<T, N> fma(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) {
|
| 2808 |
+
multiply_add<Array<T, N>> op;
|
| 2809 |
+
return op(a, b, c);
|
| 2810 |
+
}
|
| 2811 |
+
|
| 2812 |
+
template <typename T, int N>
|
| 2813 |
+
CUTLASS_HOST_DEVICE
|
| 2814 |
+
Array<T, N> fma(T a, Array<T, N> const &b, Array<T, N> const &c) {
|
| 2815 |
+
multiply_add<Array<T, N>> op;
|
| 2816 |
+
return op(a, b, c);
|
| 2817 |
+
}
|
| 2818 |
+
|
| 2819 |
+
template <typename T, int N>
|
| 2820 |
+
CUTLASS_HOST_DEVICE
|
| 2821 |
+
Array<T, N> fma(Array<T, N> const &a, T b, Array<T, N> const &c) {
|
| 2822 |
+
multiply_add<Array<T, N>> op;
|
| 2823 |
+
return op(a, b, c);
|
| 2824 |
+
}
|
| 2825 |
+
|
| 2826 |
+
template <typename T, int N>
|
| 2827 |
+
CUTLASS_HOST_DEVICE
|
| 2828 |
+
Array<T, N> fma(Array<T, N> const &a, Array<T, N> const &b, T c) {
|
| 2829 |
+
multiply_add<Array<T, N>> op;
|
| 2830 |
+
return op(a, b, c);
|
| 2831 |
+
}
|
| 2832 |
+
|
| 2833 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2834 |
+
|
| 2835 |
+
|
| 2836 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2837 |
+
// AlignedArray
|
| 2838 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2839 |
+
|
| 2840 |
+
/// Aligned array type
|
| 2841 |
+
template <
|
| 2842 |
+
/// Element type
|
| 2843 |
+
typename T,
|
| 2844 |
+
/// Number of elements in the array
|
| 2845 |
+
int N,
|
| 2846 |
+
/// Alignment requirement in bytes
|
| 2847 |
+
int Alignment = ( sizeof_bits<T>::value * N + 7 ) / 8
|
| 2848 |
+
>
|
| 2849 |
+
class alignas(Alignment) AlignedArray: public Array<T, N> {
|
| 2850 |
+
public:
|
| 2851 |
+
|
| 2852 |
+
};
|
| 2853 |
+
|
| 2854 |
+
} // namespace cutlass
|
| 2855 |
+
|
| 2856 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2857 |
+
|
| 2858 |
+
#include "cutlass/array_subbyte.h"
|
| 2859 |
+
|
| 2860 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing warp-level matrix multiply-accumulate operations.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
/// Array holding planar complex elements
|
| 47 |
+
template <typename Element_, int N>
|
| 48 |
+
struct ArrayPlanarComplex {
|
| 49 |
+
|
| 50 |
+
/// Underlying real element
|
| 51 |
+
using Element = Element_;
|
| 52 |
+
|
| 53 |
+
/// Number of logical elements
|
| 54 |
+
static constexpr size_t kElements = N;
|
| 55 |
+
|
| 56 |
+
/// Underlying Fragment of real-valued elemenets
|
| 57 |
+
using ArrayReal = cutlass::Array<Element, N>;
|
| 58 |
+
|
| 59 |
+
public:
|
| 60 |
+
/// Fragment of real-valued elements representing the real part
|
| 61 |
+
ArrayReal real;
|
| 62 |
+
|
| 63 |
+
/// Fragment of real-valued elements representing the imaginary part
|
| 64 |
+
ArrayReal imag;
|
| 65 |
+
|
| 66 |
+
public:
|
| 67 |
+
/// Sets the array to zero efficiently
|
| 68 |
+
CUTLASS_HOST_DEVICE
|
| 69 |
+
void clear() {
|
| 70 |
+
real.clear();
|
| 71 |
+
imag.clear();
|
| 72 |
+
}
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 76 |
+
|
| 77 |
+
/// Helper to deduce template arguments
|
| 78 |
+
template <typename Element, int N>
|
| 79 |
+
CUTLASS_HOST_DEVICE
|
| 80 |
+
ArrayPlanarComplex<Element, N>
|
| 81 |
+
make_ArrayPlanarComplex(Array<Element, N> const &real, Array<Element, N> const &imag) {
|
| 82 |
+
return ArrayPlanarComplex<Element, N>{real, imag};
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 86 |
+
|
| 87 |
+
} // namespace cutlass
|
| 88 |
+
|
| 89 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types
|
| 33 |
+
and is safe to use in a union.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/platform/platform.h"
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
/// Statically sized array for any data type
|
| 47 |
+
template <
|
| 48 |
+
typename T,
|
| 49 |
+
int N
|
| 50 |
+
>
|
| 51 |
+
struct Array<T, N, false> {
|
| 52 |
+
static constexpr int kSizeBits = sizeof_bits<T>::value * N;
|
| 53 |
+
|
| 54 |
+
/// Storage type
|
| 55 |
+
using Storage = typename platform::conditional<
|
| 56 |
+
((kSizeBits % 32) != 0),
|
| 57 |
+
typename platform::conditional<
|
| 58 |
+
((kSizeBits % 16) != 0),
|
| 59 |
+
uint8_t,
|
| 60 |
+
uint16_t
|
| 61 |
+
>::type,
|
| 62 |
+
uint32_t
|
| 63 |
+
>::type;
|
| 64 |
+
|
| 65 |
+
/// Element type
|
| 66 |
+
using Element = T;
|
| 67 |
+
|
| 68 |
+
/// Number of logical elements per stored object
|
| 69 |
+
static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits<T>::value;
|
| 70 |
+
|
| 71 |
+
/// Number of storage elements
|
| 72 |
+
static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;
|
| 73 |
+
|
| 74 |
+
/// Number of logical elements
|
| 75 |
+
static constexpr size_t kElements = N;
|
| 76 |
+
|
| 77 |
+
/// Bitmask for covering one item
|
| 78 |
+
static constexpr Storage kMask = ((Storage(1) << sizeof_bits<T>::value) - 1);
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// C++ standard members with pointer types removed
|
| 82 |
+
//
|
| 83 |
+
|
| 84 |
+
typedef T value_type;
|
| 85 |
+
typedef size_t size_type;
|
| 86 |
+
typedef ptrdiff_t difference_type;
|
| 87 |
+
typedef value_type *pointer;
|
| 88 |
+
typedef value_type const *const_pointer;
|
| 89 |
+
|
| 90 |
+
//
|
| 91 |
+
// References
|
| 92 |
+
//
|
| 93 |
+
|
| 94 |
+
/// Reference object inserts or extracts sub-byte items
|
| 95 |
+
class reference {
|
| 96 |
+
/// Pointer to storage element
|
| 97 |
+
Storage *ptr_{nullptr};
|
| 98 |
+
|
| 99 |
+
/// Index into elements packed into Storage object
|
| 100 |
+
int idx_{0};
|
| 101 |
+
|
| 102 |
+
public:
|
| 103 |
+
|
| 104 |
+
reference() = default;
|
| 105 |
+
|
| 106 |
+
/// Ctor
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
| 109 |
+
|
| 110 |
+
/// Assignment
|
| 111 |
+
CUTLASS_HOST_DEVICE
|
| 112 |
+
reference &operator=(T x) {
|
| 113 |
+
// `*ptr_ & kUpdateMask` will read ptr_ before write to it
|
| 114 |
+
// This means code pattern like
|
| 115 |
+
//
|
| 116 |
+
// ```cpp
|
| 117 |
+
// Array<half_t, N> result;
|
| 118 |
+
// result[0] = xxx;
|
| 119 |
+
// ```
|
| 120 |
+
//
|
| 121 |
+
// Will leads to compiler warning on use of uninitialized member variable. Although we know
|
| 122 |
+
// this read of uninitialized member variable is harmeless.
|
| 123 |
+
|
| 124 |
+
#if defined(__clang__)
|
| 125 |
+
# pragma clang diagnostic push
|
| 126 |
+
# pragma clang diagnostic ignored "-Wuninitialized"
|
| 127 |
+
#elif defined(__GNUC__)
|
| 128 |
+
# pragma GCC diagnostic push
|
| 129 |
+
# pragma GCC diagnostic ignored "-Wuninitialized"
|
| 130 |
+
# pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
| 131 |
+
#endif
|
| 132 |
+
|
| 133 |
+
Storage item = (reinterpret_cast<Storage const &>(x) & kMask);
|
| 134 |
+
|
| 135 |
+
Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits<T>::value)));
|
| 136 |
+
|
| 137 |
+
*ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits<T>::value)));
|
| 138 |
+
|
| 139 |
+
#if defined(__clang__)
|
| 140 |
+
# pragma clang diagnostic pop
|
| 141 |
+
#elif defined(__GNUC__)
|
| 142 |
+
# pragma GCC diagnostic pop
|
| 143 |
+
#endif
|
| 144 |
+
|
| 145 |
+
return *this;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
CUTLASS_HOST_DEVICE
|
| 149 |
+
T get() const {
|
| 150 |
+
Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask);
|
| 151 |
+
return reinterpret_cast<T const &>(item);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/// Extract
|
| 155 |
+
CUTLASS_HOST_DEVICE
|
| 156 |
+
operator T() const {
|
| 157 |
+
return get();
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Explicit cast to int
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
explicit operator int() const {
|
| 163 |
+
return int(get());
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Explicit cast to float
|
| 167 |
+
CUTLASS_HOST_DEVICE
|
| 168 |
+
explicit operator float() const {
|
| 169 |
+
return float(get());
|
| 170 |
+
}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
/// Reference object extracts sub-byte items
|
| 174 |
+
class const_reference {
|
| 175 |
+
|
| 176 |
+
/// Pointer to storage element
|
| 177 |
+
Storage const *ptr_{nullptr};
|
| 178 |
+
|
| 179 |
+
/// Index into elements packed into Storage object
|
| 180 |
+
int idx_{0};
|
| 181 |
+
|
| 182 |
+
public:
|
| 183 |
+
|
| 184 |
+
const_reference() = default;
|
| 185 |
+
|
| 186 |
+
/// Ctor
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
| 189 |
+
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
const T get() const {
|
| 192 |
+
Storage item = (*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask;
|
| 193 |
+
return reinterpret_cast<T const &>(item);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
/// Extract
|
| 197 |
+
CUTLASS_HOST_DEVICE
|
| 198 |
+
operator T() const {
|
| 199 |
+
Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits<T>::value)) & kMask);
|
| 200 |
+
return reinterpret_cast<T const &>(item);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
/// Explicit cast to int
|
| 204 |
+
CUTLASS_HOST_DEVICE
|
| 205 |
+
explicit operator int() const {
|
| 206 |
+
return int(get());
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Explicit cast to float
|
| 210 |
+
CUTLASS_HOST_DEVICE
|
| 211 |
+
explicit operator float() const {
|
| 212 |
+
return float(get());
|
| 213 |
+
}
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
//
|
| 217 |
+
// Iterators
|
| 218 |
+
//
|
| 219 |
+
|
| 220 |
+
/// Bidirectional iterator over elements
|
| 221 |
+
class iterator {
|
| 222 |
+
|
| 223 |
+
/// Pointer to storage element
|
| 224 |
+
Storage *ptr_{nullptr};
|
| 225 |
+
|
| 226 |
+
/// Index into elements packed into Storage object
|
| 227 |
+
int idx_{0};
|
| 228 |
+
|
| 229 |
+
public:
|
| 230 |
+
|
| 231 |
+
iterator() = default;
|
| 232 |
+
|
| 233 |
+
CUTLASS_HOST_DEVICE
|
| 234 |
+
iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
| 235 |
+
|
| 236 |
+
CUTLASS_HOST_DEVICE
|
| 237 |
+
iterator &operator++() {
|
| 238 |
+
++idx_;
|
| 239 |
+
if (idx_ == kElementsPerStoredItem) {
|
| 240 |
+
++ptr_;
|
| 241 |
+
idx_ = 0;
|
| 242 |
+
}
|
| 243 |
+
return *this;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
CUTLASS_HOST_DEVICE
|
| 247 |
+
iterator &operator--() {
|
| 248 |
+
if (!idx_) {
|
| 249 |
+
--ptr_;
|
| 250 |
+
idx_ = kElementsPerStoredItem - 1;
|
| 251 |
+
}
|
| 252 |
+
else {
|
| 253 |
+
--idx_;
|
| 254 |
+
}
|
| 255 |
+
return *this;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
CUTLASS_HOST_DEVICE
|
| 259 |
+
iterator operator++(int) {
|
| 260 |
+
iterator ret(*this);
|
| 261 |
+
++idx_;
|
| 262 |
+
if (idx_ == kElementsPerStoredItem) {
|
| 263 |
+
++ptr_;
|
| 264 |
+
idx_ = 0;
|
| 265 |
+
}
|
| 266 |
+
return ret;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
CUTLASS_HOST_DEVICE
|
| 270 |
+
iterator operator--(int) {
|
| 271 |
+
iterator ret(*this);
|
| 272 |
+
if (!idx_) {
|
| 273 |
+
--ptr_;
|
| 274 |
+
idx_ = kElementsPerStoredItem - 1;
|
| 275 |
+
}
|
| 276 |
+
else {
|
| 277 |
+
--idx_;
|
| 278 |
+
}
|
| 279 |
+
return ret;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
CUTLASS_HOST_DEVICE
|
| 283 |
+
reference operator*() const {
|
| 284 |
+
return reference(ptr_, idx_);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
CUTLASS_HOST_DEVICE
|
| 288 |
+
bool operator==(iterator const &other) const {
|
| 289 |
+
return ptr_ == other.ptr_ && idx_ == other.idx_;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
CUTLASS_HOST_DEVICE
|
| 293 |
+
bool operator!=(iterator const &other) const {
|
| 294 |
+
return !(*this == other);
|
| 295 |
+
}
|
| 296 |
+
};
|
| 297 |
+
|
| 298 |
+
/// Bidirectional constant iterator over elements
|
| 299 |
+
class const_iterator {
|
| 300 |
+
|
| 301 |
+
/// Pointer to storage element
|
| 302 |
+
Storage const *ptr_{nullptr};
|
| 303 |
+
|
| 304 |
+
/// Index into elements packed into Storage object
|
| 305 |
+
int idx_{0};
|
| 306 |
+
|
| 307 |
+
public:
|
| 308 |
+
|
| 309 |
+
const_iterator() = default;
|
| 310 |
+
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
| 313 |
+
|
| 314 |
+
CUTLASS_HOST_DEVICE
|
| 315 |
+
iterator &operator++() {
|
| 316 |
+
++idx_;
|
| 317 |
+
if (idx_ == kElementsPerStoredItem) {
|
| 318 |
+
++ptr_;
|
| 319 |
+
idx_ = 0;
|
| 320 |
+
}
|
| 321 |
+
return *this;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
CUTLASS_HOST_DEVICE
|
| 325 |
+
iterator &operator--() {
|
| 326 |
+
if (!idx_) {
|
| 327 |
+
--ptr_;
|
| 328 |
+
idx_ = kElementsPerStoredItem - 1;
|
| 329 |
+
}
|
| 330 |
+
else {
|
| 331 |
+
--idx_;
|
| 332 |
+
}
|
| 333 |
+
return *this;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
CUTLASS_HOST_DEVICE
|
| 337 |
+
iterator operator++(int) {
|
| 338 |
+
iterator ret(*this);
|
| 339 |
+
++idx_;
|
| 340 |
+
if (idx_ == kElementsPerStoredItem) {
|
| 341 |
+
++ptr_;
|
| 342 |
+
idx_ = 0;
|
| 343 |
+
}
|
| 344 |
+
return ret;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
CUTLASS_HOST_DEVICE
|
| 348 |
+
iterator operator--(int) {
|
| 349 |
+
iterator ret(*this);
|
| 350 |
+
if (!idx_) {
|
| 351 |
+
--ptr_;
|
| 352 |
+
idx_ = kElementsPerStoredItem - 1;
|
| 353 |
+
}
|
| 354 |
+
else {
|
| 355 |
+
--idx_;
|
| 356 |
+
}
|
| 357 |
+
return ret;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
CUTLASS_HOST_DEVICE
|
| 361 |
+
const_reference operator*() const {
|
| 362 |
+
return const_reference(ptr_, idx_);
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
CUTLASS_HOST_DEVICE
|
| 366 |
+
bool operator==(iterator const &other) const {
|
| 367 |
+
return ptr_ == other.ptr_ && idx_ == other.idx_;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
CUTLASS_HOST_DEVICE
|
| 371 |
+
bool operator!=(iterator const &other) const {
|
| 372 |
+
return !(*this == other);
|
| 373 |
+
}
|
| 374 |
+
};
|
| 375 |
+
|
| 376 |
+
/// Bidirectional iterator over elements
|
| 377 |
+
class reverse_iterator {
|
| 378 |
+
|
| 379 |
+
/// Pointer to storage element
|
| 380 |
+
Storage *ptr_{nullptr};
|
| 381 |
+
|
| 382 |
+
/// Index into elements packed into Storage object
|
| 383 |
+
int idx_{0};
|
| 384 |
+
|
| 385 |
+
public:
|
| 386 |
+
|
| 387 |
+
reverse_iterator() = default;
|
| 388 |
+
|
| 389 |
+
CUTLASS_HOST_DEVICE
|
| 390 |
+
reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
| 391 |
+
};
|
| 392 |
+
|
| 393 |
+
/// Bidirectional constant iterator over elements
|
| 394 |
+
class const_reverse_iterator {
|
| 395 |
+
|
| 396 |
+
/// Pointer to storage element
|
| 397 |
+
Storage const *ptr_{nullptr};
|
| 398 |
+
|
| 399 |
+
/// Index into elements packed into Storage object
|
| 400 |
+
int idx_{0};
|
| 401 |
+
|
| 402 |
+
public:
|
| 403 |
+
|
| 404 |
+
const_reverse_iterator() = default;
|
| 405 |
+
|
| 406 |
+
CUTLASS_HOST_DEVICE
|
| 407 |
+
const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
| 408 |
+
};
|
| 409 |
+
|
| 410 |
+
/// Efficient clear method
|
| 411 |
+
CUTLASS_HOST_DEVICE
|
| 412 |
+
void clear() {
|
| 413 |
+
|
| 414 |
+
CUTLASS_PRAGMA_UNROLL
|
| 415 |
+
for (int i = 0; i < int(kStorageElements); ++i) {
|
| 416 |
+
storage[i] = Storage(0);
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
CUTLASS_HOST_DEVICE
|
| 421 |
+
reference at(size_type pos) {
|
| 422 |
+
return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
CUTLASS_HOST_DEVICE
|
| 426 |
+
const_reference at(size_type pos) const {
|
| 427 |
+
return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
CUTLASS_HOST_DEVICE
|
| 431 |
+
reference operator[](size_type pos) {
|
| 432 |
+
return at(pos);
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
CUTLASS_HOST_DEVICE
|
| 436 |
+
const_reference operator[](size_type pos) const {
|
| 437 |
+
return at(pos);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
CUTLASS_HOST_DEVICE
|
| 441 |
+
reference front() {
|
| 442 |
+
return at(0);
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
CUTLASS_HOST_DEVICE
|
| 446 |
+
const_reference front() const {
|
| 447 |
+
return at(0);
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
CUTLASS_HOST_DEVICE
|
| 451 |
+
reference back() {
|
| 452 |
+
return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
CUTLASS_HOST_DEVICE
|
| 456 |
+
const_reference back() const {
|
| 457 |
+
return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
CUTLASS_HOST_DEVICE
|
| 461 |
+
pointer data() {
|
| 462 |
+
return reinterpret_cast<pointer>(storage);
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
CUTLASS_HOST_DEVICE
|
| 466 |
+
const_pointer data() const {
|
| 467 |
+
return reinterpret_cast<const_pointer>(storage);
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
CUTLASS_HOST_DEVICE
|
| 471 |
+
Storage * raw_data() {
|
| 472 |
+
return storage;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
CUTLASS_HOST_DEVICE
|
| 476 |
+
Storage const * raw_data() const {
|
| 477 |
+
return storage;
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
CUTLASS_HOST_DEVICE
|
| 481 |
+
constexpr bool empty() const {
|
| 482 |
+
return !kElements;
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
CUTLASS_HOST_DEVICE
|
| 486 |
+
constexpr size_type size() const {
|
| 487 |
+
return kElements;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
CUTLASS_HOST_DEVICE
|
| 491 |
+
constexpr size_type max_size() const {
|
| 492 |
+
return kElements;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
CUTLASS_HOST_DEVICE
|
| 496 |
+
void fill(T const &value) {
|
| 497 |
+
|
| 498 |
+
CUTLASS_PRAGMA_UNROLL
|
| 499 |
+
for (int i = 0; i < kElementsPerStoredItem; ++i) {
|
| 500 |
+
reference ref(storage, i);
|
| 501 |
+
ref = value;
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
CUTLASS_PRAGMA_UNROLL
|
| 505 |
+
for (int i = 1; i < kStorageElements; ++i) {
|
| 506 |
+
storage[i] = storage[0];
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
CUTLASS_HOST_DEVICE
|
| 511 |
+
iterator begin() {
|
| 512 |
+
return iterator(storage);
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
CUTLASS_HOST_DEVICE
|
| 516 |
+
const_iterator cbegin() const {
|
| 517 |
+
return const_iterator(storage);
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
CUTLASS_HOST_DEVICE
|
| 521 |
+
iterator end() {
|
| 522 |
+
return iterator(storage + kStorageElements);
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
CUTLASS_HOST_DEVICE
|
| 526 |
+
const_iterator cend() const {
|
| 527 |
+
return const_iterator(storage + kStorageElements);
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
CUTLASS_HOST_DEVICE
|
| 531 |
+
reverse_iterator rbegin() {
|
| 532 |
+
return reverse_iterator(storage + kStorageElements);
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
CUTLASS_HOST_DEVICE
|
| 536 |
+
const_reverse_iterator crbegin() const {
|
| 537 |
+
return const_reverse_iterator(storage + kStorageElements);
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
CUTLASS_HOST_DEVICE
|
| 541 |
+
reverse_iterator rend() {
|
| 542 |
+
return reverse_iterator(storage);
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
CUTLASS_HOST_DEVICE
|
| 546 |
+
const_reverse_iterator crend() const {
|
| 547 |
+
return const_reverse_iterator(storage);
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
private:
|
| 551 |
+
/// Internal storage
|
| 552 |
+
Storage storage[kStorageElements];
|
| 553 |
+
};
|
| 554 |
+
|
| 555 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 556 |
+
|
| 557 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 558 |
+
|
| 559 |
+
} // namespace cutlass
|
| 560 |
+
|
| 561 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implementation of a CTA-wide barrier for inter-CTA synchronization.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/arch/barrier.h"
|
| 39 |
+
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
|
| 44 |
+
namespace detail {
|
| 45 |
+
|
| 46 |
+
//
|
| 47 |
+
// Utilities for abstracting synchronization methods for barriers
|
| 48 |
+
//
|
| 49 |
+
|
| 50 |
+
struct SyncthreadsSync {
|
| 51 |
+
CUTLASS_DEVICE
|
| 52 |
+
static void sync() {
|
| 53 |
+
__syncthreads();
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
struct SyncwarpSync {
|
| 58 |
+
CUTLASS_DEVICE
|
| 59 |
+
static void sync() {
|
| 60 |
+
__syncwarp();
|
| 61 |
+
}
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template <
|
| 65 |
+
int ThreadCount,
|
| 66 |
+
int BarrierId
|
| 67 |
+
>
|
| 68 |
+
struct NamedBarrierSync {
|
| 69 |
+
CUTLASS_DEVICE
|
| 70 |
+
static void sync() {
|
| 71 |
+
cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast<arch::ReservedNamedBarriers>(BarrierId));
|
| 72 |
+
}
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
} // namepspace detail
|
| 76 |
+
|
| 77 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 78 |
+
|
| 79 |
+
/// Group or CTA-wide semaphore for inter-CTA synchronization.
|
| 80 |
+
template <class Sync>
|
| 81 |
+
struct GenericBarrier {
|
| 82 |
+
|
| 83 |
+
public:
|
| 84 |
+
|
| 85 |
+
/// Flag type
|
| 86 |
+
using T = int;
|
| 87 |
+
|
| 88 |
+
/// Initial flag value
|
| 89 |
+
static const T INIT = 0;
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
protected:
|
| 93 |
+
|
| 94 |
+
/// Load flag, as a strong acquire operation (int specialization)
|
| 95 |
+
CUTLASS_DEVICE
|
| 96 |
+
static int ld_acquire(int *ptr)
|
| 97 |
+
{
|
| 98 |
+
int state = 0;
|
| 99 |
+
|
| 100 |
+
#if (__CUDA_ARCH__ >= 700)
|
| 101 |
+
/// SM70 and newer use memory consistency qualifiers
|
| 102 |
+
|
| 103 |
+
// Acquire pattern using acquire modifier
|
| 104 |
+
asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
| 105 |
+
|
| 106 |
+
#else
|
| 107 |
+
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
| 108 |
+
#endif // (__CUDA_ARCH__ >= 700)
|
| 109 |
+
|
| 110 |
+
return state;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
/// Reduce into flag, with release pattern (int specialization)
|
| 115 |
+
CUTLASS_DEVICE
|
| 116 |
+
static void red_release(int *ptr, int val)
|
| 117 |
+
{
|
| 118 |
+
#if (__CUDA_ARCH__ >= 700)
|
| 119 |
+
/// SM70 and newer use memory consistency qualifiers
|
| 120 |
+
|
| 121 |
+
// Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data
|
| 122 |
+
// that was weakly-written by other threads prior to the last syncthreads)
|
| 123 |
+
asm volatile ("fence.acq_rel.gpu;\n");
|
| 124 |
+
asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
| 125 |
+
|
| 126 |
+
#else
|
| 127 |
+
__threadfence();
|
| 128 |
+
atomicAdd(ptr, val);
|
| 129 |
+
#endif // (__CUDA_ARCH__ >= 700)
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
public:
|
| 134 |
+
|
| 135 |
+
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
|
| 136 |
+
CUTLASS_DEVICE
|
| 137 |
+
static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count)
|
| 138 |
+
{
|
| 139 |
+
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
| 140 |
+
|
| 141 |
+
if (thread_idx == 0)
|
| 142 |
+
{
|
| 143 |
+
// Spin-loop
|
| 144 |
+
#pragma unroll 1
|
| 145 |
+
while(ld_acquire(flag_ptr) < count) {}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
Sync::sync();
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
|
| 152 |
+
CUTLASS_DEVICE
|
| 153 |
+
static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1)
|
| 154 |
+
{
|
| 155 |
+
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
| 156 |
+
|
| 157 |
+
if (thread_idx == 0)
|
| 158 |
+
{
|
| 159 |
+
// Spin-loop
|
| 160 |
+
#pragma unroll 1
|
| 161 |
+
while(ld_acquire(flag_ptr) != val) {}
|
| 162 |
+
}
|
| 163 |
+
Sync::sync();
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Uses thread[0] to wait for the specified count of signals on the given flag counter
|
| 167 |
+
CUTLASS_DEVICE
|
| 168 |
+
static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
| 169 |
+
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
| 170 |
+
|
| 171 |
+
if (thread_idx == 0)
|
| 172 |
+
{
|
| 173 |
+
// Spin-loop
|
| 174 |
+
#pragma unroll 1
|
| 175 |
+
while(atomicCAS(flag_ptr, val, 0) != val) {}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
Sync::sync();
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/// Increment the arrival count for a flag
|
| 182 |
+
CUTLASS_DEVICE
|
| 183 |
+
static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1)
|
| 184 |
+
{
|
| 185 |
+
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
| 186 |
+
|
| 187 |
+
Sync::sync();
|
| 188 |
+
|
| 189 |
+
if (thread_idx == 0)
|
| 190 |
+
{
|
| 191 |
+
red_release(flag_ptr, val);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
/// Increment the arrival counts for a range of flags
|
| 197 |
+
CUTLASS_DEVICE
|
| 198 |
+
static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1)
|
| 199 |
+
{
|
| 200 |
+
int flag_idx = first_flag_idx + thread_idx;
|
| 201 |
+
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
| 202 |
+
|
| 203 |
+
// Barrier to make sure all other threads in group have written their data
|
| 204 |
+
Sync::sync();
|
| 205 |
+
|
| 206 |
+
// Select threads increment their flags
|
| 207 |
+
if (thread_idx < count) {
|
| 208 |
+
red_release(flag_ptr, val);
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
};
|
| 212 |
+
|
| 213 |
+
using Barrier = GenericBarrier<detail::SyncthreadsSync>;
|
| 214 |
+
|
| 215 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 216 |
+
|
| 217 |
+
/** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing
|
| 218 |
+
* runtime index values to be used to call into named barriers with compile-time-constant IDs.
|
| 219 |
+
*
|
| 220 |
+
* @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID
|
| 221 |
+
* @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into
|
| 222 |
+
* @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type
|
| 223 |
+
**/
|
| 224 |
+
template <
|
| 225 |
+
uint32_t ThreadCount_,
|
| 226 |
+
uint32_t Offset = 0,
|
| 227 |
+
uint32_t MaxNumNamedBarriers = 16
|
| 228 |
+
>
|
| 229 |
+
struct NamedBarrierManager {
|
| 230 |
+
|
| 231 |
+
static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers);
|
| 232 |
+
static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15");
|
| 233 |
+
|
| 234 |
+
// Number of threads participating in the barrier
|
| 235 |
+
static constexpr uint32_t ThreadCount = ThreadCount_;
|
| 236 |
+
|
| 237 |
+
template <uint32_t BarrierId>
|
| 238 |
+
using BarrierSync = cutlass::GenericBarrier<cutlass::detail::NamedBarrierSync<ThreadCount, BarrierId>>;
|
| 239 |
+
|
| 240 |
+
// Underlying type used by all barriers for synchronization. Does not depend on
|
| 241 |
+
// template parameter BarrierId, so passing in 0 suffices.
|
| 242 |
+
using T = typename BarrierSync<0>::T;
|
| 243 |
+
|
| 244 |
+
using IntegerSequence = cute::make_integer_sequence<uint32_t, MaxNumNamedBarriers>;
|
| 245 |
+
|
| 246 |
+
CUTLASS_DEVICE
|
| 247 |
+
static
|
| 248 |
+
void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) {
|
| 249 |
+
wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{});
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
CUTLASS_DEVICE
|
| 253 |
+
static void
|
| 254 |
+
wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
| 255 |
+
wait_eq_helper<false>(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{});
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
CUTLASS_DEVICE
|
| 259 |
+
static void
|
| 260 |
+
wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
| 261 |
+
wait_eq_helper<true>(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{});
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
CUTLASS_DEVICE
|
| 265 |
+
static void
|
| 266 |
+
arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
|
| 267 |
+
arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{});
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
CUTLASS_DEVICE
|
| 271 |
+
static void
|
| 272 |
+
arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
|
| 273 |
+
arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{});
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
private:
|
| 277 |
+
CUTLASS_DEVICE
|
| 278 |
+
static void
|
| 279 |
+
check_barrier_in_range([[maybe_unused]] uint32_t idx) {
|
| 280 |
+
assert((idx < MaxNumNamedBarriers) && "Index exceeds barrier count");
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
template <uint32_t... Idx>
|
| 284 |
+
CUTLASS_DEVICE
|
| 285 |
+
static void
|
| 286 |
+
wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence<uint32_t, Idx...>) {
|
| 287 |
+
check_barrier_in_range(idx);
|
| 288 |
+
((Idx == idx && (BarrierSync<Idx + Offset>::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
template <bool Reset, uint32_t... Idx>
|
| 292 |
+
CUTLASS_DEVICE
|
| 293 |
+
static void
|
| 294 |
+
wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence<uint32_t, Idx...>) {
|
| 295 |
+
check_barrier_in_range(idx);
|
| 296 |
+
if constexpr (Reset) {
|
| 297 |
+
((Idx == idx && (BarrierSync<Idx + Offset>::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...);
|
| 298 |
+
}
|
| 299 |
+
else {
|
| 300 |
+
((Idx == idx && (BarrierSync<Idx + Offset>::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
template <uint32_t... Idx>
|
| 305 |
+
CUTLASS_DEVICE
|
| 306 |
+
static void
|
| 307 |
+
arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence<uint32_t, Idx...>) {
|
| 308 |
+
check_barrier_in_range(idx);
|
| 309 |
+
((Idx == idx && (BarrierSync<Idx + Offset>::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
template <uint32_t... Idx>
|
| 313 |
+
CUTLASS_DEVICE
|
| 314 |
+
static void
|
| 315 |
+
arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence<uint32_t, Idx...>) {
|
| 316 |
+
check_barrier_in_range(idx);
|
| 317 |
+
((Idx == idx && (BarrierSync<Idx + Offset>::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...);
|
| 318 |
+
}
|
| 319 |
+
};
|
| 320 |
+
|
| 321 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 322 |
+
|
| 323 |
+
/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads)
|
| 324 |
+
* via an API that mirrors that of NamedBarrierManager
|
| 325 |
+
*
|
| 326 |
+
* @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization
|
| 327 |
+
**/
|
| 328 |
+
template <
|
| 329 |
+
class Synchronizer,
|
| 330 |
+
uint32_t ThreadCount_
|
| 331 |
+
>
|
| 332 |
+
struct SyncManager {
|
| 333 |
+
|
| 334 |
+
// Number of threads participating in the barrier
|
| 335 |
+
static constexpr uint32_t ThreadCount = ThreadCount_;
|
| 336 |
+
|
| 337 |
+
using BarrierSync = cutlass::GenericBarrier<Synchronizer>;
|
| 338 |
+
|
| 339 |
+
// Underlying type used by all barriers for synchronization.
|
| 340 |
+
using T = typename BarrierSync::T;
|
| 341 |
+
|
| 342 |
+
CUTLASS_DEVICE
|
| 343 |
+
static
|
| 344 |
+
void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) {
|
| 345 |
+
BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count);
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
CUTLASS_DEVICE
|
| 349 |
+
static void
|
| 350 |
+
wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
| 351 |
+
BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val);
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
CUTLASS_DEVICE
|
| 355 |
+
static void
|
| 356 |
+
wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
| 357 |
+
BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val);
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
CUTLASS_DEVICE
|
| 361 |
+
static void
|
| 362 |
+
arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
|
| 363 |
+
BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
CUTLASS_DEVICE
|
| 367 |
+
static void
|
| 368 |
+
arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
|
| 369 |
+
BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val);
|
| 370 |
+
}
|
| 371 |
+
};
|
| 372 |
+
|
| 373 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 374 |
+
|
| 375 |
+
} // namespace cutlass
|
| 376 |
+
|
| 377 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*!
|
| 32 |
+
\file
|
| 33 |
+
\brief Defines a proxy class for storing non-standard 16-bit floating point values with
|
| 34 |
+
8 bits of exponent and 7 bit of mantissa.
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#if defined(__CUDACC_RTC__)
|
| 40 |
+
#include "cutlass/floating_point_nvrtc.h"
|
| 41 |
+
#else
|
| 42 |
+
#include <cmath>
|
| 43 |
+
#include <limits>
|
| 44 |
+
#include <cstdint>
|
| 45 |
+
#include <cstring>
|
| 46 |
+
#endif
|
| 47 |
+
|
| 48 |
+
#include <cuda_bf16.h>
|
| 49 |
+
#include "cutlass/cutlass.h"
|
| 50 |
+
#include "cutlass/platform/platform.h"
|
| 51 |
+
|
| 52 |
+
namespace cutlass {
|
| 53 |
+
|
| 54 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// Floating-point type with 8 bits of exponent and 7 bits of mantissa.
|
| 57 |
+
struct alignas(2) bfloat16_t {
|
| 58 |
+
|
| 59 |
+
//
|
| 60 |
+
// Data members
|
| 61 |
+
//
|
| 62 |
+
|
| 63 |
+
/// Storage type
|
| 64 |
+
uint16_t storage;
|
| 65 |
+
|
| 66 |
+
//
|
| 67 |
+
// Methods
|
| 68 |
+
//
|
| 69 |
+
|
| 70 |
+
/// Constructs from an unsigned short
|
| 71 |
+
CUTLASS_HOST_DEVICE
|
| 72 |
+
static bfloat16_t bitcast(uint16_t x) {
|
| 73 |
+
bfloat16_t h;
|
| 74 |
+
h.storage = x;
|
| 75 |
+
return h;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
private:
|
| 79 |
+
struct from_32_bit_integer_t {};
|
| 80 |
+
static constexpr from_32_bit_integer_t from_32_bit_integer{};
|
| 81 |
+
|
| 82 |
+
template<class T>
|
| 83 |
+
CUTLASS_HOST_DEVICE
|
| 84 |
+
explicit bfloat16_t(from_32_bit_integer_t, T x) {
|
| 85 |
+
static_assert(cutlass::platform::is_integral<T>::value && sizeof(T) == 4, "Requires 32-bit integer");
|
| 86 |
+
|
| 87 |
+
float flt = static_cast<float>(x);
|
| 88 |
+
uint32_t bits;
|
| 89 |
+
|
| 90 |
+
#if defined(__CUDA_ARCH__)
|
| 91 |
+
bits = reinterpret_cast<uint32_t &>(flt);
|
| 92 |
+
#else
|
| 93 |
+
std::memcpy(&bits, &flt, sizeof(bits));
|
| 94 |
+
#endif
|
| 95 |
+
|
| 96 |
+
storage = uint16_t(bits >> 16);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
public:
|
| 100 |
+
/// Default constructor
|
| 101 |
+
bfloat16_t() = default;
|
| 102 |
+
|
| 103 |
+
/// Reinterpret cast from CUDA's __nv_bfloat16 type
|
| 104 |
+
CUTLASS_HOST_DEVICE
|
| 105 |
+
explicit bfloat16_t(__nv_bfloat16 const & x) {
|
| 106 |
+
#if defined(__CUDA_ARCH__)
|
| 107 |
+
storage = reinterpret_cast<uint16_t const &>(x);
|
| 108 |
+
#else
|
| 109 |
+
__nv_bfloat16_raw raw(x);
|
| 110 |
+
std::memcpy(&storage, &raw.x, sizeof(storage));
|
| 111 |
+
#endif
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/// Floating-point conversion - round toward nearest
|
| 115 |
+
CUTLASS_HOST_DEVICE
|
| 116 |
+
explicit bfloat16_t(float x) {
|
| 117 |
+
|
| 118 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
| 119 |
+
|
| 120 |
+
asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));
|
| 121 |
+
|
| 122 |
+
#else
|
| 123 |
+
uint32_t bits;
|
| 124 |
+
|
| 125 |
+
#if defined(__CUDA_ARCH__)
|
| 126 |
+
bits = reinterpret_cast<uint32_t &>(x);
|
| 127 |
+
#else
|
| 128 |
+
std::memcpy(&bits, &x, sizeof(bits));
|
| 129 |
+
#endif
|
| 130 |
+
|
| 131 |
+
if ((bits & 0x7f800000) != 0x7f800000) {
|
| 132 |
+
|
| 133 |
+
bool mantissa_bit = ((bits & (1 << 16)) != 0);
|
| 134 |
+
bool round_bit = ((bits & (1 << 15)) != 0);
|
| 135 |
+
bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0);
|
| 136 |
+
|
| 137 |
+
if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) {
|
| 138 |
+
bits += uint32_t(1 << 16);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
else if (bits & ~0xff800000) {
|
| 142 |
+
bits = 0x7fffffff;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
storage = uint16_t((bits >> 16) & 0xffff);
|
| 146 |
+
#endif
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/// Floating-point conversion - round toward nearest
|
| 150 |
+
CUTLASS_HOST_DEVICE
|
| 151 |
+
explicit bfloat16_t(double x): bfloat16_t(float(x)) {
|
| 152 |
+
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
/// Integer conversion - round toward nearest
|
| 156 |
+
CUTLASS_HOST_DEVICE
|
| 157 |
+
explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {}
|
| 158 |
+
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {}
|
| 161 |
+
|
| 162 |
+
/// Converts to float
|
| 163 |
+
CUTLASS_HOST_DEVICE
|
| 164 |
+
operator float() const {
|
| 165 |
+
unsigned bits = (unsigned(storage) << 16);
|
| 166 |
+
#if defined(__CUDA_ARCH__)
|
| 167 |
+
return reinterpret_cast<float const &>(bits);
|
| 168 |
+
#else
|
| 169 |
+
float flt;
|
| 170 |
+
std::memcpy(&flt, &bits, sizeof(flt));
|
| 171 |
+
return flt;
|
| 172 |
+
#endif
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
/// Converts to float
|
| 176 |
+
CUTLASS_HOST_DEVICE
|
| 177 |
+
explicit operator double() const {
|
| 178 |
+
return double(float(*this));
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/// Converts to int
|
| 182 |
+
CUTLASS_HOST_DEVICE
|
| 183 |
+
explicit operator int() const {
|
| 184 |
+
return int(float(*this));
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
/// Casts to bool
|
| 188 |
+
CUTLASS_HOST_DEVICE
|
| 189 |
+
explicit operator bool() const {
|
| 190 |
+
return (float(*this) != 0.0f);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
/// Bitcasts to CUDA's bf16 type
|
| 194 |
+
CUTLASS_DEVICE
|
| 195 |
+
__nv_bfloat16 to_nv_bfloat16() const {
|
| 196 |
+
return reinterpret_cast<__nv_bfloat16 const &>(storage);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Obtains raw bits
|
| 200 |
+
CUTLASS_HOST_DEVICE
|
| 201 |
+
uint16_t raw() const {
|
| 202 |
+
return storage;
|
| 203 |
+
}
|
| 204 |
+
/// Returns the sign bit
|
| 205 |
+
CUTLASS_HOST_DEVICE
|
| 206 |
+
bool signbit() const {
|
| 207 |
+
return ((raw() & 0x8000) != 0);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/// Returns the biased exponent
|
| 211 |
+
CUTLASS_HOST_DEVICE
|
| 212 |
+
int exponent_biased() const {
|
| 213 |
+
return int((raw() >> 7) & 0x0ff);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
/// Returns the unbiased exponent
|
| 217 |
+
CUTLASS_HOST_DEVICE
|
| 218 |
+
int exponent() const {
|
| 219 |
+
return exponent_biased() - 127;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
/// Returns the mantissa
|
| 223 |
+
CUTLASS_HOST_DEVICE
|
| 224 |
+
int mantissa() const {
|
| 225 |
+
return int(raw() & 0x7f);
|
| 226 |
+
}
|
| 227 |
+
};
|
| 228 |
+
|
| 229 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 230 |
+
|
| 231 |
+
CUTLASS_HOST_DEVICE
|
| 232 |
+
bool signbit(cutlass::bfloat16_t const& h) {
|
| 233 |
+
return h.signbit();
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
CUTLASS_HOST_DEVICE
|
| 237 |
+
cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) {
|
| 238 |
+
return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff);
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
CUTLASS_HOST_DEVICE
|
| 242 |
+
bool isnan(cutlass::bfloat16_t const& h) {
|
| 243 |
+
return (h.exponent_biased() == 0x0ff) && h.mantissa();
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
CUTLASS_HOST_DEVICE
|
| 247 |
+
bool isfinite(cutlass::bfloat16_t const& h) {
|
| 248 |
+
return (h.exponent_biased() != 0x0ff);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
CUTLASS_HOST_DEVICE
|
| 252 |
+
cutlass::bfloat16_t nan_bf16(const char*) {
|
| 253 |
+
// NVIDIA canonical NaN
|
| 254 |
+
return cutlass::bfloat16_t::bitcast(0x7fff);
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
CUTLASS_HOST_DEVICE
|
| 258 |
+
bool isinf(cutlass::bfloat16_t const& h) {
|
| 259 |
+
return (h.exponent_biased() == 0x0ff) && !h.mantissa();
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
CUTLASS_HOST_DEVICE
|
| 263 |
+
bool isnormal(cutlass::bfloat16_t const& h) {
|
| 264 |
+
return h.exponent_biased() && h.exponent_biased() != 0x0ff;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
CUTLASS_HOST_DEVICE
|
| 268 |
+
int fpclassify(cutlass::bfloat16_t const& h) {
|
| 269 |
+
int exp = h.exponent_biased();
|
| 270 |
+
int mantissa = h.mantissa();
|
| 271 |
+
if (exp == 0x0ff) {
|
| 272 |
+
if (mantissa) {
|
| 273 |
+
return FP_NAN;
|
| 274 |
+
}
|
| 275 |
+
else {
|
| 276 |
+
return FP_INFINITE;
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
else if (!exp) {
|
| 280 |
+
if (mantissa) {
|
| 281 |
+
return FP_SUBNORMAL;
|
| 282 |
+
}
|
| 283 |
+
else {
|
| 284 |
+
return FP_ZERO;
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
return FP_NORMAL;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
CUTLASS_HOST_DEVICE
|
| 291 |
+
cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) {
|
| 292 |
+
#if defined(__CUDACC_RTC__)
|
| 293 |
+
return cutlass::bfloat16_t(sqrtf(float(h)));
|
| 294 |
+
#else
|
| 295 |
+
return cutlass::bfloat16_t(std::sqrt(float(h)));
|
| 296 |
+
#endif
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
CUTLASS_HOST_DEVICE
|
| 300 |
+
bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) {
|
| 301 |
+
|
| 302 |
+
uint16_t a_bits;
|
| 303 |
+
uint16_t b_bits;
|
| 304 |
+
|
| 305 |
+
#if defined(__CUDA_ARCH__)
|
| 306 |
+
a_bits = reinterpret_cast<uint16_t const &>(a);
|
| 307 |
+
b_bits = reinterpret_cast<uint16_t const &>(b);
|
| 308 |
+
#else
|
| 309 |
+
std::memcpy(&a_bits, &a, sizeof(a_bits));
|
| 310 |
+
std::memcpy(&b_bits, &b, sizeof(b_bits));
|
| 311 |
+
#endif
|
| 312 |
+
|
| 313 |
+
uint16_t a_mag = (a_bits & 0x7fff);
|
| 314 |
+
uint16_t b_sign = (b_bits & 0x8000);
|
| 315 |
+
uint16_t result = (a_mag | b_sign);
|
| 316 |
+
|
| 317 |
+
return bfloat16_t::bitcast(result);
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 321 |
+
|
| 322 |
+
} // namespace cutlass
|
| 323 |
+
|
| 324 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 325 |
+
//
|
| 326 |
+
// Standard Library operations and definitions
|
| 327 |
+
//
|
| 328 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 329 |
+
|
| 330 |
+
#if !defined(__CUDACC_RTC__)
|
| 331 |
+
namespace std {
|
| 332 |
+
|
| 333 |
+
/// Numeric limits
|
| 334 |
+
template <>
|
| 335 |
+
struct numeric_limits<cutlass::bfloat16_t> {
|
| 336 |
+
static bool const is_specialized = true;
|
| 337 |
+
static bool const is_signed = true;
|
| 338 |
+
static bool const is_integer = false;
|
| 339 |
+
static bool const is_exact = false;
|
| 340 |
+
static bool const has_infinity = true;
|
| 341 |
+
static bool const has_quiet_NaN = true;
|
| 342 |
+
static bool const has_signaling_NaN = false;
|
| 343 |
+
static std::float_denorm_style const has_denorm = std::denorm_present;
|
| 344 |
+
static bool const has_denorm_loss = true;
|
| 345 |
+
static std::float_round_style const round_style = std::round_to_nearest;
|
| 346 |
+
static bool const is_iec559 = false;
|
| 347 |
+
static bool const is_bounded = true;
|
| 348 |
+
static bool const is_modulo = false;
|
| 349 |
+
static int const digits = 7;
|
| 350 |
+
|
| 351 |
+
/// Least positive value
|
| 352 |
+
CUTLASS_HOST_DEVICE
|
| 353 |
+
static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
|
| 354 |
+
|
| 355 |
+
/// Minimum finite value
|
| 356 |
+
CUTLASS_HOST_DEVICE
|
| 357 |
+
static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
|
| 358 |
+
|
| 359 |
+
/// Maximum finite value
|
| 360 |
+
CUTLASS_HOST_DEVICE
|
| 361 |
+
static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
|
| 362 |
+
|
| 363 |
+
/// Returns smallest finite value
|
| 364 |
+
CUTLASS_HOST_DEVICE
|
| 365 |
+
static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
|
| 366 |
+
|
| 367 |
+
/// Returns smallest finite value
|
| 368 |
+
CUTLASS_HOST_DEVICE
|
| 369 |
+
static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
|
| 370 |
+
|
| 371 |
+
/// Returns smallest finite value
|
| 372 |
+
CUTLASS_HOST_DEVICE
|
| 373 |
+
static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
|
| 374 |
+
|
| 375 |
+
/// Returns smallest finite value
|
| 376 |
+
CUTLASS_HOST_DEVICE
|
| 377 |
+
static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
|
| 378 |
+
|
| 379 |
+
/// Returns smallest finite value
|
| 380 |
+
CUTLASS_HOST_DEVICE
|
| 381 |
+
static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
|
| 382 |
+
|
| 383 |
+
/// Returns smallest finite value
|
| 384 |
+
CUTLASS_HOST_DEVICE
|
| 385 |
+
static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
|
| 386 |
+
};
|
| 387 |
+
|
| 388 |
+
} // namespace std
|
| 389 |
+
#endif
|
| 390 |
+
|
| 391 |
+
namespace cutlass {
|
| 392 |
+
namespace platform {
|
| 393 |
+
|
| 394 |
+
/// Forward Declaration
|
| 395 |
+
template <class T>
|
| 396 |
+
struct numeric_limits;
|
| 397 |
+
|
| 398 |
+
/// Numeric limits
|
| 399 |
+
template <>
|
| 400 |
+
struct numeric_limits<cutlass::bfloat16_t> {
|
| 401 |
+
static bool const is_specialized = true;
|
| 402 |
+
static bool const is_signed = true;
|
| 403 |
+
static bool const is_integer = false;
|
| 404 |
+
static bool const is_exact = false;
|
| 405 |
+
static bool const has_infinity = true;
|
| 406 |
+
static bool const has_quiet_NaN = true;
|
| 407 |
+
static bool const has_signaling_NaN = false;
|
| 408 |
+
#if !defined(__CUDACC_RTC__)
|
| 409 |
+
static std::float_denorm_style const has_denorm = std::denorm_present;
|
| 410 |
+
#endif
|
| 411 |
+
static bool const has_denorm_loss = true;
|
| 412 |
+
#if !defined(__CUDACC_RTC__)
|
| 413 |
+
static std::float_round_style const round_style = std::round_to_nearest;
|
| 414 |
+
#endif
|
| 415 |
+
static bool const is_iec559 = false;
|
| 416 |
+
static bool const is_bounded = true;
|
| 417 |
+
static bool const is_modulo = false;
|
| 418 |
+
static int const digits = 7;
|
| 419 |
+
|
| 420 |
+
/// Least positive value
|
| 421 |
+
CUTLASS_HOST_DEVICE
|
| 422 |
+
static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
|
| 423 |
+
|
| 424 |
+
/// Minimum finite value
|
| 425 |
+
CUTLASS_HOST_DEVICE
|
| 426 |
+
static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
|
| 427 |
+
|
| 428 |
+
/// Maximum finite value
|
| 429 |
+
CUTLASS_HOST_DEVICE
|
| 430 |
+
static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
|
| 431 |
+
|
| 432 |
+
/// Returns smallest finite value
|
| 433 |
+
CUTLASS_HOST_DEVICE
|
| 434 |
+
static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
|
| 435 |
+
|
| 436 |
+
/// Returns smallest finite value
|
| 437 |
+
CUTLASS_HOST_DEVICE
|
| 438 |
+
static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
|
| 439 |
+
|
| 440 |
+
/// Returns smallest finite value
|
| 441 |
+
CUTLASS_HOST_DEVICE
|
| 442 |
+
static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
|
| 443 |
+
|
| 444 |
+
/// Returns smallest finite value
|
| 445 |
+
CUTLASS_HOST_DEVICE
|
| 446 |
+
static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
|
| 447 |
+
|
| 448 |
+
/// Returns smallest finite value
|
| 449 |
+
CUTLASS_HOST_DEVICE
|
| 450 |
+
static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
|
| 451 |
+
|
| 452 |
+
/// Returns smallest finite value
|
| 453 |
+
CUTLASS_HOST_DEVICE
|
| 454 |
+
static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
|
| 455 |
+
};
|
| 456 |
+
|
| 457 |
+
} // namespace platform
|
| 458 |
+
} // namespace cutlass
|
| 459 |
+
|
| 460 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 461 |
+
//
|
| 462 |
+
// Arithmetic operators
|
| 463 |
+
//
|
| 464 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 465 |
+
|
| 466 |
+
namespace cutlass {
|
| 467 |
+
|
| 468 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 469 |
+
|
| 470 |
+
CUTLASS_HOST_DEVICE
|
| 471 |
+
bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 472 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 473 |
+
return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
|
| 474 |
+
#else
|
| 475 |
+
return float(lhs) == float(rhs);
|
| 476 |
+
#endif
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
CUTLASS_HOST_DEVICE
|
| 480 |
+
bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 481 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 482 |
+
return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
|
| 483 |
+
#else
|
| 484 |
+
return float(lhs) != float(rhs);
|
| 485 |
+
#endif
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
CUTLASS_HOST_DEVICE
|
| 489 |
+
bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 490 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 491 |
+
return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
|
| 492 |
+
#else
|
| 493 |
+
return float(lhs) < float(rhs);
|
| 494 |
+
#endif
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
CUTLASS_HOST_DEVICE
|
| 498 |
+
bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 499 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 500 |
+
return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
|
| 501 |
+
#else
|
| 502 |
+
return float(lhs) <= float(rhs);
|
| 503 |
+
#endif
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
CUTLASS_HOST_DEVICE
|
| 507 |
+
bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 508 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 509 |
+
return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
|
| 510 |
+
#else
|
| 511 |
+
return float(lhs) > float(rhs);
|
| 512 |
+
#endif
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
CUTLASS_HOST_DEVICE
|
| 516 |
+
bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 517 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 518 |
+
return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
|
| 519 |
+
#else
|
| 520 |
+
return float(lhs) >= float(rhs);
|
| 521 |
+
#endif
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
CUTLASS_HOST_DEVICE
|
| 525 |
+
bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 526 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 527 |
+
return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 528 |
+
#else
|
| 529 |
+
return bfloat16_t(float(lhs) + float(rhs));
|
| 530 |
+
#endif
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
CUTLASS_HOST_DEVICE
|
| 534 |
+
bfloat16_t operator-(bfloat16_t const& lhs) {
|
| 535 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 536 |
+
return bfloat16_t(__hneg(lhs.to_nv_bfloat16()));
|
| 537 |
+
#else
|
| 538 |
+
return bfloat16_t(-float(lhs));
|
| 539 |
+
#endif
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
CUTLASS_HOST_DEVICE
|
| 543 |
+
bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 544 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 545 |
+
return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 546 |
+
#else
|
| 547 |
+
return bfloat16_t(float(lhs) - float(rhs));
|
| 548 |
+
#endif
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
CUTLASS_HOST_DEVICE
|
| 552 |
+
bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 553 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 554 |
+
return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 555 |
+
#else
|
| 556 |
+
return bfloat16_t(float(lhs) * float(rhs));
|
| 557 |
+
#endif
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
CUTLASS_HOST_DEVICE
|
| 561 |
+
bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
| 562 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 563 |
+
return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 564 |
+
#else
|
| 565 |
+
return bfloat16_t(float(lhs) / float(rhs));
|
| 566 |
+
#endif
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
CUTLASS_HOST_DEVICE
|
| 570 |
+
bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
| 571 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 572 |
+
lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 573 |
+
#else
|
| 574 |
+
lhs = bfloat16_t(float(lhs) + float(rhs));
|
| 575 |
+
#endif
|
| 576 |
+
return lhs;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
CUTLASS_HOST_DEVICE
|
| 580 |
+
bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
| 581 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 582 |
+
lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 583 |
+
#else
|
| 584 |
+
lhs = bfloat16_t(float(lhs) - float(rhs));
|
| 585 |
+
#endif
|
| 586 |
+
return lhs;
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
CUTLASS_HOST_DEVICE
|
| 590 |
+
bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
| 591 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 592 |
+
lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 593 |
+
#else
|
| 594 |
+
lhs = bfloat16_t(float(lhs) * float(rhs));
|
| 595 |
+
#endif
|
| 596 |
+
return lhs;
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
CUTLASS_HOST_DEVICE
|
| 600 |
+
bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
| 601 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 602 |
+
lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
|
| 603 |
+
#else
|
| 604 |
+
lhs = bfloat16_t(float(lhs) / float(rhs));
|
| 605 |
+
#endif
|
| 606 |
+
return lhs;
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
CUTLASS_HOST_DEVICE
|
| 610 |
+
bfloat16_t& operator++(bfloat16_t & lhs) {
|
| 611 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 612 |
+
lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
|
| 613 |
+
#else
|
| 614 |
+
float tmp(lhs);
|
| 615 |
+
++tmp;
|
| 616 |
+
lhs = bfloat16_t(tmp);
|
| 617 |
+
#endif
|
| 618 |
+
return lhs;
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
CUTLASS_HOST_DEVICE
|
| 622 |
+
bfloat16_t& operator--(bfloat16_t & lhs) {
|
| 623 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 624 |
+
lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
|
| 625 |
+
#else
|
| 626 |
+
float tmp(lhs);
|
| 627 |
+
--tmp;
|
| 628 |
+
lhs = bfloat16_t(tmp);
|
| 629 |
+
#endif
|
| 630 |
+
return lhs;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
CUTLASS_HOST_DEVICE
|
| 634 |
+
bfloat16_t operator++(bfloat16_t & lhs, int) {
|
| 635 |
+
bfloat16_t ret(lhs);
|
| 636 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 637 |
+
lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
|
| 638 |
+
#else
|
| 639 |
+
float tmp(lhs);
|
| 640 |
+
tmp++;
|
| 641 |
+
lhs = bfloat16_t(tmp);
|
| 642 |
+
#endif
|
| 643 |
+
return ret;
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
CUTLASS_HOST_DEVICE
|
| 647 |
+
bfloat16_t operator--(bfloat16_t & lhs, int) {
|
| 648 |
+
bfloat16_t ret(lhs);
|
| 649 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 650 |
+
lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
|
| 651 |
+
#else
|
| 652 |
+
float tmp(lhs);
|
| 653 |
+
tmp--;
|
| 654 |
+
lhs = bfloat16_t(tmp);
|
| 655 |
+
#endif
|
| 656 |
+
return ret;
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 660 |
+
|
| 661 |
+
} // namespace cutlass
|
| 662 |
+
|
| 663 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 664 |
+
|
| 665 |
+
//
|
| 666 |
+
// User-defined literals
|
| 667 |
+
//
|
| 668 |
+
|
| 669 |
+
CUTLASS_HOST_DEVICE
|
| 670 |
+
cutlass::bfloat16_t operator "" _bf16(long double x) {
|
| 671 |
+
return cutlass::bfloat16_t(float(x));
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
CUTLASS_HOST_DEVICE
|
| 675 |
+
cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) {
|
| 676 |
+
return cutlass::bfloat16_t(int(x));
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Basic include for CUTLASS BLAS3/HPC code.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/array.h"
|
| 42 |
+
#include "cutlass/blas3_types.h"
|
| 43 |
+
#include "cutlass/coord.h"
|
| 44 |
+
#include "cutlass/complex.h"
|
| 45 |
+
#include "cutlass/functional.h"
|
| 46 |
+
#include "cutlass/numeric_types.h"
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Defines FillMode inversions
|
| 55 |
+
template <FillMode kFillMode>
|
| 56 |
+
struct InvertFillMode;
|
| 57 |
+
|
| 58 |
+
/// Invert FillMode lower to upper
|
| 59 |
+
template <>
|
| 60 |
+
struct InvertFillMode<FillMode::kLower> {
|
| 61 |
+
static FillMode const mode = FillMode::kUpper;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
/// Invert FillMode upper to lower
|
| 65 |
+
template <>
|
| 66 |
+
struct InvertFillMode<FillMode::kUpper> {
|
| 67 |
+
static FillMode const mode = FillMode::kLower;
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
/// Defines SideMode inversions
|
| 72 |
+
template <SideMode kSideMode>
|
| 73 |
+
struct InvertSideMode;
|
| 74 |
+
|
| 75 |
+
/// Invert SideMode left to right
|
| 76 |
+
template <>
|
| 77 |
+
struct InvertSideMode<SideMode::kLeft> {
|
| 78 |
+
static SideMode const mode = SideMode::kRight;
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
/// Invert SideMode right to left
|
| 82 |
+
template <>
|
| 83 |
+
struct InvertSideMode<SideMode::kRight> {
|
| 84 |
+
static SideMode const mode = SideMode::kLeft;
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
/// Defines correct compare operation for Triangular matrix boundary
|
| 89 |
+
template <FillMode kFillMode, DiagType kDiagType = DiagType::kNonUnit>
|
| 90 |
+
struct TrMatrixCompareOp {
|
| 91 |
+
using Index = int32_t;
|
| 92 |
+
using Type = typename platform::conditional<
|
| 93 |
+
(kFillMode == FillMode::kLower),
|
| 94 |
+
greater_equal<Index>,
|
| 95 |
+
less_equal<Index>>::type;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
template <FillMode kFillMode>
|
| 99 |
+
struct TrMatrixCompareOp <kFillMode, DiagType::kUnit> {
|
| 100 |
+
using Index = int32_t;
|
| 101 |
+
using Type = typename platform::conditional<
|
| 102 |
+
(kFillMode == FillMode::kLower),
|
| 103 |
+
greater_equal<Index>,
|
| 104 |
+
less_equal<Index>>::type;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
template <FillMode kFillMode>
|
| 108 |
+
struct TrMatrixCompareOp <kFillMode, DiagType::kZero> {
|
| 109 |
+
using Index = int32_t;
|
| 110 |
+
using Type = typename platform::conditional<
|
| 111 |
+
(kFillMode == FillMode::kLower),
|
| 112 |
+
greater<Index>,
|
| 113 |
+
less<Index>>::type;
|
| 114 |
+
};
|
| 115 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 116 |
+
// Returns precision in terms of bits (based on datatype) to fill tensors with.
|
| 117 |
+
// Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs).
|
| 118 |
+
// Also defines acceptable mantissa result variance/error.
|
| 119 |
+
template <typename Element>
|
| 120 |
+
struct MantissaInBits {
|
| 121 |
+
static int constexpr bits = 5;
|
| 122 |
+
static double constexpr error = 1.0e-7;
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
// Full precision is supported for FP64
|
| 126 |
+
template <>
|
| 127 |
+
struct MantissaInBits<double> {
|
| 128 |
+
static int constexpr bits = 30;
|
| 129 |
+
static double constexpr error = 1.0e-15;
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
template <>
|
| 133 |
+
struct MantissaInBits<cutlass::complex<double>> {
|
| 134 |
+
static int constexpr bits = 30;
|
| 135 |
+
static double constexpr error = 1.0e-14;
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 139 |
+
|
| 140 |
+
} // namespace cutlass
|
| 141 |
+
|
| 142 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 143 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 35 |
+
|
| 36 |
+
namespace cutlass {
|
| 37 |
+
|
| 38 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 39 |
+
|
| 40 |
+
/// Enumerated type describing the type of kernel (based on input or output matrices).
|
| 41 |
+
enum class BlasMode {
|
| 42 |
+
kGemm,
|
| 43 |
+
kSymmetric,
|
| 44 |
+
kHermitian,
|
| 45 |
+
kTriangular,
|
| 46 |
+
kInvalid
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
/// Enumerated type describing the fill mode for matrices for BLAS functions.
|
| 50 |
+
enum class FillMode {
|
| 51 |
+
kFull, /// The entire tensor is covered.
|
| 52 |
+
kLower, /// The 'lower' part of a tensor is covered including diagonal
|
| 53 |
+
kUpper, /// The 'upper' part of a tensor is covered including diaognal
|
| 54 |
+
kDiagonal, /// Only diagonal elements are covered.
|
| 55 |
+
kNone, /// No element is covered.
|
| 56 |
+
kInvalid
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
/// Enumerated type describing the diagonal property of matrices for BLAS functions.
|
| 60 |
+
enum class DiagType {
|
| 61 |
+
kNonUnit,
|
| 62 |
+
kUnit,
|
| 63 |
+
kZero, // Only used internally for computing SYMM/HEMM
|
| 64 |
+
kInvalid
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions.
|
| 68 |
+
enum class SideMode {
|
| 69 |
+
kLeft,
|
| 70 |
+
kRight,
|
| 71 |
+
kInvalid
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 75 |
+
|
| 76 |
+
} // namespace cutlass
|
| 77 |
+
|
| 78 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable,
|
| 33 |
+
statically-sized array types to global memory.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/wmma_array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/complex.h"
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
// AccessWidth
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit
|
| 51 |
+
template <
|
| 52 |
+
typename T,
|
| 53 |
+
int Limit>
|
| 54 |
+
struct AccessWidth
|
| 55 |
+
{
|
| 56 |
+
// Inductive case
|
| 57 |
+
template <
|
| 58 |
+
int ObjectBytes, /// Size of T in bytes
|
| 59 |
+
int AlignBytes, /// Template induction variable
|
| 60 |
+
bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes
|
| 61 |
+
((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))>
|
| 62 |
+
struct Detail
|
| 63 |
+
{
|
| 64 |
+
static const int value = Detail<ObjectBytes, AlignBytes * 2>::value;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
// Base case (ObjectBytes is not an even multiple of AlignBytes)
|
| 68 |
+
template <
|
| 69 |
+
int ObjectBytes, /// Size of T in bytes
|
| 70 |
+
int AlignBytes> /// Template induction variable
|
| 71 |
+
struct Detail<ObjectBytes, AlignBytes, false>
|
| 72 |
+
{
|
| 73 |
+
static const int value = AlignBytes / 2;
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
/// The maximal power-of-two that evenly divides the size of T
|
| 77 |
+
static const int value = Detail<
|
| 78 |
+
(int) sizeof(T),
|
| 79 |
+
1>::value;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
// StripedAccessType
|
| 86 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 87 |
+
|
| 88 |
+
/// ReinterpretCast type for striping a trivially-copyable type in global memory
|
| 89 |
+
/// (Default specialization. Striping granularity is type T.)
|
| 90 |
+
template <
|
| 91 |
+
typename T, /// Data type
|
| 92 |
+
int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures)
|
| 93 |
+
AccessWidth<T, 16>::value>
|
| 94 |
+
struct alignas(TransferBytes) StripedAccessType : public T
|
| 95 |
+
{};
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
/// ReinterpretCast type for striping a trivially-copyable type in global memory
|
| 99 |
+
/// (Specialization for cutlass::Array<T>. Striping granularity is a multiple of T.)
|
| 100 |
+
template <
|
| 101 |
+
typename T, /// Array element type
|
| 102 |
+
int N, /// Number of elements in array
|
| 103 |
+
bool RegisterSized, /// T is register-sized
|
| 104 |
+
int TransferBytes> /// Data access width
|
| 105 |
+
struct StripedAccessType<
|
| 106 |
+
Array<T, N, RegisterSized>,
|
| 107 |
+
TransferBytes>
|
| 108 |
+
: public AlignedArray<
|
| 109 |
+
T, // Element type of StripedAccessType
|
| 110 |
+
__NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType
|
| 111 |
+
TransferBytes> // Alignment of StripedAccessType
|
| 112 |
+
{};
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
| 116 |
+
|
| 117 |
+
/// ReinterpretCast type for striping a trivially-copyable type in global memory
|
| 118 |
+
/// (Specialization for cutlass::WmmaFragmentArray<T>. Striping granularity is a multiple of T.)
|
| 119 |
+
template<
|
| 120 |
+
typename Use,
|
| 121 |
+
int m,
|
| 122 |
+
int n,
|
| 123 |
+
int k,
|
| 124 |
+
typename ElementT,
|
| 125 |
+
typename Layout,
|
| 126 |
+
int kFragments,
|
| 127 |
+
int TransferBytes>
|
| 128 |
+
struct StripedAccessType<
|
| 129 |
+
WmmaFragmentArray<nvcuda::wmma::fragment<Use, m, n, k, ElementT, Layout>, kFragments>,
|
| 130 |
+
TransferBytes>
|
| 131 |
+
: public AlignedArray<
|
| 132 |
+
ElementT,
|
| 133 |
+
__NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)),
|
| 134 |
+
TransferBytes>
|
| 135 |
+
{};
|
| 136 |
+
|
| 137 |
+
#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 141 |
+
// BlockStriped
|
| 142 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 143 |
+
|
| 144 |
+
/// Utility for performing block-striped access (load, store) of trivially-copyable,
|
| 145 |
+
/// statically-sized array types to global memory
|
| 146 |
+
template <
|
| 147 |
+
int BlockThreads,
|
| 148 |
+
typename ArrayT,
|
| 149 |
+
typename AccessT = StripedAccessType<ArrayT> >
|
| 150 |
+
struct BlockStriped
|
| 151 |
+
{
|
| 152 |
+
/// Number of striped accesses
|
| 153 |
+
static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT));
|
| 154 |
+
static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type");
|
| 155 |
+
|
| 156 |
+
/// Load
|
| 157 |
+
CUTLASS_DEVICE
|
| 158 |
+
static void load(ArrayT &data, ArrayT *ptr, int thread_idx)
|
| 159 |
+
{
|
| 160 |
+
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
|
| 161 |
+
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
|
| 162 |
+
|
| 163 |
+
CUTLASS_PRAGMA_UNROLL
|
| 164 |
+
for (int i = 0; i < kStripes; ++i) {
|
| 165 |
+
access_data[i] = access_input[(BlockThreads * i) + thread_idx];
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// Load & Add
|
| 170 |
+
CUTLASS_DEVICE
|
| 171 |
+
static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx)
|
| 172 |
+
{
|
| 173 |
+
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
|
| 174 |
+
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
|
| 175 |
+
|
| 176 |
+
plus<AccessT> add;
|
| 177 |
+
|
| 178 |
+
CUTLASS_PRAGMA_UNROLL
|
| 179 |
+
for (int i = 0; i < kStripes; ++i)
|
| 180 |
+
{
|
| 181 |
+
access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
/// Store
|
| 186 |
+
CUTLASS_DEVICE
|
| 187 |
+
static void store(ArrayT *ptr, const ArrayT &data, int thread_idx)
|
| 188 |
+
{
|
| 189 |
+
AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
|
| 190 |
+
const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
|
| 191 |
+
|
| 192 |
+
CUTLASS_PRAGMA_UNROLL
|
| 193 |
+
for (int i = 0; i < kStripes; ++i) {
|
| 194 |
+
access_output[(BlockThreads * i) + thread_idx] = access_data[i];
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 202 |
+
// BlockStripedReduce
|
| 203 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
|
| 207 |
+
/// statically-sized array types to global memory.
|
| 208 |
+
/// (Default specialization)
|
| 209 |
+
template <
|
| 210 |
+
int BlockThreads,
|
| 211 |
+
typename ArrayT,
|
| 212 |
+
typename ElementT = typename StripedAccessType<ArrayT>::Element>
|
| 213 |
+
struct BlockStripedReduce :
|
| 214 |
+
BlockStriped<
|
| 215 |
+
BlockThreads,
|
| 216 |
+
ArrayT,
|
| 217 |
+
ElementT>
|
| 218 |
+
{
|
| 219 |
+
/// Reduce
|
| 220 |
+
CUTLASS_DEVICE
|
| 221 |
+
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
|
| 222 |
+
{
|
| 223 |
+
cutlass::atomic_add<ElementT> reduce;
|
| 224 |
+
ElementT *access_output = reinterpret_cast<ElementT*>(ptr);
|
| 225 |
+
const ElementT *access_data = reinterpret_cast<const ElementT*>(&data);
|
| 226 |
+
|
| 227 |
+
CUTLASS_PRAGMA_UNROLL
|
| 228 |
+
for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
|
| 229 |
+
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
};
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
|
| 236 |
+
/// statically-sized array types to global memory.
|
| 237 |
+
/// (Specialization for half_t. Uses half2 vectorized-reduction.)
|
| 238 |
+
template <
|
| 239 |
+
int BlockThreads,
|
| 240 |
+
typename ArrayT>
|
| 241 |
+
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> :
|
| 242 |
+
BlockStriped<
|
| 243 |
+
BlockThreads,
|
| 244 |
+
ArrayT,
|
| 245 |
+
half2>
|
| 246 |
+
{
|
| 247 |
+
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
|
| 248 |
+
|
| 249 |
+
/// Reduce
|
| 250 |
+
CUTLASS_DEVICE
|
| 251 |
+
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
|
| 252 |
+
{
|
| 253 |
+
cutlass::atomic_add<half2> reduce;
|
| 254 |
+
half2 *access_output = reinterpret_cast<half2*>(ptr);
|
| 255 |
+
const half2 *access_data = reinterpret_cast<const half2*>(&data);
|
| 256 |
+
|
| 257 |
+
CUTLASS_PRAGMA_UNROLL
|
| 258 |
+
for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
|
| 259 |
+
{
|
| 260 |
+
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
};
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
} // namespace cutlass
|
| 267 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief CUDA interfaces to launch CUTLASS device-level operators (for >= SM90) that use thread-block clusters.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include <cuda_runtime_api.h>
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/trace.h"
|
| 41 |
+
#include <cute/arch/cluster_sm100.hpp>
|
| 42 |
+
#include "cutlass/arch/synclog.hpp"
|
| 43 |
+
|
| 44 |
+
#if defined(__CUDACC_RTC__)
|
| 45 |
+
#include CUDA_STD_HEADER(type_traits)
|
| 46 |
+
#else
|
| 47 |
+
#include <type_traits>
|
| 48 |
+
#include <cstdio>
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))
|
| 52 |
+
# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED
|
| 53 |
+
#endif
|
| 54 |
+
|
| 55 |
+
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
| 56 |
+
# define CUDA_ENABLE_PREFERRED_CLUSTER
|
| 57 |
+
#endif
|
| 58 |
+
namespace cutlass {
|
| 59 |
+
|
| 60 |
+
#ifndef NDEBUG
|
| 61 |
+
#define Return_Status(cudaError_t_status) \
|
| 62 |
+
if (cudaError_t_status != cudaSuccess) { \
|
| 63 |
+
fprintf(stderr, \
|
| 64 |
+
"[ ERROR: CUDA Runtime ] %s:%d: %s\n", \
|
| 65 |
+
__FILE__, \
|
| 66 |
+
__LINE__, \
|
| 67 |
+
cudaGetErrorString(cudaError_t_status)); \
|
| 68 |
+
return Status::kInvalid; \
|
| 69 |
+
} else { \
|
| 70 |
+
return Status::kSuccess; \
|
| 71 |
+
}
|
| 72 |
+
#else
|
| 73 |
+
#define Return_Status(cudaError_t_status) \
|
| 74 |
+
if (cudaError_t_status != cudaSuccess) { \
|
| 75 |
+
return Status::kInvalid; \
|
| 76 |
+
} else { \
|
| 77 |
+
return Status::kSuccess; \
|
| 78 |
+
}
|
| 79 |
+
#endif
|
| 80 |
+
|
| 81 |
+
struct ClusterLauncher {
|
| 82 |
+
constexpr static int MaxClusterSize = 32;
|
| 83 |
+
|
| 84 |
+
struct LaunchConfig {
|
| 85 |
+
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
| 86 |
+
cudaLaunchConfig_t launch_config;
|
| 87 |
+
|
| 88 |
+
#if defined(CUDA_ENABLE_PREFERRED_CLUSTER)
|
| 89 |
+
constexpr static int numAttrs = 3;
|
| 90 |
+
#else
|
| 91 |
+
|
| 92 |
+
constexpr static int numAttrs = 2;
|
| 93 |
+
#endif
|
| 94 |
+
cudaLaunchAttribute launch_attribute[numAttrs];
|
| 95 |
+
// Commonly used utility functions
|
| 96 |
+
dim3 gridDim() { return launch_config.gridDim; }
|
| 97 |
+
dim3 blockDim() { return launch_config.blockDim; }
|
| 98 |
+
#endif
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
// Check for hardware compatibility
|
| 102 |
+
static inline CUTLASS_HOST
|
| 103 |
+
Status check_cluster_dims(dim3 grid, dim3 cluster) {
|
| 104 |
+
if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) &&
|
| 105 |
+
(grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) {
|
| 106 |
+
return Status::kSuccess;
|
| 107 |
+
}
|
| 108 |
+
else {
|
| 109 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch.");
|
| 110 |
+
return Status::kInvalid;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
static inline CUTLASS_HOST
|
| 115 |
+
Status
|
| 116 |
+
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
| 117 |
+
init(void const* kernel_function)
|
| 118 |
+
#else
|
| 119 |
+
init(void const* /* kernel_function */)
|
| 120 |
+
#endif
|
| 121 |
+
{
|
| 122 |
+
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
| 123 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 124 |
+
if (kernel_function == nullptr) {
|
| 125 |
+
CUTLASS_TRACE_HOST("kernel_function is null");
|
| 126 |
+
return Status::kInvalid;
|
| 127 |
+
}
|
| 128 |
+
CUTLASS_TRACE_HOST("Checking previous error state before calling cudaFuncSetAttribute");
|
| 129 |
+
cudaError_t prevStatus = cudaGetLastError();
|
| 130 |
+
if (prevStatus != cudaSuccess) {
|
| 131 |
+
fprintf(stderr,
|
| 132 |
+
"[ ERROR: CUDA Runtime ] %s:%d: %s\n",
|
| 133 |
+
__FILE__,
|
| 134 |
+
__LINE__,
|
| 135 |
+
cudaGetErrorString(prevStatus));
|
| 136 |
+
return Status::kInvalid;
|
| 137 |
+
}
|
| 138 |
+
CUTLASS_TRACE_HOST("Calling cudaFuncSetAttribute");
|
| 139 |
+
#endif
|
| 140 |
+
// This attribute was added in CUDA 11.8.
|
| 141 |
+
cudaError_t status =
|
| 142 |
+
cudaFuncSetAttribute(
|
| 143 |
+
kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1);
|
| 144 |
+
Return_Status(status);
|
| 145 |
+
#else
|
| 146 |
+
return Status::kInvalid;
|
| 147 |
+
#endif
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
static inline CUTLASS_HOST
|
| 151 |
+
LaunchConfig make_cluster_launch_config(
|
| 152 |
+
dim3 const grid_dims,
|
| 153 |
+
dim3 const cluster_dims,
|
| 154 |
+
dim3 const block_dims,
|
| 155 |
+
size_t const smem_size = 0,
|
| 156 |
+
cudaStream_t cuda_stream = 0,
|
| 157 |
+
bool launch_with_pdl = false
|
| 158 |
+
, dim3 const fallback_cluster_dims = {0, 0, 0}
|
| 159 |
+
) {
|
| 160 |
+
LaunchConfig cluster_launch_config;
|
| 161 |
+
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
| 162 |
+
auto &launch_config = cluster_launch_config.launch_config;
|
| 163 |
+
auto &launch_attribute = cluster_launch_config.launch_attribute;
|
| 164 |
+
auto numAttrs = cluster_launch_config.numAttrs;
|
| 165 |
+
|
| 166 |
+
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
|
| 167 |
+
|
| 168 |
+
bool have_fallback = fallback_cluster_dims.x * fallback_cluster_dims.y * fallback_cluster_dims.z > 0;
|
| 169 |
+
|
| 170 |
+
if (have_fallback) {
|
| 171 |
+
launch_attribute[0].val.clusterDim = {fallback_cluster_dims.x, fallback_cluster_dims.y, fallback_cluster_dims.z};
|
| 172 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: Setting fallback ClusterDims = "
|
| 173 |
+
"(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n");
|
| 174 |
+
}
|
| 175 |
+
else {
|
| 176 |
+
|
| 177 |
+
launch_attribute[0].val.clusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z};
|
| 178 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: Setting ClusterDims = "
|
| 179 |
+
"(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
|
| 180 |
+
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
#if defined(CUDA_ENABLE_PREFERRED_CLUSTER)
|
| 184 |
+
if (have_fallback) {
|
| 185 |
+
if (cute::initialize_preferred_cluster_launch(nullptr, grid_dims, cluster_dims, fallback_cluster_dims)) {
|
| 186 |
+
launch_attribute[1].id = cudaLaunchAttributePreferredClusterDimension;
|
| 187 |
+
launch_attribute[1].val.preferredClusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z};
|
| 188 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: Setting preferred ClusterDims = "
|
| 189 |
+
"(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
else {
|
| 193 |
+
numAttrs--;
|
| 194 |
+
}
|
| 195 |
+
#endif
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
// PDL attributes
|
| 199 |
+
launch_attribute[numAttrs - 1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
| 200 |
+
launch_attribute[numAttrs - 1].val.programmaticStreamSerializationAllowed = 1;
|
| 201 |
+
|
| 202 |
+
launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z};
|
| 203 |
+
launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z};
|
| 204 |
+
launch_config.dynamicSmemBytes = smem_size;
|
| 205 |
+
launch_config.stream = cuda_stream;
|
| 206 |
+
launch_config.numAttrs = launch_with_pdl ? numAttrs : numAttrs - 1;
|
| 207 |
+
launch_config.attrs = launch_attribute;
|
| 208 |
+
return cluster_launch_config;
|
| 209 |
+
#else
|
| 210 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
|
| 211 |
+
return cluster_launch_config;
|
| 212 |
+
#endif
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// This is the method we expect to use going forward
|
| 216 |
+
static inline CUTLASS_HOST
|
| 217 |
+
Status launch(
|
| 218 |
+
dim3 const grid_dims,
|
| 219 |
+
dim3 const cluster_dims,
|
| 220 |
+
dim3 const block_dims,
|
| 221 |
+
size_t const smem_size,
|
| 222 |
+
cudaStream_t cuda_stream,
|
| 223 |
+
void const* kernel,
|
| 224 |
+
void** kernel_params,
|
| 225 |
+
bool launch_with_pdl = false) {
|
| 226 |
+
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
| 227 |
+
LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, cluster_dims,
|
| 228 |
+
block_dims, smem_size, cuda_stream, launch_with_pdl);
|
| 229 |
+
|
| 230 |
+
auto launch_grid_dims = cluster_launch_config.gridDim();
|
| 231 |
+
if (check_cluster_dims(launch_grid_dims, cluster_dims) != Status::kSuccess) {
|
| 232 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting.");
|
| 233 |
+
return Status::kInvalid;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
auto init_status = init(kernel);
|
| 237 |
+
if (init_status != Status::kSuccess) {
|
| 238 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting.");
|
| 239 |
+
return Status::kInvalid;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: Launching GridDims = "
|
| 243 |
+
"(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), "
|
| 244 |
+
"And ClusterDims = "
|
| 245 |
+
"(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
|
| 246 |
+
|
| 247 |
+
cutlass::arch::synclog_setup();
|
| 248 |
+
cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params);
|
| 249 |
+
Return_Status(status);
|
| 250 |
+
#else
|
| 251 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
|
| 252 |
+
return Status::kInvalid;
|
| 253 |
+
#endif
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
// This is the method we expect to use going forward
|
| 258 |
+
// Launch a preferred cluster grid
|
| 259 |
+
static inline CUTLASS_HOST
|
| 260 |
+
Status launch_with_fallback_cluster(
|
| 261 |
+
dim3 const grid_dims,
|
| 262 |
+
dim3 const preferred_cluster_dims,
|
| 263 |
+
dim3 const fallback_cluster_dims,
|
| 264 |
+
dim3 const block_dims,
|
| 265 |
+
size_t const smem_size,
|
| 266 |
+
cudaStream_t cuda_stream,
|
| 267 |
+
void const* kernel,
|
| 268 |
+
void** kernel_params,
|
| 269 |
+
bool launch_with_pdl = false) {
|
| 270 |
+
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
| 271 |
+
LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, preferred_cluster_dims,
|
| 272 |
+
block_dims, smem_size, cuda_stream, launch_with_pdl, fallback_cluster_dims);
|
| 273 |
+
|
| 274 |
+
auto launch_grid_dims = cluster_launch_config.gridDim();
|
| 275 |
+
if (check_cluster_dims(launch_grid_dims, preferred_cluster_dims) != Status::kSuccess) {
|
| 276 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting.");
|
| 277 |
+
return Status::kInvalid;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
auto init_status = init(kernel);
|
| 281 |
+
if (init_status != Status::kSuccess) {
|
| 282 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting.");
|
| 283 |
+
return Status::kInvalid;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: Launching \n\tGridDims = "
|
| 287 |
+
"(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), "
|
| 288 |
+
"\n\tPreferred ClusterDims = "
|
| 289 |
+
"(" << preferred_cluster_dims.x << ", " << preferred_cluster_dims.y << ", " << preferred_cluster_dims.z << "),"
|
| 290 |
+
"\n\tFallback ClusterDims = "
|
| 291 |
+
"(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n");
|
| 292 |
+
|
| 293 |
+
cutlass::arch::synclog_setup();
|
| 294 |
+
cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params);
|
| 295 |
+
Return_Status(status);
|
| 296 |
+
#else
|
| 297 |
+
CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
|
| 298 |
+
return Status::kInvalid;
|
| 299 |
+
#endif
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
};
|
| 304 |
+
|
| 305 |
+
namespace detail {
|
| 306 |
+
|
| 307 |
+
template<class Arg>
|
| 308 |
+
void* checked_addressof(Arg&& arg) {
|
| 309 |
+
static_assert(! std::is_rvalue_reference_v<Arg> || ! std::is_const_v<Arg>, "You cannot take the address of a const rvalue reference (const T&&).");
|
| 310 |
+
// We use std::addressof to ensure we get the address,
|
| 311 |
+
// in case the type has an overloaded operator&.
|
| 312 |
+
// Note that this precludes `const T&&` references.
|
| 313 |
+
return const_cast<void*>(reinterpret_cast<void const*>(std::addressof(arg)));
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
} // namespace detail
|
| 317 |
+
|
| 318 |
+
//! Parameters for launch_on_cluster (see below).
|
| 319 |
+
struct ClusterLaunchParams {
|
| 320 |
+
//! Grid dimensions
|
| 321 |
+
dim3 grid_dims{1, 1, 1};
|
| 322 |
+
|
| 323 |
+
//! Block dimensions
|
| 324 |
+
dim3 block_dims{1, 1, 1};
|
| 325 |
+
|
| 326 |
+
//! Cluster dimensions
|
| 327 |
+
dim3 cluster_dims{1, 1, 1};
|
| 328 |
+
|
| 329 |
+
//! Number of bytes required for the kernel's shared memory.
|
| 330 |
+
int smem_size_in_bytes = 0;
|
| 331 |
+
|
| 332 |
+
//! CUDA stream on which to launch the kernel.
|
| 333 |
+
cudaStream_t cuda_stream = nullptr;
|
| 334 |
+
};
|
| 335 |
+
|
| 336 |
+
/// @brief Launch the kernel on the stream using cluster launch.
|
| 337 |
+
///
|
| 338 |
+
/// @param params Cluster launch parameters (see above).
|
| 339 |
+
/// @param kernel_ptr Pointer to the kernel function (see example).
|
| 340 |
+
/// @param args Zero or more arguments to pass to the kernel.
|
| 341 |
+
///
|
| 342 |
+
/// @tparam Args Types of the arguments passed to the kernel.
|
| 343 |
+
/// Don't specify this/these template argument(s) explicitly.
|
| 344 |
+
///
|
| 345 |
+
/// @return Status::Success on success, else an error code.
|
| 346 |
+
///
|
| 347 |
+
/// @code
|
| 348 |
+
/// template<class SharedMemoryType, class A, class B, class C>
|
| 349 |
+
/// __global__ void kernel(A a, B b, C c);
|
| 350 |
+
///
|
| 351 |
+
/// X x = get_x();
|
| 352 |
+
/// Y y = get_y();
|
| 353 |
+
/// Z z = get_z();
|
| 354 |
+
///
|
| 355 |
+
/// void const* kernel_ptr =
|
| 356 |
+
/// const_cast<void const*>(reinterpret_cast<void*>(
|
| 357 |
+
/// &kernel<SharedMemory, X, Y, Z>));
|
| 358 |
+
/// auto status = launch_kernel_on_cluster(
|
| 359 |
+
/// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)},
|
| 360 |
+
/// kernel_ptr, x, y, z);
|
| 361 |
+
/// @endcode
|
| 362 |
+
template<class ... Args>
|
| 363 |
+
CUTLASS_HOST cutlass::Status
|
| 364 |
+
launch_kernel_on_cluster(const ClusterLaunchParams& params,
|
| 365 |
+
void const* kernel_ptr,
|
| 366 |
+
Args&& ... args)
|
| 367 |
+
{
|
| 368 |
+
// Unfortunately, we find ourselves needing to pass in
|
| 369 |
+
// the parameters as an array of raw pointers.
|
| 370 |
+
if constexpr (sizeof...(Args) == 0) {
|
| 371 |
+
return cutlass::ClusterLauncher::launch(
|
| 372 |
+
params.grid_dims,
|
| 373 |
+
params.cluster_dims,
|
| 374 |
+
params.block_dims,
|
| 375 |
+
params.smem_size_in_bytes,
|
| 376 |
+
params.cuda_stream,
|
| 377 |
+
kernel_ptr, nullptr);
|
| 378 |
+
}
|
| 379 |
+
else {
|
| 380 |
+
void* kernel_params[sizeof...(Args)] = {
|
| 381 |
+
detail::checked_addressof(std::forward<Args>(args))...
|
| 382 |
+
};
|
| 383 |
+
return cutlass::ClusterLauncher::launch(
|
| 384 |
+
params.grid_dims,
|
| 385 |
+
params.cluster_dims,
|
| 386 |
+
params.block_dims,
|
| 387 |
+
params.smem_size_in_bytes,
|
| 388 |
+
params.cuda_stream,
|
| 389 |
+
kernel_ptr,
|
| 390 |
+
kernel_params);
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <cuComplex.h>
|
| 35 |
+
|
| 36 |
+
#include <cuda_fp16.h>
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#if defined(__CUDACC_RTC__)
|
| 39 |
+
#include CUDA_STD_HEADER(cstdint)
|
| 40 |
+
#else
|
| 41 |
+
#include <cstdint>
|
| 42 |
+
#endif
|
| 43 |
+
#include "cutlass/functional.h"
|
| 44 |
+
#include "cutlass/platform/platform.h"
|
| 45 |
+
#include "cutlass/real.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/numeric_types.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/fast_math.h"
|
| 50 |
+
|
| 51 |
+
#if !defined(__CUDACC_RTC__)
|
| 52 |
+
#include <iosfwd>
|
| 53 |
+
#endif
|
| 54 |
+
|
| 55 |
+
namespace cutlass {
|
| 56 |
+
|
| 57 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
/// Enumeraed type describing a transformation on a complex value.
|
| 59 |
+
enum class ComplexTransform {
|
| 60 |
+
kNone,
|
| 61 |
+
kConjugate
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
/// Defines ComplexTransform inversions
|
| 66 |
+
template <ComplexTransform kTransform>
|
| 67 |
+
struct InvertComplexTransform;
|
| 68 |
+
|
| 69 |
+
/// Invert ComplexTransform from kNone to kConjugate
|
| 70 |
+
template <>
|
| 71 |
+
struct InvertComplexTransform<ComplexTransform::kNone> {
|
| 72 |
+
static ComplexTransform const transform = ComplexTransform::kConjugate;
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
/// Invert ComplexTransform from kConjugate to kNone
|
| 76 |
+
template <>
|
| 77 |
+
struct InvertComplexTransform<ComplexTransform::kConjugate> {
|
| 78 |
+
static ComplexTransform const transform = ComplexTransform::kNone;
|
| 79 |
+
};
|
| 80 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 81 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 82 |
+
|
| 83 |
+
//
|
| 84 |
+
// Accessors for CUDA complex types
|
| 85 |
+
//
|
| 86 |
+
|
| 87 |
+
#if !defined(__CUDACC_RTC__)
|
| 88 |
+
/// Returns the real part of the complex number
|
| 89 |
+
CUTLASS_HOST_DEVICE
|
| 90 |
+
float const &real(cuFloatComplex const &z) { return z.x; }
|
| 91 |
+
|
| 92 |
+
/// Returns the real part of the complex number
|
| 93 |
+
CUTLASS_HOST_DEVICE
|
| 94 |
+
float &real(cuFloatComplex &z) { return z.x; }
|
| 95 |
+
|
| 96 |
+
/// Returns the real part of the complex number
|
| 97 |
+
CUTLASS_HOST_DEVICE
|
| 98 |
+
double const &real(cuDoubleComplex const &z) { return z.x; }
|
| 99 |
+
|
| 100 |
+
/// Returns the real part of the complex number
|
| 101 |
+
CUTLASS_HOST_DEVICE
|
| 102 |
+
double &real(cuDoubleComplex &z) { return z.x; }
|
| 103 |
+
|
| 104 |
+
/// Returns the imaginary part of the complex number
|
| 105 |
+
CUTLASS_HOST_DEVICE
|
| 106 |
+
float const &imag(cuFloatComplex const &z) { return z.y; }
|
| 107 |
+
|
| 108 |
+
/// Returns the imaginary part of the complex number
|
| 109 |
+
CUTLASS_HOST_DEVICE
|
| 110 |
+
float &imag(cuFloatComplex &z) { return z.y; }
|
| 111 |
+
|
| 112 |
+
/// Returns the imaginary part of the complex number
|
| 113 |
+
CUTLASS_HOST_DEVICE
|
| 114 |
+
double const &imag(cuDoubleComplex const &z) { return z.y; }
|
| 115 |
+
|
| 116 |
+
/// Returns the imaginary part of the complex number
|
| 117 |
+
CUTLASS_HOST_DEVICE
|
| 118 |
+
double &imag(cuDoubleComplex &z) { return z.y; }
|
| 119 |
+
|
| 120 |
+
// Returns the conjugate of the complex number
|
| 121 |
+
CUTLASS_HOST_DEVICE cuFloatComplex
|
| 122 |
+
conj(cuFloatComplex const& z) {
|
| 123 |
+
return make_cuFloatComplex(z.x, -z.y);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Returns the conjugate of the complex number
|
| 127 |
+
CUTLASS_HOST_DEVICE cuDoubleComplex
|
| 128 |
+
conj(cuDoubleComplex const& z) {
|
| 129 |
+
return make_cuDoubleComplex(z.x, -z.y);
|
| 130 |
+
}
|
| 131 |
+
#endif
|
| 132 |
+
|
| 133 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
|
| 135 |
+
/// Class for representing and manipulating complex numbers with conversions from built-in CUDA
|
| 136 |
+
/// complex types.
|
| 137 |
+
|
| 138 |
+
template <typename T>
|
| 139 |
+
class complex
|
| 140 |
+
{
|
| 141 |
+
public:
|
| 142 |
+
/// Type alias for scalar type
|
| 143 |
+
using value_type = T;
|
| 144 |
+
|
| 145 |
+
private:
|
| 146 |
+
//
|
| 147 |
+
// Data members
|
| 148 |
+
//
|
| 149 |
+
|
| 150 |
+
/// Real part
|
| 151 |
+
T _real;
|
| 152 |
+
|
| 153 |
+
/// Imaginary part
|
| 154 |
+
T _imag;
|
| 155 |
+
|
| 156 |
+
public:
|
| 157 |
+
|
| 158 |
+
//
|
| 159 |
+
// Methods
|
| 160 |
+
//
|
| 161 |
+
|
| 162 |
+
/// Default constructor
|
| 163 |
+
complex() = default;
|
| 164 |
+
|
| 165 |
+
/// Constructor
|
| 166 |
+
CUTLASS_HOST_DEVICE
|
| 167 |
+
complex(T r) : _real(r), _imag(T(0)) {}
|
| 168 |
+
|
| 169 |
+
/// Constructor
|
| 170 |
+
CUTLASS_HOST_DEVICE
|
| 171 |
+
complex(T r, T i) : _real(r), _imag(i) {}
|
| 172 |
+
|
| 173 |
+
/// Constructor
|
| 174 |
+
template<typename A>
|
| 175 |
+
CUTLASS_HOST_DEVICE
|
| 176 |
+
complex(complex<A> const &z) : _real(static_cast<T>(z.real())), _imag(static_cast<T>(z.imag())) {}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
#if !defined(__CUDACC_RTC__)
|
| 180 |
+
/// Conversion from cuFloatComplex
|
| 181 |
+
CUTLASS_HOST_DEVICE
|
| 182 |
+
complex(cuFloatComplex const &z) : _real(static_cast<T>(cuCrealf(z))), _imag(static_cast<T>(cuCimagf(z))) {}
|
| 183 |
+
|
| 184 |
+
/// Conversion from cuDoubleComplex
|
| 185 |
+
CUTLASS_HOST_DEVICE
|
| 186 |
+
complex(cuDoubleComplex const &z) : _real(static_cast<T>(cuCreal(z))), _imag(static_cast<T>(cuCimag(z))) {}
|
| 187 |
+
#endif
|
| 188 |
+
|
| 189 |
+
/// Equality operator
|
| 190 |
+
CUTLASS_HOST_DEVICE bool operator==(complex<T> const &rhs) const {
|
| 191 |
+
return this->real() == rhs.real() && this->imag() == rhs.imag();
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/// Inequality operator
|
| 195 |
+
CUTLASS_HOST_DEVICE bool operator!=(complex<T> const &rhs) const {
|
| 196 |
+
return !(*this == rhs);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Addition
|
| 200 |
+
template <typename A>
|
| 201 |
+
CUTLASS_HOST_DEVICE complex<T> operator+(complex<A> const &rhs) const {
|
| 202 |
+
return complex<T>(this->real() + rhs.real(), this->imag() + rhs.imag());
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/// Reduction into memory address. Components may update out of order.
|
| 206 |
+
template <typename OtherT>
|
| 207 |
+
CUTLASS_DEVICE void red(complex<OtherT> *ptr) const {
|
| 208 |
+
static_assert(platform::is_same<T, OtherT>::value, "Component type must match");
|
| 209 |
+
cutlass::atomic_add<T> reduce;
|
| 210 |
+
reduce(&ptr->_real, _real);
|
| 211 |
+
reduce(&ptr->_imag, _imag);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
/// Reduction into memory address. Components may update out of order. (Half specialization)
|
| 215 |
+
CUTLASS_DEVICE void red(complex<half_t> *ptr) const {
|
| 216 |
+
static_assert(platform::is_same<T, half_t>::value, "Component type must match");
|
| 217 |
+
half2 *h2_ptr = reinterpret_cast<half2*>(ptr);
|
| 218 |
+
half2 h2_data = reinterpret_cast<half2&>(*this);
|
| 219 |
+
cutlass::atomic_add<half2> reduce;
|
| 220 |
+
reduce(h2_ptr, h2_data);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
/// Subtraction
|
| 224 |
+
template <typename A>
|
| 225 |
+
CUTLASS_HOST_DEVICE complex<T> operator-(complex<A> const &rhs) const {
|
| 226 |
+
return complex<T>(this->real() - rhs.real(), this->imag() - rhs.imag());
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/// Multiplication
|
| 230 |
+
template <typename A>
|
| 231 |
+
CUTLASS_HOST_DEVICE complex<T> operator*(complex<A> const &rhs) const {
|
| 232 |
+
return complex<T>(this->real() * rhs.real() - this->imag() * rhs.imag(),
|
| 233 |
+
this->real() * rhs.imag() + this->imag() * rhs.real());
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
/// Scalar Multiplication
|
| 237 |
+
template <typename A>
|
| 238 |
+
CUTLASS_HOST_DEVICE complex<T> operator*(A const &s) const {
|
| 239 |
+
return complex<T>(this->real() * s, this->imag() * s);
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
/// Division
|
| 243 |
+
template <typename A>
|
| 244 |
+
CUTLASS_HOST_DEVICE complex<T> operator/(complex<A> const &rhs) const {
|
| 245 |
+
T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag());
|
| 246 |
+
|
| 247 |
+
return complex<T>(
|
| 248 |
+
(real() * rhs.real() + imag() * rhs.imag()) / d,
|
| 249 |
+
(imag() * rhs.real() - real() * rhs.imag()) / d
|
| 250 |
+
);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
/// Scalar Division
|
| 254 |
+
template <typename A>
|
| 255 |
+
CUTLASS_HOST_DEVICE complex<T> operator/(A const &s) const {
|
| 256 |
+
return complex<T>(this->real() / s, this->imag() / s);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
/// Addition
|
| 260 |
+
template <typename A>
|
| 261 |
+
CUTLASS_HOST_DEVICE complex<T> &operator+=(complex<A> const &rhs) {
|
| 262 |
+
*this = *this + rhs;
|
| 263 |
+
return *this;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
/// Subtraction
|
| 267 |
+
template <typename A>
|
| 268 |
+
CUTLASS_HOST_DEVICE complex<T> &operator-=(complex<A> const &rhs) {
|
| 269 |
+
*this = *this - rhs;
|
| 270 |
+
return *this;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/// Multiplication
|
| 274 |
+
template <typename A>
|
| 275 |
+
CUTLASS_HOST_DEVICE complex<T> &operator*=(complex<A> const &rhs) {
|
| 276 |
+
*this = *this * rhs;
|
| 277 |
+
return *this;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
/// Scalar multiplication
|
| 281 |
+
template <typename A>
|
| 282 |
+
CUTLASS_HOST_DEVICE complex<T> &operator*=(A s) {
|
| 283 |
+
*this = *this * s;
|
| 284 |
+
return *this;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
/// Division
|
| 288 |
+
template <typename A>
|
| 289 |
+
CUTLASS_HOST_DEVICE complex<T> &operator/=(complex<A> const &rhs) {
|
| 290 |
+
*this = *this / rhs;
|
| 291 |
+
return *this;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
/// Accesses the real part of the complex number
|
| 295 |
+
CUTLASS_HOST_DEVICE
|
| 296 |
+
T const &real() const { return _real; }
|
| 297 |
+
|
| 298 |
+
/// Accesses the real part of the complex number
|
| 299 |
+
CUTLASS_HOST_DEVICE
|
| 300 |
+
T &real() { return _real; }
|
| 301 |
+
|
| 302 |
+
/// Accesses the imaginary part of the complex number
|
| 303 |
+
CUTLASS_HOST_DEVICE
|
| 304 |
+
T const &imag() const { return _imag; }
|
| 305 |
+
|
| 306 |
+
/// Accesses the imaginary part of the complex number
|
| 307 |
+
CUTLASS_HOST_DEVICE
|
| 308 |
+
T &imag() { return _imag; }
|
| 309 |
+
|
| 310 |
+
/// Set the real part of the complex number
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
void real(T real) { _real = real; }
|
| 313 |
+
|
| 314 |
+
/// Set the imaginary part of the complex number
|
| 315 |
+
CUTLASS_HOST_DEVICE
|
| 316 |
+
void imag(T imag) { _imag = imag; }
|
| 317 |
+
|
| 318 |
+
#if !defined(__CUDACC_RTC__)
|
| 319 |
+
/// Converts to cuFloatComplex
|
| 320 |
+
CUTLASS_HOST_DEVICE
|
| 321 |
+
explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); }
|
| 322 |
+
|
| 323 |
+
/// Converts to cuDoubleComplex
|
| 324 |
+
CUTLASS_HOST_DEVICE
|
| 325 |
+
explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); }
|
| 326 |
+
#endif
|
| 327 |
+
};
|
| 328 |
+
|
| 329 |
+
// Complex conjugate
|
| 330 |
+
template<class T>
|
| 331 |
+
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const& z) {
|
| 332 |
+
return {z.real(), -z.imag()};
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 336 |
+
|
| 337 |
+
//
|
| 338 |
+
// Accessors for complex template
|
| 339 |
+
//
|
| 340 |
+
|
| 341 |
+
// Nonmember real and imag need to work for non-complex numbers too.
|
| 342 |
+
// That means cutlass::complex, std::complex, cuda::std::complex, and
|
| 343 |
+
// any user-defined complex number type that looks like std::complex.
|
| 344 |
+
// It's reasonable to assume that a "complex number type" has
|
| 345 |
+
// zero-argument real() and imag() member functions returning
|
| 346 |
+
// non-void. While cuFloatComplex and cuDoubleComplex lack those
|
| 347 |
+
// member functions, one-argument nonmember real and imag overloads
|
| 348 |
+
// for those types are defined above.
|
| 349 |
+
|
| 350 |
+
namespace detail {
|
| 351 |
+
|
| 352 |
+
template <typename T, typename Enable = void>
|
| 353 |
+
struct has_zero_argument_real_member_function :
|
| 354 |
+
cutlass::platform::false_type
|
| 355 |
+
{};
|
| 356 |
+
|
| 357 |
+
template <typename T>
|
| 358 |
+
struct has_zero_argument_real_member_function<T,
|
| 359 |
+
cutlass::platform::enable_if_t<
|
| 360 |
+
! cutlass::platform::is_void_v<
|
| 361 |
+
decltype(cutlass::platform::declval<T>().real())
|
| 362 |
+
>
|
| 363 |
+
>
|
| 364 |
+
> : cutlass::platform::true_type
|
| 365 |
+
{};
|
| 366 |
+
|
| 367 |
+
template <typename T>
|
| 368 |
+
constexpr bool has_zero_argument_real_member_function_v =
|
| 369 |
+
has_zero_argument_real_member_function<T>::value;
|
| 370 |
+
|
| 371 |
+
template <typename T, typename Enable = void>
|
| 372 |
+
struct has_zero_argument_imag_member_function :
|
| 373 |
+
cutlass::platform::false_type
|
| 374 |
+
{};
|
| 375 |
+
|
| 376 |
+
template <typename T>
|
| 377 |
+
struct has_zero_argument_imag_member_function<T,
|
| 378 |
+
cutlass::platform::enable_if_t<
|
| 379 |
+
! cutlass::platform::is_void_v<
|
| 380 |
+
decltype(cutlass::platform::declval<T>().imag())
|
| 381 |
+
>
|
| 382 |
+
>
|
| 383 |
+
> : cutlass::platform::true_type
|
| 384 |
+
{};
|
| 385 |
+
|
| 386 |
+
template <typename T>
|
| 387 |
+
constexpr bool has_zero_argument_imag_member_function_v =
|
| 388 |
+
has_zero_argument_imag_member_function<T>::value;
|
| 389 |
+
|
| 390 |
+
} // namespace detail
|
| 391 |
+
|
| 392 |
+
template<typename T>
|
| 393 |
+
CUTLASS_HOST_DEVICE auto real(T z) {
|
| 394 |
+
if constexpr (detail::has_zero_argument_real_member_function_v<T>) {
|
| 395 |
+
return z.real();
|
| 396 |
+
} else {
|
| 397 |
+
return z;
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
template<typename T>
|
| 402 |
+
CUTLASS_HOST_DEVICE auto imag(T z) {
|
| 403 |
+
if constexpr (detail::has_zero_argument_imag_member_function_v<T>) {
|
| 404 |
+
return z.imag();
|
| 405 |
+
} else {
|
| 406 |
+
// Imaginary part of a non-complex input has the same type as the
|
| 407 |
+
// input, and its value is zero. CUTLASS assumes in this case
|
| 408 |
+
// that value-initializing T is well-formed and results in zero.
|
| 409 |
+
return T{};
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
//
|
| 414 |
+
// Output operators
|
| 415 |
+
//
|
| 416 |
+
|
| 417 |
+
#if !defined(__CUDACC_RTC__)
|
| 418 |
+
template <typename T>
|
| 419 |
+
std::ostream &operator<<(std::ostream &out, complex<T> const &z) {
|
| 420 |
+
T _r = real(z);
|
| 421 |
+
T _i = imag(z);
|
| 422 |
+
|
| 423 |
+
if (bool(_i)) {
|
| 424 |
+
return out << _r << "+i" << _i;
|
| 425 |
+
}
|
| 426 |
+
return out << _r;
|
| 427 |
+
}
|
| 428 |
+
#endif
|
| 429 |
+
|
| 430 |
+
//
|
| 431 |
+
// Non-member operators defined for complex types
|
| 432 |
+
//
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
//
|
| 436 |
+
// Non-member functions defined for complex numbers
|
| 437 |
+
//
|
| 438 |
+
|
| 439 |
+
// abs returns the magnitude of the complex number.
|
| 440 |
+
|
| 441 |
+
CUTLASS_HOST_DEVICE float abs(complex<float> const &z) {
|
| 442 |
+
return ::hypot(z.real(), z.imag());
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
CUTLASS_HOST_DEVICE double abs(complex<double> const &z) {
|
| 446 |
+
return ::hypot(z.real(), z.imag());
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
// In theory, it would make sense to add a complex<long double>
|
| 450 |
+
// specialization of abs here, since hypot works for long double too.
|
| 451 |
+
// In practice, long double doesn't have a portable number of bits or
|
| 452 |
+
// behavior, so users who care about higher-precision floating-point
|
| 453 |
+
// computation should probably insist on an actual FP128 type.
|
| 454 |
+
|
| 455 |
+
template <typename T>
|
| 456 |
+
CUTLASS_HOST_DEVICE T abs(complex<T> const &z) {
|
| 457 |
+
// cutlass::complex permits all kinds of T, including types that
|
| 458 |
+
// don't have NaN. For a generic floating-point type with Inf
|
| 459 |
+
// and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it
|
| 460 |
+
// would handle issues like avoiding unwarranted overflow if
|
| 461 |
+
// z.real() or z.imag() is slightly bigger than the square root of
|
| 462 |
+
// the max finite number. That could be a future improvement; for
|
| 463 |
+
// now, the code just uses the naive algorithm.
|
| 464 |
+
//
|
| 465 |
+
// Use the "swap two-step" idiom so that argument-dependent lookup
|
| 466 |
+
// can find any CUTLASS-specific overloads.
|
| 467 |
+
using cutlass::sqrt;
|
| 468 |
+
return sqrt(z.real() * z.real() + z.imag() * z.imag());
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
/// Returns the magnitude of the complex number
|
| 472 |
+
template <typename T>
|
| 473 |
+
CUTLASS_HOST_DEVICE T arg(complex<T> const &z) {
|
| 474 |
+
return atan2(imag(z), real(z));
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
/// Returns the squared magnitude of a real number
|
| 478 |
+
template <typename T>
|
| 479 |
+
CUTLASS_HOST_DEVICE T norm(T const &z) {
|
| 480 |
+
return z * z;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
/// Returns the squared magnitude of a real number
|
| 484 |
+
template <>
|
| 485 |
+
CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) {
|
| 486 |
+
return static_cast<int8_t>(z * z);
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
/// Returns the squared magnitude of a complex number
|
| 490 |
+
template <typename T>
|
| 491 |
+
CUTLASS_HOST_DEVICE double norm(complex<T> const &z) {
|
| 492 |
+
return real(z) * real(z) + imag(z) * imag(z);
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
/// Norm-accumulate calculation
|
| 496 |
+
template <typename T, typename R>
|
| 497 |
+
CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) {
|
| 498 |
+
return accumulator + static_cast<R>(x) * static_cast<R>(x);
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
/// Norm accumulate specialized for complex types
|
| 502 |
+
template <typename T, typename R>
|
| 503 |
+
CUTLASS_HOST_DEVICE R norm_accumulate(complex<T> const &z, R const &accumulator) {
|
| 504 |
+
return accumulator + static_cast<R>(real(z)) * static_cast<R>(real(z)) +
|
| 505 |
+
static_cast<R>(imag(z)) * static_cast<R>(imag(z));
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
namespace detail {
|
| 509 |
+
|
| 510 |
+
template<class T>
|
| 511 |
+
CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) {
|
| 512 |
+
return conj(z);
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
template<class T>
|
| 516 |
+
CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) {
|
| 517 |
+
return z;
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
template<class T>
|
| 521 |
+
CUTLASS_HOST_DEVICE T conj_impl(T const& z) {
|
| 522 |
+
constexpr bool use_unqualified_conj =
|
| 523 |
+
! cutlass::platform::is_arithmetic_v<T> &&
|
| 524 |
+
! detail::has_cutlass_conj_v<T> &&
|
| 525 |
+
detail::has_unqualified_conj_v<T>;
|
| 526 |
+
return conj_impl(z, cutlass::platform::bool_constant<use_unqualified_conj>{});
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
} // namespace detail
|
| 530 |
+
|
| 531 |
+
// Return the complex conjugate of the input.
|
| 532 |
+
//
|
| 533 |
+
// This MUST be a function and not a function object, because it may
|
| 534 |
+
// be common practice for downstream types to define specifically
|
| 535 |
+
// cutlass::conj overloads, instead of overloads in their namespace.
|
| 536 |
+
//
|
| 537 |
+
// As a result of this being a function and not a function object,
|
| 538 |
+
// CUTLASS code needs to declare "using cutlass::conj;" in scope and
|
| 539 |
+
// then call this function unqualified, just like std::swap.
|
| 540 |
+
//
|
| 541 |
+
// If an overload already exists for cutlass::conj(T), that overload
|
| 542 |
+
// will be called instead of this one. Otherwise:
|
| 543 |
+
//
|
| 544 |
+
// 1. for arithmetic types, return z;
|
| 545 |
+
//
|
| 546 |
+
// 2. for types where (namespace-unqualified) conj(z) is well formed
|
| 547 |
+
// and cutlass::conj(z) is NOT well formed, return conj(z); and,
|
| 548 |
+
//
|
| 549 |
+
// 3. for everything else, return z.
|
| 550 |
+
//
|
| 551 |
+
// Regarding (1), the C++ Standard Library makes std::conj always
|
| 552 |
+
// return std::complex, even for (noncomplex) arithmetic types.
|
| 553 |
+
// cutlass::conj(T t) needs to return type T. This follows the
|
| 554 |
+
// convention of linear algebra software like the BLAS, where
|
| 555 |
+
// "conjugate transpose" means the same thing as "transpose" for a
|
| 556 |
+
// matrix of noncomplex numbers.
|
| 557 |
+
//
|
| 558 |
+
// Case (2) covers std::complex, cuda::std::complex, and non-Standard
|
| 559 |
+
// (including user-defined) complex number types (for which "conj(z)"
|
| 560 |
+
// is findable via argument-dependent lookup, but does not live in the
|
| 561 |
+
// cutlass namespace). It excludes cutlass::conj(z) in order to
|
| 562 |
+
// prevent infinite recursion.
|
| 563 |
+
//
|
| 564 |
+
// Case (3) covers non-Standard non-complex number types.
|
| 565 |
+
template<class T>
|
| 566 |
+
CUTLASS_HOST_DEVICE T conj(T const& z) {
|
| 567 |
+
return detail::conj_impl(z);
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
/// Projects the complex number z onto the Riemann sphere
|
| 571 |
+
template <typename T>
|
| 572 |
+
CUTLASS_HOST_DEVICE complex<T> proj(complex<T> const &z) {
|
| 573 |
+
T d = real(z) * real(z) + imag(z) * imag(z) + T(1);
|
| 574 |
+
return complex<T>((T(2) * real(z)) / d, (T(2) * imag(z)) / d);
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
/// Returns a complex number with magnitude r and phase theta
|
| 578 |
+
template <typename T>
|
| 579 |
+
CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
|
| 580 |
+
return complex<T>(r * cos(theta), r * sin(theta));
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
/// Computes the complex exponential of z.
|
| 584 |
+
template <typename T>
|
| 585 |
+
CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
|
| 586 |
+
return complex<T>(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z)));
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
/// Computes the log of z
|
| 590 |
+
template <typename T>
|
| 591 |
+
CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
|
| 592 |
+
return complex<T>(log(abs(z)), arg(z));
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
/// Computes the log base 10 of z
|
| 596 |
+
template <typename T>
|
| 597 |
+
CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
|
| 598 |
+
return log(z) / T(log(T(10)));
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
/// Computes the square root of complex number z
|
| 602 |
+
template <typename T>
|
| 603 |
+
CUTLASS_HOST_DEVICE complex<T> sqrt(complex<T> const &z) {
|
| 604 |
+
return sqrt(T(2)) / T(2) *
|
| 605 |
+
complex<T>(sqrt(sqrt(norm(z)) + real(z)),
|
| 606 |
+
(imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z)));
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
/// Computes the cosine of complex z.
|
| 610 |
+
template <typename T>
|
| 611 |
+
CUTLASS_HOST_DEVICE complex<T> cos(complex<T> const &z) {
|
| 612 |
+
return (exp(z) + exp(-z)) / T(2);
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
/// Computes the sin of complex z.
|
| 616 |
+
template <typename T>
|
| 617 |
+
CUTLASS_HOST_DEVICE complex<T> sin(complex<T> const &z) {
|
| 618 |
+
return (exp(-z) - exp(z)) * complex<T>(T(0), T(1) / T(2));
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
/// Comparison
|
| 622 |
+
template <typename T>
|
| 623 |
+
CUTLASS_HOST_DEVICE bool operator<(complex<T> const &lhs, complex<T> const &rhs) {
|
| 624 |
+
return true;
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 628 |
+
|
| 629 |
+
/// Partial specialization for complex-valued type.
|
| 630 |
+
template <typename T>
|
| 631 |
+
struct RealType< complex<T> >
|
| 632 |
+
{
|
| 633 |
+
using Type = T;
|
| 634 |
+
|
| 635 |
+
/// Number of elements
|
| 636 |
+
static int const kExtent = 2;
|
| 637 |
+
|
| 638 |
+
CUTLASS_HOST_DEVICE
|
| 639 |
+
static complex<T> from_real(double x) {
|
| 640 |
+
return complex<T>(static_cast<T>(x));
|
| 641 |
+
}
|
| 642 |
+
};
|
| 643 |
+
|
| 644 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 645 |
+
|
| 646 |
+
template <>
|
| 647 |
+
CUTLASS_HOST_DEVICE
|
| 648 |
+
cutlass::complex<half_t> from_real<cutlass::complex<half_t> >(double r) {
|
| 649 |
+
return cutlass::complex<half_t>(half_t(r));
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
template <>
|
| 653 |
+
CUTLASS_HOST_DEVICE
|
| 654 |
+
cutlass::complex<float> from_real<cutlass::complex<float> >(double r) {
|
| 655 |
+
return cutlass::complex<float>(float(r));
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
template <>
|
| 659 |
+
CUTLASS_HOST_DEVICE
|
| 660 |
+
cutlass::complex<double> from_real<cutlass::complex<double> >(double r) {
|
| 661 |
+
return cutlass::complex<double>(r);
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 665 |
+
|
| 666 |
+
template <typename T>
|
| 667 |
+
struct is_complex {
|
| 668 |
+
static bool const value = false;
|
| 669 |
+
};
|
| 670 |
+
|
| 671 |
+
template <typename T>
|
| 672 |
+
struct is_complex<complex<T>> {
|
| 673 |
+
static bool const value = true;
|
| 674 |
+
};
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 678 |
+
// functional.h numeric specializations
|
| 679 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 680 |
+
|
| 681 |
+
/// Squares with optional conversion
|
| 682 |
+
template <typename T, typename Output>
|
| 683 |
+
struct magnitude_squared<complex<T>, Output> {
|
| 684 |
+
CUTLASS_HOST_DEVICE
|
| 685 |
+
Output operator()(complex<T> lhs) const {
|
| 686 |
+
multiplies<Output> mul_op;
|
| 687 |
+
|
| 688 |
+
Output y_r = Output(lhs.real());
|
| 689 |
+
Output y_i = Output(lhs.imag());
|
| 690 |
+
|
| 691 |
+
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
|
| 692 |
+
}
|
| 693 |
+
};
|
| 694 |
+
|
| 695 |
+
/// Fused multiply-add
|
| 696 |
+
template <typename T>
|
| 697 |
+
struct multiply_add<complex<T>, complex<T>, complex<T>> {
|
| 698 |
+
CUTLASS_HOST_DEVICE
|
| 699 |
+
complex<T> operator()(
|
| 700 |
+
complex<T> const &a,
|
| 701 |
+
complex<T> const &b,
|
| 702 |
+
complex<T> const &c) const {
|
| 703 |
+
|
| 704 |
+
T real = c.real();
|
| 705 |
+
T imag = c.imag();
|
| 706 |
+
|
| 707 |
+
real += a.real() * b.real();
|
| 708 |
+
real += -a.imag() * b.imag();
|
| 709 |
+
imag += a.real() * b.imag();
|
| 710 |
+
imag += a.imag () * b.real();
|
| 711 |
+
|
| 712 |
+
return complex<T>{
|
| 713 |
+
real,
|
| 714 |
+
imag
|
| 715 |
+
};
|
| 716 |
+
}
|
| 717 |
+
};
|
| 718 |
+
|
| 719 |
+
/// Fused multiply-add
|
| 720 |
+
template <typename T>
|
| 721 |
+
struct multiply_add<complex<T>, T, complex<T>> {
|
| 722 |
+
CUTLASS_HOST_DEVICE
|
| 723 |
+
complex<T> operator()(
|
| 724 |
+
complex<T> const &a,
|
| 725 |
+
T const &b,
|
| 726 |
+
complex<T> const &c) const {
|
| 727 |
+
|
| 728 |
+
T real = c.real();
|
| 729 |
+
T imag = c.imag();
|
| 730 |
+
|
| 731 |
+
real += a.real() * b;
|
| 732 |
+
imag += a.imag () * b;
|
| 733 |
+
|
| 734 |
+
return complex<T>{
|
| 735 |
+
real,
|
| 736 |
+
imag
|
| 737 |
+
};
|
| 738 |
+
}
|
| 739 |
+
};
|
| 740 |
+
|
| 741 |
+
/// Fused multiply-add
|
| 742 |
+
template <typename T>
|
| 743 |
+
struct multiply_add<T, complex<T>, complex<T>> {
|
| 744 |
+
CUTLASS_HOST_DEVICE
|
| 745 |
+
complex<T> operator()(
|
| 746 |
+
T const &a,
|
| 747 |
+
complex<T> const &b,
|
| 748 |
+
complex<T> const &c) const {
|
| 749 |
+
|
| 750 |
+
T real = c.real();
|
| 751 |
+
T imag = c.imag();
|
| 752 |
+
|
| 753 |
+
real += a * b.real();
|
| 754 |
+
imag += a * b.imag();
|
| 755 |
+
|
| 756 |
+
return complex<T>{
|
| 757 |
+
real,
|
| 758 |
+
imag
|
| 759 |
+
};
|
| 760 |
+
}
|
| 761 |
+
};
|
| 762 |
+
|
| 763 |
+
/// Conjugate
|
| 764 |
+
template <typename T>
|
| 765 |
+
struct conjugate<complex<T>> {
|
| 766 |
+
CUTLASS_HOST_DEVICE
|
| 767 |
+
complex<T> operator()(complex<T> const &a) const {
|
| 768 |
+
// Invoke the complex<T> overload specifically, rather than
|
| 769 |
+
// wasting the compiler's effort on overload resolution.
|
| 770 |
+
return cutlass::conj(a);
|
| 771 |
+
}
|
| 772 |
+
};
|
| 773 |
+
|
| 774 |
+
#if ! defined(__CUDACC_RTC__)
|
| 775 |
+
template <>
|
| 776 |
+
struct conjugate<cuFloatComplex> {
|
| 777 |
+
CUTLASS_HOST_DEVICE
|
| 778 |
+
cuFloatComplex operator()(cuFloatComplex const& z) const {
|
| 779 |
+
return make_cuFloatComplex(z.x, -z.y);
|
| 780 |
+
}
|
| 781 |
+
};
|
| 782 |
+
|
| 783 |
+
template <>
|
| 784 |
+
struct conjugate<cuDoubleComplex> {
|
| 785 |
+
CUTLASS_HOST_DEVICE
|
| 786 |
+
cuDoubleComplex operator()(cuDoubleComplex const& z) const {
|
| 787 |
+
return make_cuDoubleComplex(z.x, -z.y);
|
| 788 |
+
}
|
| 789 |
+
};
|
| 790 |
+
#endif
|
| 791 |
+
|
| 792 |
+
/// Computes the square of a difference with optional conversion
|
| 793 |
+
template <typename T, typename Output>
|
| 794 |
+
struct magnitude_squared_difference<complex<T>, Output> {
|
| 795 |
+
CUTLASS_HOST_DEVICE
|
| 796 |
+
Output operator()(complex<T> lhs, complex<T> rhs) const {
|
| 797 |
+
multiplies<Output> mul_op;
|
| 798 |
+
|
| 799 |
+
Output y_r = Output(lhs.real()) - Output(rhs.real());
|
| 800 |
+
Output y_i = Output(lhs.imag()) - Output(rhs.imag());
|
| 801 |
+
|
| 802 |
+
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
|
| 803 |
+
}
|
| 804 |
+
};
|
| 805 |
+
|
| 806 |
+
/// Reduces value into the data pointed to by ptr (complex<T> specialization)
|
| 807 |
+
template <typename T>
|
| 808 |
+
struct atomic_add<complex<T>> {
|
| 809 |
+
CUTLASS_DEVICE
|
| 810 |
+
void operator()(complex<T> *ptr, const complex<T> &data)
|
| 811 |
+
{
|
| 812 |
+
data.red(ptr);
|
| 813 |
+
}
|
| 814 |
+
};
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 818 |
+
|
| 819 |
+
} // namespace cutlass
|
| 820 |
+
|
| 821 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h
ADDED
|
@@ -0,0 +1,1239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/* \file
|
| 33 |
+
\brief Boost-style constant definitions for floating-point types.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/complex.h"
|
| 42 |
+
|
| 43 |
+
///////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace constants {
|
| 47 |
+
|
| 48 |
+
///////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
//
|
| 51 |
+
// Primary templates
|
| 52 |
+
//
|
| 53 |
+
|
| 54 |
+
/// Returns 1, the multiplicative identity element
|
| 55 |
+
template <typename T> CUTLASS_HOST_DEVICE T one();
|
| 56 |
+
|
| 57 |
+
/// Returns 0, the additive identity element
|
| 58 |
+
template <typename T> CUTLASS_HOST_DEVICE T zero();
|
| 59 |
+
|
| 60 |
+
/// Returns 2
|
| 61 |
+
template <typename T> CUTLASS_HOST_DEVICE T two();
|
| 62 |
+
|
| 63 |
+
/// Returns pi, approximately 3.141
|
| 64 |
+
template <typename T> CUTLASS_HOST_DEVICE T pi();
|
| 65 |
+
|
| 66 |
+
/// Returns 2 * pi
|
| 67 |
+
template <typename T> CUTLASS_HOST_DEVICE T two_pi();
|
| 68 |
+
|
| 69 |
+
/// Returns pi / 2
|
| 70 |
+
template <typename T> CUTLASS_HOST_DEVICE T half_pi();
|
| 71 |
+
|
| 72 |
+
/// Returns sqrt(pi)
|
| 73 |
+
template <typename T> CUTLASS_HOST_DEVICE T root_pi();
|
| 74 |
+
|
| 75 |
+
/// Returns sqrt(pi / 2)
|
| 76 |
+
template <typename T> CUTLASS_HOST_DEVICE T root_half_pi();
|
| 77 |
+
|
| 78 |
+
/// Returns sqrt(2 * pi)
|
| 79 |
+
template <typename T> CUTLASS_HOST_DEVICE T root_two_pi();
|
| 80 |
+
|
| 81 |
+
/// Returns sqrt(ln(4))
|
| 82 |
+
template <typename T> CUTLASS_HOST_DEVICE T root_ln_four();
|
| 83 |
+
|
| 84 |
+
/// Returns e, approximately 2.718...
|
| 85 |
+
template <typename T> CUTLASS_HOST_DEVICE T e();
|
| 86 |
+
|
| 87 |
+
/// Returns (1/2)
|
| 88 |
+
template <typename T> CUTLASS_HOST_DEVICE T half();
|
| 89 |
+
|
| 90 |
+
/// Returns sqrt(2), approximately 1.414...
|
| 91 |
+
template <typename T> CUTLASS_HOST_DEVICE T root_two();
|
| 92 |
+
|
| 93 |
+
/// Returns sqrt(2)/2, approximately 0.707...
|
| 94 |
+
template <typename T> CUTLASS_HOST_DEVICE T half_root_two();
|
| 95 |
+
|
| 96 |
+
/// Returns ln(2), approximately 0.693...
|
| 97 |
+
template <typename T> CUTLASS_HOST_DEVICE T ln_two();
|
| 98 |
+
|
| 99 |
+
/// Returns ln(ln(2)), approximately -0.3665...
|
| 100 |
+
template <typename T> CUTLASS_HOST_DEVICE T ln_ln_two();
|
| 101 |
+
|
| 102 |
+
/// Returns 1/3, approximately 0.333...
|
| 103 |
+
template <typename T> CUTLASS_HOST_DEVICE T third();
|
| 104 |
+
|
| 105 |
+
/// Returns 2/3, approximately 0.666...
|
| 106 |
+
template <typename T> CUTLASS_HOST_DEVICE T twothirds();
|
| 107 |
+
|
| 108 |
+
/// Returns pi - 3, approximately 0.1416...
|
| 109 |
+
template <typename T> CUTLASS_HOST_DEVICE T pi_minus_three();
|
| 110 |
+
|
| 111 |
+
/// Returns 4 - pi, approximately 0.858...
|
| 112 |
+
template <typename T> CUTLASS_HOST_DEVICE T four_minus_pi();
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
/////////////////////////////////////////////////////////////////////////////////////
|
| 116 |
+
|
| 117 |
+
// Specialization for double
|
| 118 |
+
|
| 119 |
+
/// Returns 1, the multiplicative identity element (specialization for double)
|
| 120 |
+
template <> CUTLASS_HOST_DEVICE double one<double>() {
|
| 121 |
+
uint64_t bits = 0x3ff0000000000000ull;
|
| 122 |
+
return reinterpret_cast<double const &>(bits);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
/// Returns 1, the multiplicative identity element (specialization for complex<double>)
|
| 126 |
+
template <> CUTLASS_HOST_DEVICE complex<double> one< complex<double> >() {
|
| 127 |
+
return complex<double>(one<double>(), double());
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Returns 0, the additive identity element (specialization for double)
|
| 131 |
+
template <> CUTLASS_HOST_DEVICE double zero<double>() {
|
| 132 |
+
uint64_t bits = 0x0ull;
|
| 133 |
+
return reinterpret_cast<double const &>(bits);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// Returns 0, the additive identity element (specialization for complex<double>)
|
| 137 |
+
template <> CUTLASS_HOST_DEVICE complex<double> zero< complex<double> >() {
|
| 138 |
+
return complex<double>(zero<double>(), double());
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Returns 2 (specialization for double)
|
| 142 |
+
template <> CUTLASS_HOST_DEVICE double two<double>() {
|
| 143 |
+
uint64_t bits = 0x4000000000000000ull;
|
| 144 |
+
return reinterpret_cast<double const &>(bits);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/// Returns 2 (specialization for complex<double>)
|
| 148 |
+
template <> CUTLASS_HOST_DEVICE complex<double> two< complex<double> >() {
|
| 149 |
+
return complex<double>(two<double>(), double());
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Returns pi, approximately 3.141 (specialization for double)
|
| 153 |
+
template <> CUTLASS_HOST_DEVICE double pi<double>() {
|
| 154 |
+
uint64_t bits = 0x400921fb54442d18ull;
|
| 155 |
+
return reinterpret_cast<double const &>(bits);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
/// Returns pi, approximately 3.141 (specialization for complex<double>)
|
| 159 |
+
template <> CUTLASS_HOST_DEVICE complex<double> pi< complex<double> >() {
|
| 160 |
+
return complex<double>(pi<double>(), double());
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Returns 2 * pi (specialization for double)
|
| 164 |
+
template <> CUTLASS_HOST_DEVICE double two_pi<double>() {
|
| 165 |
+
uint64_t bits = 0x401921fb54442d18ull;
|
| 166 |
+
return reinterpret_cast<double const &>(bits);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// Returns 2 * pi (specialization for complex<double>)
|
| 170 |
+
template <> CUTLASS_HOST_DEVICE complex<double> two_pi< complex<double> >() {
|
| 171 |
+
return complex<double>(two_pi<double>(), double());
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
/// Returns pi / 2 (specialization for double)
|
| 175 |
+
template <> CUTLASS_HOST_DEVICE double half_pi<double>() {
|
| 176 |
+
uint64_t bits = 0x3ff921fb54442d18ull;
|
| 177 |
+
return reinterpret_cast<double const &>(bits);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/// Returns pi / 2 (specialization for complex<double>)
|
| 181 |
+
template <> CUTLASS_HOST_DEVICE complex<double> half_pi< complex<double> >() {
|
| 182 |
+
return complex<double>(half_pi<double>(), double());
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
/// Returns sqrt(pi) (specialization for double)
|
| 186 |
+
template <> CUTLASS_HOST_DEVICE double root_pi<double>() {
|
| 187 |
+
uint64_t bits = 0x3ffc5bf891b4ef6aull;
|
| 188 |
+
return reinterpret_cast<double const &>(bits);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
/// Returns sqrt(pi) (specialization for complex<double>)
|
| 192 |
+
template <> CUTLASS_HOST_DEVICE complex<double> root_pi< complex<double> >() {
|
| 193 |
+
return complex<double>(root_pi<double>(), double());
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
/// Returns sqrt(pi / 2) (specialization for double)
|
| 197 |
+
template <> CUTLASS_HOST_DEVICE double root_half_pi<double>() {
|
| 198 |
+
uint64_t bits = 0x3ff40d931ff62705ull;
|
| 199 |
+
return reinterpret_cast<double const &>(bits);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
/// Returns sqrt(pi / 2) (specialization for complex<double>)
|
| 203 |
+
template <> CUTLASS_HOST_DEVICE complex<double> root_half_pi< complex<double> >() {
|
| 204 |
+
return complex<double>(root_half_pi<double>(), double());
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/// Returns sqrt(2 * pi) (specialization for double)
|
| 208 |
+
template <> CUTLASS_HOST_DEVICE double root_two_pi<double>() {
|
| 209 |
+
uint64_t bits = 0x40040d931ff62705ull;
|
| 210 |
+
return reinterpret_cast<double const &>(bits);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Returns sqrt(2 * pi) (specialization for complex<double>)
|
| 214 |
+
template <> CUTLASS_HOST_DEVICE complex<double> root_two_pi< complex<double> >() {
|
| 215 |
+
return complex<double>(root_two_pi<double>(), double());
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
/// Returns sqrt(ln(4)) (specialization for double)
|
| 219 |
+
template <> CUTLASS_HOST_DEVICE double root_ln_four<double>() {
|
| 220 |
+
uint64_t bits = 0x3ff2d6abe44afc43ull;
|
| 221 |
+
return reinterpret_cast<double const &>(bits);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
/// Returns sqrt(ln(4)) (specialization for complex<double>)
|
| 225 |
+
template <> CUTLASS_HOST_DEVICE complex<double> root_ln_four< complex<double> >() {
|
| 226 |
+
return complex<double>(root_ln_four<double>(), double());
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/// Returns e, approximately 2.718... (specialization for double)
|
| 230 |
+
template <> CUTLASS_HOST_DEVICE double e<double>() {
|
| 231 |
+
uint64_t bits = 0x4005bf0a8b145769ull;
|
| 232 |
+
return reinterpret_cast<double const &>(bits);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
/// Returns e, approximately 2.718... (specialization for complex<double>)
|
| 236 |
+
template <> CUTLASS_HOST_DEVICE complex<double> e< complex<double> >() {
|
| 237 |
+
return complex<double>(e<double>(), double());
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Returns (1/2) (specialization for double)
|
| 241 |
+
template <> CUTLASS_HOST_DEVICE double half<double>() {
|
| 242 |
+
uint64_t bits = 0x3fe0000000000000ull;
|
| 243 |
+
return reinterpret_cast<double const &>(bits);
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
/// Returns (1/2) (specialization for complex<double>)
|
| 247 |
+
template <> CUTLASS_HOST_DEVICE complex<double> half< complex<double> >() {
|
| 248 |
+
return complex<double>(half<double>(), double());
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for double)
|
| 252 |
+
template <> CUTLASS_HOST_DEVICE double root_two<double>() {
|
| 253 |
+
uint64_t bits = 0x3ff6a09e667f3bcdull;
|
| 254 |
+
return reinterpret_cast<double const &>(bits);
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for complex<double>)
|
| 258 |
+
template <> CUTLASS_HOST_DEVICE complex<double> root_two< complex<double> >() {
|
| 259 |
+
return complex<double>(root_two<double>(), double());
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for double)
|
| 263 |
+
template <> CUTLASS_HOST_DEVICE double half_root_two<double>() {
|
| 264 |
+
uint64_t bits = 0x3fe6a09e667f3bcdull;
|
| 265 |
+
return reinterpret_cast<double const &>(bits);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<double>)
|
| 269 |
+
template <> CUTLASS_HOST_DEVICE complex<double> half_root_two< complex<double> >() {
|
| 270 |
+
return complex<double>(half_root_two<double>(), double());
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/// Returns ln(2), approximately 0.693... (specialization for double)
|
| 274 |
+
template <> CUTLASS_HOST_DEVICE double ln_two<double>() {
|
| 275 |
+
uint64_t bits = 0x3fe62e42fefa39efull;
|
| 276 |
+
return reinterpret_cast<double const &>(bits);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/// Returns ln(2), approximately 0.693... (specialization for complex<double>)
|
| 280 |
+
template <> CUTLASS_HOST_DEVICE complex<double> ln_two< complex<double> >() {
|
| 281 |
+
return complex<double>(ln_two<double>(), double());
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for double)
|
| 285 |
+
template <> CUTLASS_HOST_DEVICE double ln_ln_two<double>() {
|
| 286 |
+
uint64_t bits = 0xbfd774f29bdd6b9full;
|
| 287 |
+
return reinterpret_cast<double const &>(bits);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<double>)
|
| 291 |
+
template <> CUTLASS_HOST_DEVICE complex<double> ln_ln_two< complex<double> >() {
|
| 292 |
+
return complex<double>(ln_ln_two<double>(), double());
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
/// Returns 1/3, approximately 0.333... (specialization for double)
|
| 296 |
+
template <> CUTLASS_HOST_DEVICE double third<double>() {
|
| 297 |
+
uint64_t bits = 0x3fd5555555555555ull;
|
| 298 |
+
return reinterpret_cast<double const &>(bits);
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
/// Returns 1/3, approximately 0.333... (specialization for complex<double>)
|
| 302 |
+
template <> CUTLASS_HOST_DEVICE complex<double> third< complex<double> >() {
|
| 303 |
+
return complex<double>(third<double>(), double());
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
/// Returns 2/3, approximately 0.666... (specialization for double)
|
| 307 |
+
template <> CUTLASS_HOST_DEVICE double twothirds<double>() {
|
| 308 |
+
uint64_t bits = 0x3fe5555555555555ull;
|
| 309 |
+
return reinterpret_cast<double const &>(bits);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
/// Returns 2/3, approximately 0.666... (specialization for complex<double>)
|
| 313 |
+
template <> CUTLASS_HOST_DEVICE complex<double> twothirds< complex<double> >() {
|
| 314 |
+
return complex<double>(twothirds<double>(), double());
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for double)
|
| 318 |
+
template <> CUTLASS_HOST_DEVICE double pi_minus_three<double>() {
|
| 319 |
+
uint64_t bits = 0x3fc21fb54442d180ull;
|
| 320 |
+
return reinterpret_cast<double const &>(bits);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for complex<double>)
|
| 324 |
+
template <> CUTLASS_HOST_DEVICE complex<double> pi_minus_three< complex<double> >() {
|
| 325 |
+
return complex<double>(pi_minus_three<double>(), double());
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for double)
|
| 329 |
+
template <> CUTLASS_HOST_DEVICE double four_minus_pi<double>() {
|
| 330 |
+
uint64_t bits = 0x3feb7812aeef4ba0ull;
|
| 331 |
+
return reinterpret_cast<double const &>(bits);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for complex<double>)
|
| 335 |
+
template <> CUTLASS_HOST_DEVICE complex<double> four_minus_pi< complex<double> >() {
|
| 336 |
+
return complex<double>(four_minus_pi<double>(), double());
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
/////////////////////////////////////////////////////////////////////////////////////
|
| 340 |
+
|
| 341 |
+
// Specialization for float
|
| 342 |
+
|
| 343 |
+
/// Returns 1, the multiplicative identity element (specialization for float)
|
| 344 |
+
template <> CUTLASS_HOST_DEVICE float one<float>() {
|
| 345 |
+
uint32_t bits = 0x3f800000u;
|
| 346 |
+
return reinterpret_cast<float const &>(bits);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
/// Returns 1, the multiplicative identity element (specialization for complex<float>)
|
| 350 |
+
template <> CUTLASS_HOST_DEVICE complex<float> one< complex<float> >() {
|
| 351 |
+
return complex<float>(one<float>(), float());
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
/// Returns 0, the additive identity element (specialization for float)
|
| 355 |
+
template <> CUTLASS_HOST_DEVICE float zero<float>() {
|
| 356 |
+
uint32_t bits = 0x0u;
|
| 357 |
+
return reinterpret_cast<float const &>(bits);
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
/// Returns 0, the additive identity element (specialization for complex<float>)
|
| 361 |
+
template <> CUTLASS_HOST_DEVICE complex<float> zero< complex<float> >() {
|
| 362 |
+
return complex<float>(zero<float>(), float());
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
/// Returns 2 (specialization for float)
|
| 366 |
+
template <> CUTLASS_HOST_DEVICE float two<float>() {
|
| 367 |
+
uint32_t bits = 0x40000000u;
|
| 368 |
+
return reinterpret_cast<float const &>(bits);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
/// Returns 2 (specialization for complex<float>)
|
| 372 |
+
template <> CUTLASS_HOST_DEVICE complex<float> two< complex<float> >() {
|
| 373 |
+
return complex<float>(two<float>(), float());
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
/// Returns pi, approximately 3.141 (specialization for float)
|
| 377 |
+
template <> CUTLASS_HOST_DEVICE float pi<float>() {
|
| 378 |
+
uint32_t bits = 0x40490fdbu;
|
| 379 |
+
return reinterpret_cast<float const &>(bits);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
/// Returns pi, approximately 3.141 (specialization for complex<float>)
|
| 383 |
+
template <> CUTLASS_HOST_DEVICE complex<float> pi< complex<float> >() {
|
| 384 |
+
return complex<float>(pi<float>(), float());
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
/// Returns 2 * pi (specialization for float)
|
| 388 |
+
template <> CUTLASS_HOST_DEVICE float two_pi<float>() {
|
| 389 |
+
uint32_t bits = 0x40c90fdbu;
|
| 390 |
+
return reinterpret_cast<float const &>(bits);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
/// Returns 2 * pi (specialization for complex<float>)
|
| 394 |
+
template <> CUTLASS_HOST_DEVICE complex<float> two_pi< complex<float> >() {
|
| 395 |
+
return complex<float>(two_pi<float>(), float());
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
/// Returns pi / 2 (specialization for float)
|
| 399 |
+
template <> CUTLASS_HOST_DEVICE float half_pi<float>() {
|
| 400 |
+
uint32_t bits = 0x3fc90fdbu;
|
| 401 |
+
return reinterpret_cast<float const &>(bits);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
/// Returns pi / 2 (specialization for complex<float>)
|
| 405 |
+
template <> CUTLASS_HOST_DEVICE complex<float> half_pi< complex<float> >() {
|
| 406 |
+
return complex<float>(half_pi<float>(), float());
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
/// Returns sqrt(pi) (specialization for float)
|
| 410 |
+
template <> CUTLASS_HOST_DEVICE float root_pi<float>() {
|
| 411 |
+
uint32_t bits = 0x3fe2dfc5u;
|
| 412 |
+
return reinterpret_cast<float const &>(bits);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
/// Returns sqrt(pi) (specialization for complex<float>)
|
| 416 |
+
template <> CUTLASS_HOST_DEVICE complex<float> root_pi< complex<float> >() {
|
| 417 |
+
return complex<float>(root_pi<float>(), float());
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
/// Returns sqrt(pi / 2) (specialization for float)
|
| 421 |
+
template <> CUTLASS_HOST_DEVICE float root_half_pi<float>() {
|
| 422 |
+
uint32_t bits = 0x3fa06c99u;
|
| 423 |
+
return reinterpret_cast<float const &>(bits);
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
/// Returns sqrt(pi / 2) (specialization for complex<float>)
|
| 427 |
+
template <> CUTLASS_HOST_DEVICE complex<float> root_half_pi< complex<float> >() {
|
| 428 |
+
return complex<float>(root_half_pi<float>(), float());
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Returns sqrt(2 * pi) (specialization for float)
|
| 432 |
+
template <> CUTLASS_HOST_DEVICE float root_two_pi<float>() {
|
| 433 |
+
uint32_t bits = 0x40206c99u;
|
| 434 |
+
return reinterpret_cast<float const &>(bits);
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/// Returns sqrt(2 * pi) (specialization for complex<float>)
|
| 438 |
+
template <> CUTLASS_HOST_DEVICE complex<float> root_two_pi< complex<float> >() {
|
| 439 |
+
return complex<float>(root_two_pi<float>(), float());
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
/// Returns sqrt(ln(4)) (specialization for float)
|
| 443 |
+
template <> CUTLASS_HOST_DEVICE float root_ln_four<float>() {
|
| 444 |
+
uint32_t bits = 0x3f96b55fu;
|
| 445 |
+
return reinterpret_cast<float const &>(bits);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
/// Returns sqrt(ln(4)) (specialization for complex<float>)
|
| 449 |
+
template <> CUTLASS_HOST_DEVICE complex<float> root_ln_four< complex<float> >() {
|
| 450 |
+
return complex<float>(root_ln_four<float>(), float());
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
/// Returns e, approximately 2.718... (specialization for float)
|
| 454 |
+
template <> CUTLASS_HOST_DEVICE float e<float>() {
|
| 455 |
+
uint32_t bits = 0x402df854u;
|
| 456 |
+
return reinterpret_cast<float const &>(bits);
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
/// Returns e, approximately 2.718... (specialization for complex<float>)
|
| 460 |
+
template <> CUTLASS_HOST_DEVICE complex<float> e< complex<float> >() {
|
| 461 |
+
return complex<float>(e<float>(), float());
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
/// Returns (1/2) (specialization for float)
|
| 465 |
+
template <> CUTLASS_HOST_DEVICE float half<float>() {
|
| 466 |
+
uint32_t bits = 0x3f000000u;
|
| 467 |
+
return reinterpret_cast<float const &>(bits);
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
/// Returns (1/2) (specialization for complex<float>)
|
| 471 |
+
template <> CUTLASS_HOST_DEVICE complex<float> half< complex<float> >() {
|
| 472 |
+
return complex<float>(half<float>(), float());
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for float)
|
| 476 |
+
template <> CUTLASS_HOST_DEVICE float root_two<float>() {
|
| 477 |
+
uint32_t bits = 0x3fb504f3u;
|
| 478 |
+
return reinterpret_cast<float const &>(bits);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for complex<float>)
|
| 482 |
+
template <> CUTLASS_HOST_DEVICE complex<float> root_two< complex<float> >() {
|
| 483 |
+
return complex<float>(root_two<float>(), float());
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for float)
|
| 487 |
+
template <> CUTLASS_HOST_DEVICE float half_root_two<float>() {
|
| 488 |
+
uint32_t bits = 0x3f3504f3u;
|
| 489 |
+
return reinterpret_cast<float const &>(bits);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<float>)
|
| 493 |
+
template <> CUTLASS_HOST_DEVICE complex<float> half_root_two< complex<float> >() {
|
| 494 |
+
return complex<float>(half_root_two<float>(), float());
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
/// Returns ln(2), approximately 0.693... (specialization for float)
|
| 498 |
+
template <> CUTLASS_HOST_DEVICE float ln_two<float>() {
|
| 499 |
+
uint32_t bits = 0x3f317218u;
|
| 500 |
+
return reinterpret_cast<float const &>(bits);
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
/// Returns ln(2), approximately 0.693... (specialization for complex<float>)
|
| 504 |
+
template <> CUTLASS_HOST_DEVICE complex<float> ln_two< complex<float> >() {
|
| 505 |
+
return complex<float>(ln_two<float>(), float());
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for float)
|
| 509 |
+
template <> CUTLASS_HOST_DEVICE float ln_ln_two<float>() {
|
| 510 |
+
uint32_t bits = 0xbebba795u;
|
| 511 |
+
return reinterpret_cast<float const &>(bits);
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<float>)
|
| 515 |
+
template <> CUTLASS_HOST_DEVICE complex<float> ln_ln_two< complex<float> >() {
|
| 516 |
+
return complex<float>(ln_ln_two<float>(), float());
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
/// Returns 1/3, approximately 0.333... (specialization for float)
|
| 520 |
+
template <> CUTLASS_HOST_DEVICE float third<float>() {
|
| 521 |
+
uint32_t bits = 0x3eaaaaabu;
|
| 522 |
+
return reinterpret_cast<float const &>(bits);
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
/// Returns 1/3, approximately 0.333... (specialization for complex<float>)
|
| 526 |
+
template <> CUTLASS_HOST_DEVICE complex<float> third< complex<float> >() {
|
| 527 |
+
return complex<float>(third<float>(), float());
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
/// Returns 2/3, approximately 0.666... (specialization for float)
|
| 531 |
+
template <> CUTLASS_HOST_DEVICE float twothirds<float>() {
|
| 532 |
+
uint32_t bits = 0x3f2aaaabu;
|
| 533 |
+
return reinterpret_cast<float const &>(bits);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
/// Returns 2/3, approximately 0.666... (specialization for complex<float>)
|
| 537 |
+
template <> CUTLASS_HOST_DEVICE complex<float> twothirds< complex<float> >() {
|
| 538 |
+
return complex<float>(twothirds<float>(), float());
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for float)
|
| 542 |
+
template <> CUTLASS_HOST_DEVICE float pi_minus_three<float>() {
|
| 543 |
+
uint32_t bits = 0x3e10fdaau;
|
| 544 |
+
return reinterpret_cast<float const &>(bits);
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for complex<float>)
|
| 548 |
+
template <> CUTLASS_HOST_DEVICE complex<float> pi_minus_three< complex<float> >() {
|
| 549 |
+
return complex<float>(pi_minus_three<float>(), float());
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for float)
|
| 553 |
+
template <> CUTLASS_HOST_DEVICE float four_minus_pi<float>() {
|
| 554 |
+
uint32_t bits = 0x3f5bc095u;
|
| 555 |
+
return reinterpret_cast<float const &>(bits);
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for complex<float>)
|
| 559 |
+
template <> CUTLASS_HOST_DEVICE complex<float> four_minus_pi< complex<float> >() {
|
| 560 |
+
return complex<float>(four_minus_pi<float>(), float());
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
/////////////////////////////////////////////////////////////////////////////////////
|
| 564 |
+
|
| 565 |
+
// Specialization for tfloat32_t
|
| 566 |
+
|
| 567 |
+
/// Returns 1, the multiplicative identity element (specialization for tfloat32_t)
|
| 568 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t one<tfloat32_t>() {
|
| 569 |
+
uint32_t bits = 0x3f801000u;
|
| 570 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
/// Returns 1, the multiplicative identity element (specialization for complex<tfloat32_t>)
|
| 574 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> one< complex<tfloat32_t> >() {
|
| 575 |
+
return complex<tfloat32_t>(one<tfloat32_t>(), tfloat32_t());
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
/// Returns 0, the additive identity element (specialization for tfloat32_t)
|
| 579 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t zero<tfloat32_t>() {
|
| 580 |
+
uint32_t bits = 0x1000u;
|
| 581 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
/// Returns 0, the additive identity element (specialization for complex<tfloat32_t>)
|
| 585 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> zero< complex<tfloat32_t> >() {
|
| 586 |
+
return complex<tfloat32_t>(zero<tfloat32_t>(), tfloat32_t());
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
/// Returns 2 (specialization for tfloat32_t)
|
| 590 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t two<tfloat32_t>() {
|
| 591 |
+
uint32_t bits = 0x40001000u;
|
| 592 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
/// Returns 2 (specialization for complex<tfloat32_t>)
|
| 596 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> two< complex<tfloat32_t> >() {
|
| 597 |
+
return complex<tfloat32_t>(two<tfloat32_t>(), tfloat32_t());
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
/// Returns pi, approximately 3.141 (specialization for tfloat32_t)
|
| 601 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t pi<tfloat32_t>() {
|
| 602 |
+
uint32_t bits = 0x40491fdbu;
|
| 603 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
/// Returns pi, approximately 3.141 (specialization for complex<tfloat32_t>)
|
| 607 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> pi< complex<tfloat32_t> >() {
|
| 608 |
+
return complex<tfloat32_t>(pi<tfloat32_t>(), tfloat32_t());
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
/// Returns 2 * pi (specialization for tfloat32_t)
|
| 612 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi<tfloat32_t>() {
|
| 613 |
+
uint32_t bits = 0x40c91fdbu;
|
| 614 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
/// Returns 2 * pi (specialization for complex<tfloat32_t>)
|
| 618 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> two_pi< complex<tfloat32_t> >() {
|
| 619 |
+
return complex<tfloat32_t>(two_pi<tfloat32_t>(), tfloat32_t());
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
/// Returns pi / 2 (specialization for tfloat32_t)
|
| 623 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi<tfloat32_t>() {
|
| 624 |
+
uint32_t bits = 0x3fc91fdbu;
|
| 625 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
/// Returns pi / 2 (specialization for complex<tfloat32_t>)
|
| 629 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> half_pi< complex<tfloat32_t> >() {
|
| 630 |
+
return complex<tfloat32_t>(half_pi<tfloat32_t>(), tfloat32_t());
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
/// Returns sqrt(pi) (specialization for tfloat32_t)
|
| 634 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi<tfloat32_t>() {
|
| 635 |
+
uint32_t bits = 0x3fe2efc5u;
|
| 636 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
/// Returns sqrt(pi) (specialization for complex<tfloat32_t>)
|
| 640 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_pi< complex<tfloat32_t> >() {
|
| 641 |
+
return complex<tfloat32_t>(root_pi<tfloat32_t>(), tfloat32_t());
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
/// Returns sqrt(pi / 2) (specialization for tfloat32_t)
|
| 645 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi<tfloat32_t>() {
|
| 646 |
+
uint32_t bits = 0x3fa07c99u;
|
| 647 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
/// Returns sqrt(pi / 2) (specialization for complex<tfloat32_t>)
|
| 651 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_half_pi< complex<tfloat32_t> >() {
|
| 652 |
+
return complex<tfloat32_t>(root_half_pi<tfloat32_t>(), tfloat32_t());
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
/// Returns sqrt(2 * pi) (specialization for tfloat32_t)
|
| 656 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi<tfloat32_t>() {
|
| 657 |
+
uint32_t bits = 0x40207c99u;
|
| 658 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
/// Returns sqrt(2 * pi) (specialization for complex<tfloat32_t>)
|
| 662 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_two_pi< complex<tfloat32_t> >() {
|
| 663 |
+
return complex<tfloat32_t>(root_two_pi<tfloat32_t>(), tfloat32_t());
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
/// Returns sqrt(ln(4)) (specialization for tfloat32_t)
|
| 667 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four<tfloat32_t>() {
|
| 668 |
+
uint32_t bits = 0x3f96c55fu;
|
| 669 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
/// Returns sqrt(ln(4)) (specialization for complex<tfloat32_t>)
|
| 673 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_ln_four< complex<tfloat32_t> >() {
|
| 674 |
+
return complex<tfloat32_t>(root_ln_four<tfloat32_t>(), tfloat32_t());
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
/// Returns e, approximately 2.718... (specialization for tfloat32_t)
|
| 678 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t e<tfloat32_t>() {
|
| 679 |
+
uint32_t bits = 0x402e0854u;
|
| 680 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
/// Returns e, approximately 2.718... (specialization for complex<tfloat32_t>)
|
| 684 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> e< complex<tfloat32_t> >() {
|
| 685 |
+
return complex<tfloat32_t>(e<tfloat32_t>(), tfloat32_t());
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
/// Returns (1/2) (specialization for tfloat32_t)
|
| 689 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t half<tfloat32_t>() {
|
| 690 |
+
uint32_t bits = 0x3f001000u;
|
| 691 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
/// Returns (1/2) (specialization for complex<tfloat32_t>)
|
| 695 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> half< complex<tfloat32_t> >() {
|
| 696 |
+
return complex<tfloat32_t>(half<tfloat32_t>(), tfloat32_t());
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t)
|
| 700 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t root_two<tfloat32_t>() {
|
| 701 |
+
uint32_t bits = 0x3fb514f3u;
|
| 702 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for complex<tfloat32_t>)
|
| 706 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> root_two< complex<tfloat32_t> >() {
|
| 707 |
+
return complex<tfloat32_t>(root_two<tfloat32_t>(), tfloat32_t());
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t)
|
| 711 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two<tfloat32_t>() {
|
| 712 |
+
uint32_t bits = 0x3f3514f3u;
|
| 713 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<tfloat32_t>)
|
| 717 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> half_root_two< complex<tfloat32_t> >() {
|
| 718 |
+
return complex<tfloat32_t>(half_root_two<tfloat32_t>(), tfloat32_t());
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t)
|
| 722 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two<tfloat32_t>() {
|
| 723 |
+
uint32_t bits = 0x3f318218u;
|
| 724 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
/// Returns ln(2), approximately 0.693... (specialization for complex<tfloat32_t>)
|
| 728 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> ln_two< complex<tfloat32_t> >() {
|
| 729 |
+
return complex<tfloat32_t>(ln_two<tfloat32_t>(), tfloat32_t());
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t)
|
| 733 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two<tfloat32_t>() {
|
| 734 |
+
uint32_t bits = 0xbebbb795u;
|
| 735 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<tfloat32_t>)
|
| 739 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> ln_ln_two< complex<tfloat32_t> >() {
|
| 740 |
+
return complex<tfloat32_t>(ln_ln_two<tfloat32_t>(), tfloat32_t());
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t)
|
| 744 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t third<tfloat32_t>() {
|
| 745 |
+
uint32_t bits = 0x3eaabaabu;
|
| 746 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
/// Returns 1/3, approximately 0.333... (specialization for complex<tfloat32_t>)
|
| 750 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> third< complex<tfloat32_t> >() {
|
| 751 |
+
return complex<tfloat32_t>(third<tfloat32_t>(), tfloat32_t());
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t)
|
| 755 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds<tfloat32_t>() {
|
| 756 |
+
uint32_t bits = 0x3f2abaabu;
|
| 757 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
/// Returns 2/3, approximately 0.666... (specialization for complex<tfloat32_t>)
|
| 761 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> twothirds< complex<tfloat32_t> >() {
|
| 762 |
+
return complex<tfloat32_t>(twothirds<tfloat32_t>(), tfloat32_t());
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t)
|
| 766 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three<tfloat32_t>() {
|
| 767 |
+
uint32_t bits = 0x3e110daau;
|
| 768 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for complex<tfloat32_t>)
|
| 772 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> pi_minus_three< complex<tfloat32_t> >() {
|
| 773 |
+
return complex<tfloat32_t>(pi_minus_three<tfloat32_t>(), tfloat32_t());
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t)
|
| 777 |
+
template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi<tfloat32_t>() {
|
| 778 |
+
uint32_t bits = 0x3f5bd095u;
|
| 779 |
+
return reinterpret_cast<tfloat32_t const &>(bits);
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for complex<tfloat32_t>)
|
| 783 |
+
template <> CUTLASS_HOST_DEVICE complex<tfloat32_t> four_minus_pi< complex<tfloat32_t> >() {
|
| 784 |
+
return complex<tfloat32_t>(four_minus_pi<tfloat32_t>(), tfloat32_t());
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
/////////////////////////////////////////////////////////////////////////////////////
|
| 788 |
+
|
| 789 |
+
// Specialization for half_t
|
| 790 |
+
|
| 791 |
+
/// Returns 1, the multiplicative identity element (specialization for half_t)
|
| 792 |
+
template <> CUTLASS_HOST_DEVICE half_t one<half_t>() {
|
| 793 |
+
uint16_t bits = 0x3c00u;
|
| 794 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
/// Returns 1, the multiplicative identity element (specialization for complex<half_t>)
|
| 798 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> one< complex<half_t> >() {
|
| 799 |
+
return complex<half_t>(one<half_t>(), half_t());
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
/// Returns 0, the additive identity element (specialization for half_t)
|
| 803 |
+
template <> CUTLASS_HOST_DEVICE half_t zero<half_t>() {
|
| 804 |
+
uint16_t bits = 0x0u;
|
| 805 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
/// Returns 0, the additive identity element (specialization for complex<half_t>)
|
| 809 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> zero< complex<half_t> >() {
|
| 810 |
+
return complex<half_t>(zero<half_t>(), half_t());
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
/// Returns 2 (specialization for half_t)
|
| 814 |
+
template <> CUTLASS_HOST_DEVICE half_t two<half_t>() {
|
| 815 |
+
uint16_t bits = 0x4000u;
|
| 816 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
/// Returns 2 (specialization for complex<half_t>)
|
| 820 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> two< complex<half_t> >() {
|
| 821 |
+
return complex<half_t>(two<half_t>(), half_t());
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
/// Returns pi, approximately 3.141 (specialization for half_t)
|
| 825 |
+
template <> CUTLASS_HOST_DEVICE half_t pi<half_t>() {
|
| 826 |
+
uint16_t bits = 0x4248u;
|
| 827 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
/// Returns pi, approximately 3.141 (specialization for complex<half_t>)
|
| 831 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> pi< complex<half_t> >() {
|
| 832 |
+
return complex<half_t>(pi<half_t>(), half_t());
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
/// Returns 2 * pi (specialization for half_t)
|
| 836 |
+
template <> CUTLASS_HOST_DEVICE half_t two_pi<half_t>() {
|
| 837 |
+
uint16_t bits = 0x4648u;
|
| 838 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
/// Returns 2 * pi (specialization for complex<half_t>)
|
| 842 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> two_pi< complex<half_t> >() {
|
| 843 |
+
return complex<half_t>(two_pi<half_t>(), half_t());
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
/// Returns pi / 2 (specialization for half_t)
|
| 847 |
+
template <> CUTLASS_HOST_DEVICE half_t half_pi<half_t>() {
|
| 848 |
+
uint16_t bits = 0x3e48u;
|
| 849 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
/// Returns pi / 2 (specialization for complex<half_t>)
|
| 853 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> half_pi< complex<half_t> >() {
|
| 854 |
+
return complex<half_t>(half_pi<half_t>(), half_t());
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
/// Returns sqrt(pi) (specialization for half_t)
|
| 858 |
+
template <> CUTLASS_HOST_DEVICE half_t root_pi<half_t>() {
|
| 859 |
+
uint16_t bits = 0x3f17u;
|
| 860 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
/// Returns sqrt(pi) (specialization for complex<half_t>)
|
| 864 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> root_pi< complex<half_t> >() {
|
| 865 |
+
return complex<half_t>(root_pi<half_t>(), half_t());
|
| 866 |
+
}
|
| 867 |
+
|
| 868 |
+
/// Returns sqrt(pi / 2) (specialization for half_t)
|
| 869 |
+
template <> CUTLASS_HOST_DEVICE half_t root_half_pi<half_t>() {
|
| 870 |
+
uint16_t bits = 0x3d03u;
|
| 871 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 872 |
+
}
|
| 873 |
+
|
| 874 |
+
/// Returns sqrt(pi / 2) (specialization for complex<half_t>)
|
| 875 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> root_half_pi< complex<half_t> >() {
|
| 876 |
+
return complex<half_t>(root_half_pi<half_t>(), half_t());
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
/// Returns sqrt(2 * pi) (specialization for half_t)
|
| 880 |
+
template <> CUTLASS_HOST_DEVICE half_t root_two_pi<half_t>() {
|
| 881 |
+
uint16_t bits = 0x4103u;
|
| 882 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
/// Returns sqrt(2 * pi) (specialization for complex<half_t>)
|
| 886 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> root_two_pi< complex<half_t> >() {
|
| 887 |
+
return complex<half_t>(root_two_pi<half_t>(), half_t());
|
| 888 |
+
}
|
| 889 |
+
|
| 890 |
+
/// Returns sqrt(ln(4)) (specialization for half_t)
|
| 891 |
+
template <> CUTLASS_HOST_DEVICE half_t root_ln_four<half_t>() {
|
| 892 |
+
uint16_t bits = 0x3cb6u;
|
| 893 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
/// Returns sqrt(ln(4)) (specialization for complex<half_t>)
|
| 897 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> root_ln_four< complex<half_t> >() {
|
| 898 |
+
return complex<half_t>(root_ln_four<half_t>(), half_t());
|
| 899 |
+
}
|
| 900 |
+
|
| 901 |
+
/// Returns e, approximately 2.718... (specialization for half_t)
|
| 902 |
+
template <> CUTLASS_HOST_DEVICE half_t e<half_t>() {
|
| 903 |
+
uint16_t bits = 0x4170u;
|
| 904 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
/// Returns e, approximately 2.718... (specialization for complex<half_t>)
|
| 908 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> e< complex<half_t> >() {
|
| 909 |
+
return complex<half_t>(e<half_t>(), half_t());
|
| 910 |
+
}
|
| 911 |
+
|
| 912 |
+
/// Returns (1/2) (specialization for half_t)
|
| 913 |
+
template <> CUTLASS_HOST_DEVICE half_t half<half_t>() {
|
| 914 |
+
uint16_t bits = 0x3800u;
|
| 915 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
/// Returns (1/2) (specialization for complex<half_t>)
|
| 919 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> half< complex<half_t> >() {
|
| 920 |
+
return complex<half_t>(half<half_t>(), half_t());
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for half_t)
|
| 924 |
+
template <> CUTLASS_HOST_DEVICE half_t root_two<half_t>() {
|
| 925 |
+
uint16_t bits = 0x3da8u;
|
| 926 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for complex<half_t>)
|
| 930 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> root_two< complex<half_t> >() {
|
| 931 |
+
return complex<half_t>(root_two<half_t>(), half_t());
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t)
|
| 935 |
+
template <> CUTLASS_HOST_DEVICE half_t half_root_two<half_t>() {
|
| 936 |
+
uint16_t bits = 0x39a8u;
|
| 937 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<half_t>)
|
| 941 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> half_root_two< complex<half_t> >() {
|
| 942 |
+
return complex<half_t>(half_root_two<half_t>(), half_t());
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
/// Returns ln(2), approximately 0.693... (specialization for half_t)
|
| 946 |
+
template <> CUTLASS_HOST_DEVICE half_t ln_two<half_t>() {
|
| 947 |
+
uint16_t bits = 0x398cu;
|
| 948 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
/// Returns ln(2), approximately 0.693... (specialization for complex<half_t>)
|
| 952 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> ln_two< complex<half_t> >() {
|
| 953 |
+
return complex<half_t>(ln_two<half_t>(), half_t());
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t)
|
| 957 |
+
template <> CUTLASS_HOST_DEVICE half_t ln_ln_two<half_t>() {
|
| 958 |
+
uint16_t bits = 0xb5ddu;
|
| 959 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<half_t>)
|
| 963 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> ln_ln_two< complex<half_t> >() {
|
| 964 |
+
return complex<half_t>(ln_ln_two<half_t>(), half_t());
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
/// Returns 1/3, approximately 0.333... (specialization for half_t)
|
| 968 |
+
template <> CUTLASS_HOST_DEVICE half_t third<half_t>() {
|
| 969 |
+
uint16_t bits = 0x3555u;
|
| 970 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
/// Returns 1/3, approximately 0.333... (specialization for complex<half_t>)
|
| 974 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> third< complex<half_t> >() {
|
| 975 |
+
return complex<half_t>(third<half_t>(), half_t());
|
| 976 |
+
}
|
| 977 |
+
|
| 978 |
+
/// Returns 2/3, approximately 0.666... (specialization for half_t)
|
| 979 |
+
template <> CUTLASS_HOST_DEVICE half_t twothirds<half_t>() {
|
| 980 |
+
uint16_t bits = 0x3955u;
|
| 981 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 982 |
+
}
|
| 983 |
+
|
| 984 |
+
/// Returns 2/3, approximately 0.666... (specialization for complex<half_t>)
|
| 985 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> twothirds< complex<half_t> >() {
|
| 986 |
+
return complex<half_t>(twothirds<half_t>(), half_t());
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for half_t)
|
| 990 |
+
template <> CUTLASS_HOST_DEVICE half_t pi_minus_three<half_t>() {
|
| 991 |
+
uint16_t bits = 0x3088u;
|
| 992 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 993 |
+
}
|
| 994 |
+
|
| 995 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for complex<half_t>)
|
| 996 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> pi_minus_three< complex<half_t> >() {
|
| 997 |
+
return complex<half_t>(pi_minus_three<half_t>(), half_t());
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for half_t)
|
| 1001 |
+
template <> CUTLASS_HOST_DEVICE half_t four_minus_pi<half_t>() {
|
| 1002 |
+
uint16_t bits = 0x3adeu;
|
| 1003 |
+
return reinterpret_cast<half_t const &>(bits);
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for complex<half_t>)
|
| 1007 |
+
template <> CUTLASS_HOST_DEVICE complex<half_t> four_minus_pi< complex<half_t> >() {
|
| 1008 |
+
return complex<half_t>(four_minus_pi<half_t>(), half_t());
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
/////////////////////////////////////////////////////////////////////////////////////
|
| 1012 |
+
|
| 1013 |
+
// Specialization for bfloat16_t
|
| 1014 |
+
|
| 1015 |
+
/// Returns 1, the multiplicative identity element (specialization for bfloat16_t)
|
| 1016 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t one<bfloat16_t>() {
|
| 1017 |
+
uint16_t bits = 0x3f80u;
|
| 1018 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
/// Returns 1, the multiplicative identity element (specialization for complex<bfloat16_t>)
|
| 1022 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> one< complex<bfloat16_t> >() {
|
| 1023 |
+
return complex<bfloat16_t>(one<bfloat16_t>(), bfloat16_t());
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
/// Returns 0, the additive identity element (specialization for bfloat16_t)
|
| 1027 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t zero<bfloat16_t>() {
|
| 1028 |
+
uint16_t bits = 0x0u;
|
| 1029 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
/// Returns 0, the additive identity element (specialization for complex<bfloat16_t>)
|
| 1033 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> zero< complex<bfloat16_t> >() {
|
| 1034 |
+
return complex<bfloat16_t>(zero<bfloat16_t>(), bfloat16_t());
|
| 1035 |
+
}
|
| 1036 |
+
|
| 1037 |
+
/// Returns 2 (specialization for bfloat16_t)
|
| 1038 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t two<bfloat16_t>() {
|
| 1039 |
+
uint16_t bits = 0x4000u;
|
| 1040 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1041 |
+
}
|
| 1042 |
+
|
| 1043 |
+
/// Returns 2 (specialization for complex<bfloat16_t>)
|
| 1044 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> two< complex<bfloat16_t> >() {
|
| 1045 |
+
return complex<bfloat16_t>(two<bfloat16_t>(), bfloat16_t());
|
| 1046 |
+
}
|
| 1047 |
+
|
| 1048 |
+
/// Returns pi, approximately 3.141 (specialization for bfloat16_t)
|
| 1049 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t pi<bfloat16_t>() {
|
| 1050 |
+
uint16_t bits = 0x4049u;
|
| 1051 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1052 |
+
}
|
| 1053 |
+
|
| 1054 |
+
/// Returns pi, approximately 3.141 (specialization for complex<bfloat16_t>)
|
| 1055 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> pi< complex<bfloat16_t> >() {
|
| 1056 |
+
return complex<bfloat16_t>(pi<bfloat16_t>(), bfloat16_t());
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
/// Returns 2 * pi (specialization for bfloat16_t)
|
| 1060 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi<bfloat16_t>() {
|
| 1061 |
+
uint16_t bits = 0x40c9u;
|
| 1062 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1063 |
+
}
|
| 1064 |
+
|
| 1065 |
+
/// Returns 2 * pi (specialization for complex<bfloat16_t>)
|
| 1066 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> two_pi< complex<bfloat16_t> >() {
|
| 1067 |
+
return complex<bfloat16_t>(two_pi<bfloat16_t>(), bfloat16_t());
|
| 1068 |
+
}
|
| 1069 |
+
|
| 1070 |
+
/// Returns pi / 2 (specialization for bfloat16_t)
|
| 1071 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi<bfloat16_t>() {
|
| 1072 |
+
uint16_t bits = 0x3fc9u;
|
| 1073 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1074 |
+
}
|
| 1075 |
+
|
| 1076 |
+
/// Returns pi / 2 (specialization for complex<bfloat16_t>)
|
| 1077 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> half_pi< complex<bfloat16_t> >() {
|
| 1078 |
+
return complex<bfloat16_t>(half_pi<bfloat16_t>(), bfloat16_t());
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
+
/// Returns sqrt(pi) (specialization for bfloat16_t)
|
| 1082 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi<bfloat16_t>() {
|
| 1083 |
+
uint16_t bits = 0x3fe3u;
|
| 1084 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1085 |
+
}
|
| 1086 |
+
|
| 1087 |
+
/// Returns sqrt(pi) (specialization for complex<bfloat16_t>)
|
| 1088 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_pi< complex<bfloat16_t> >() {
|
| 1089 |
+
return complex<bfloat16_t>(root_pi<bfloat16_t>(), bfloat16_t());
|
| 1090 |
+
}
|
| 1091 |
+
|
| 1092 |
+
/// Returns sqrt(pi / 2) (specialization for bfloat16_t)
|
| 1093 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi<bfloat16_t>() {
|
| 1094 |
+
uint16_t bits = 0x3fa0u;
|
| 1095 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
/// Returns sqrt(pi / 2) (specialization for complex<bfloat16_t>)
|
| 1099 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_half_pi< complex<bfloat16_t> >() {
|
| 1100 |
+
return complex<bfloat16_t>(root_half_pi<bfloat16_t>(), bfloat16_t());
|
| 1101 |
+
}
|
| 1102 |
+
|
| 1103 |
+
/// Returns sqrt(2 * pi) (specialization for bfloat16_t)
|
| 1104 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi<bfloat16_t>() {
|
| 1105 |
+
uint16_t bits = 0x4020u;
|
| 1106 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
/// Returns sqrt(2 * pi) (specialization for complex<bfloat16_t>)
|
| 1110 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_two_pi< complex<bfloat16_t> >() {
|
| 1111 |
+
return complex<bfloat16_t>(root_two_pi<bfloat16_t>(), bfloat16_t());
|
| 1112 |
+
}
|
| 1113 |
+
|
| 1114 |
+
/// Returns sqrt(ln(4)) (specialization for bfloat16_t)
|
| 1115 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four<bfloat16_t>() {
|
| 1116 |
+
uint16_t bits = 0x3f97u;
|
| 1117 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1118 |
+
}
|
| 1119 |
+
|
| 1120 |
+
/// Returns sqrt(ln(4)) (specialization for complex<bfloat16_t>)
|
| 1121 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_ln_four< complex<bfloat16_t> >() {
|
| 1122 |
+
return complex<bfloat16_t>(root_ln_four<bfloat16_t>(), bfloat16_t());
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
/// Returns e, approximately 2.718... (specialization for bfloat16_t)
|
| 1126 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t e<bfloat16_t>() {
|
| 1127 |
+
uint16_t bits = 0x402eu;
|
| 1128 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
/// Returns e, approximately 2.718... (specialization for complex<bfloat16_t>)
|
| 1132 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> e< complex<bfloat16_t> >() {
|
| 1133 |
+
return complex<bfloat16_t>(e<bfloat16_t>(), bfloat16_t());
|
| 1134 |
+
}
|
| 1135 |
+
|
| 1136 |
+
/// Returns (1/2) (specialization for bfloat16_t)
|
| 1137 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t half<bfloat16_t>() {
|
| 1138 |
+
uint16_t bits = 0x3f00u;
|
| 1139 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1140 |
+
}
|
| 1141 |
+
|
| 1142 |
+
/// Returns (1/2) (specialization for complex<bfloat16_t>)
|
| 1143 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> half< complex<bfloat16_t> >() {
|
| 1144 |
+
return complex<bfloat16_t>(half<bfloat16_t>(), bfloat16_t());
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t)
|
| 1148 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t root_two<bfloat16_t>() {
|
| 1149 |
+
uint16_t bits = 0x3fb5u;
|
| 1150 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1151 |
+
}
|
| 1152 |
+
|
| 1153 |
+
/// Returns sqrt(2), approximately 1.414... (specialization for complex<bfloat16_t>)
|
| 1154 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> root_two< complex<bfloat16_t> >() {
|
| 1155 |
+
return complex<bfloat16_t>(root_two<bfloat16_t>(), bfloat16_t());
|
| 1156 |
+
}
|
| 1157 |
+
|
| 1158 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t)
|
| 1159 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two<bfloat16_t>() {
|
| 1160 |
+
uint16_t bits = 0x3f35u;
|
| 1161 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex<bfloat16_t>)
|
| 1165 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> half_root_two< complex<bfloat16_t> >() {
|
| 1166 |
+
return complex<bfloat16_t>(half_root_two<bfloat16_t>(), bfloat16_t());
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t)
|
| 1170 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two<bfloat16_t>() {
|
| 1171 |
+
uint16_t bits = 0x3f31u;
|
| 1172 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1173 |
+
}
|
| 1174 |
+
|
| 1175 |
+
/// Returns ln(2), approximately 0.693... (specialization for complex<bfloat16_t>)
|
| 1176 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> ln_two< complex<bfloat16_t> >() {
|
| 1177 |
+
return complex<bfloat16_t>(ln_two<bfloat16_t>(), bfloat16_t());
|
| 1178 |
+
}
|
| 1179 |
+
|
| 1180 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t)
|
| 1181 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two<bfloat16_t>() {
|
| 1182 |
+
uint16_t bits = 0xbebcu;
|
| 1183 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1184 |
+
}
|
| 1185 |
+
|
| 1186 |
+
/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex<bfloat16_t>)
|
| 1187 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> ln_ln_two< complex<bfloat16_t> >() {
|
| 1188 |
+
return complex<bfloat16_t>(ln_ln_two<bfloat16_t>(), bfloat16_t());
|
| 1189 |
+
}
|
| 1190 |
+
|
| 1191 |
+
/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t)
|
| 1192 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t third<bfloat16_t>() {
|
| 1193 |
+
uint16_t bits = 0x3eabu;
|
| 1194 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
/// Returns 1/3, approximately 0.333... (specialization for complex<bfloat16_t>)
|
| 1198 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> third< complex<bfloat16_t> >() {
|
| 1199 |
+
return complex<bfloat16_t>(third<bfloat16_t>(), bfloat16_t());
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t)
|
| 1203 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds<bfloat16_t>() {
|
| 1204 |
+
uint16_t bits = 0x3f2bu;
|
| 1205 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
/// Returns 2/3, approximately 0.666... (specialization for complex<bfloat16_t>)
|
| 1209 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> twothirds< complex<bfloat16_t> >() {
|
| 1210 |
+
return complex<bfloat16_t>(twothirds<bfloat16_t>(), bfloat16_t());
|
| 1211 |
+
}
|
| 1212 |
+
|
| 1213 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t)
|
| 1214 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three<bfloat16_t>() {
|
| 1215 |
+
uint16_t bits = 0x3e11u;
|
| 1216 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1217 |
+
}
|
| 1218 |
+
|
| 1219 |
+
/// Returns pi - 3, approximately 0.1416... (specialization for complex<bfloat16_t>)
|
| 1220 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> pi_minus_three< complex<bfloat16_t> >() {
|
| 1221 |
+
return complex<bfloat16_t>(pi_minus_three<bfloat16_t>(), bfloat16_t());
|
| 1222 |
+
}
|
| 1223 |
+
|
| 1224 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t)
|
| 1225 |
+
template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi<bfloat16_t>() {
|
| 1226 |
+
uint16_t bits = 0x3f5cu;
|
| 1227 |
+
return reinterpret_cast<bfloat16_t const &>(bits);
|
| 1228 |
+
}
|
| 1229 |
+
|
| 1230 |
+
/// Returns 4 - pi, approximately 0.858... (specialization for complex<bfloat16_t>)
|
| 1231 |
+
template <> CUTLASS_HOST_DEVICE complex<bfloat16_t> four_minus_pi< complex<bfloat16_t> >() {
|
| 1232 |
+
return complex<bfloat16_t>(four_minus_pi<bfloat16_t>(), bfloat16_t());
|
| 1233 |
+
}
|
| 1234 |
+
///////////////////////////////////////////////////////////////////////////////////
|
| 1235 |
+
|
| 1236 |
+
} // namespace constants
|
| 1237 |
+
} // namespace cutlass
|
| 1238 |
+
|
| 1239 |
+
///////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/detail/dependent_false.hpp"
|
| 34 |
+
#include "cutlass/conv/collective/collective_conv.hpp"
|
| 35 |
+
|
| 36 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
|
| 38 |
+
namespace cutlass::conv::collective {
|
| 39 |
+
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
// Used to specify stage counts or dispatch to automatic computation of stage count
|
| 43 |
+
template<int num_stages>
|
| 44 |
+
struct StageCount {
|
| 45 |
+
static constexpr int value = num_stages;
|
| 46 |
+
|
| 47 |
+
StageCount() = default;
|
| 48 |
+
explicit StageCount(cute::Int<num_stages>) {}
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
template<int carveout_bytes>
|
| 52 |
+
struct StageCountAutoCarveout {
|
| 53 |
+
static constexpr int bytes = carveout_bytes;
|
| 54 |
+
|
| 55 |
+
StageCountAutoCarveout() = default;
|
| 56 |
+
explicit StageCountAutoCarveout(cute::Int<carveout_bytes>) {}
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
// Used to automatically let the builder pick the kernel schedule.
|
| 60 |
+
// Can be overridden with kernel schedule tags in cutlass/conv/dispatch_policy.hpp
|
| 61 |
+
struct KernelScheduleAuto {};
|
| 62 |
+
|
| 63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
template <
|
| 66 |
+
class ArchTag,
|
| 67 |
+
class OpClass,
|
| 68 |
+
conv::Operator,
|
| 69 |
+
class ElementA,
|
| 70 |
+
class GmemLayoutA,
|
| 71 |
+
int AlignmentA,
|
| 72 |
+
class ElementB,
|
| 73 |
+
class GmemLayoutB,
|
| 74 |
+
int AlignmentB,
|
| 75 |
+
class ElementAccumulator,
|
| 76 |
+
class TileShape_MNK,
|
| 77 |
+
class ClusterShape_MNK,
|
| 78 |
+
class StageCountType,
|
| 79 |
+
class KernelScheduleType,
|
| 80 |
+
class Enable = void
|
| 81 |
+
>
|
| 82 |
+
struct CollectiveBuilder {
|
| 83 |
+
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not build a collective for given parameters.");
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 87 |
+
|
| 88 |
+
} // namespace cutlass::conv::collective
|
| 89 |
+
|
| 90 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 91 |
+
|
| 92 |
+
#include "builders/sm90_gmma_builder.inl"
|
| 93 |
+
#include "builders/sm100_umma_builder.inl"
|
| 94 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/detail/dependent_false.hpp"
|
| 34 |
+
#include "cutlass/conv/collective/detail.hpp"
|
| 35 |
+
|
| 36 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
|
| 38 |
+
namespace cutlass::conv::collective {
|
| 39 |
+
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
template <
|
| 43 |
+
class DispatchPolicy,
|
| 44 |
+
class TileShape,
|
| 45 |
+
class ElementA,
|
| 46 |
+
class ElementB,
|
| 47 |
+
class TiledMma,
|
| 48 |
+
class TileTraitsA,
|
| 49 |
+
class TileTraitsB
|
| 50 |
+
>
|
| 51 |
+
struct CollectiveConv {
|
| 52 |
+
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
} // namespace cutlass::conv::collective
|
| 58 |
+
|
| 59 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
#include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp"
|
| 62 |
+
#include "sm100_implicit_gemm_umma_warpspecialized.hpp"
|
| 63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/conv/convnd_problem_shape.hpp"
|
| 34 |
+
|
| 35 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 36 |
+
|
| 37 |
+
namespace cutlass::conv::collective::detail {
|
| 38 |
+
|
| 39 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
// Construct the stride types for conv collectives based on the dispatch policy, strides 64b by default
|
| 42 |
+
template <class DispatchPolicy>
|
| 43 |
+
constexpr auto
|
| 44 |
+
sm90_dispatch_policy_to_stride_A() {
|
| 45 |
+
if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) {
|
| 46 |
+
// Maps to modes ((w,n), C)
|
| 47 |
+
if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
|
| 48 |
+
return cute::Stride<cute::Stride<int64_t, int64_t>,
|
| 49 |
+
cute::Int<1>>{};
|
| 50 |
+
}
|
| 51 |
+
// Maps to modes ((w,h,n), C)
|
| 52 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
|
| 53 |
+
return cute::Stride<cute::Stride<int64_t, int64_t, int64_t>,
|
| 54 |
+
cute::Int<1>>{};
|
| 55 |
+
}
|
| 56 |
+
// Maps to modes ((w,h,d,n), C)
|
| 57 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
|
| 58 |
+
return cute::Stride<cute::Stride<int64_t, int64_t, int64_t, int64_t>,
|
| 59 |
+
cute::Int<1>>{};
|
| 60 |
+
}
|
| 61 |
+
// error dims assert
|
| 62 |
+
else {
|
| 63 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) {
|
| 67 |
+
// Maps to modes (k, nq/npq/nzpq)
|
| 68 |
+
if constexpr (DispatchPolicy::NumSpatialDimensions == 1 ||
|
| 69 |
+
DispatchPolicy::NumSpatialDimensions == 2 ||
|
| 70 |
+
DispatchPolicy::NumSpatialDimensions == 3) {
|
| 71 |
+
return cute::Stride<cute::Int<1>, int64_t>{};
|
| 72 |
+
}
|
| 73 |
+
// error dims assert
|
| 74 |
+
else {
|
| 75 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) {
|
| 79 |
+
// Maps to modes ((q,n), K)
|
| 80 |
+
if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
|
| 81 |
+
return cute::Stride<cute::Stride<int64_t, int64_t>,
|
| 82 |
+
cute::Int<1>>{};
|
| 83 |
+
}
|
| 84 |
+
// Maps to modes ((q,p,n), K)
|
| 85 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
|
| 86 |
+
return cute::Stride<cute::Stride<int64_t, int64_t, int64_t>,
|
| 87 |
+
cute::Int<1>>{};
|
| 88 |
+
}
|
| 89 |
+
// Maps to modes ((q,p,z,n), K)
|
| 90 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
|
| 91 |
+
return cute::Stride<cute::Stride<int64_t, int64_t, int64_t, int64_t>,
|
| 92 |
+
cute::Int<1>>{};
|
| 93 |
+
}
|
| 94 |
+
// error dims assert
|
| 95 |
+
else {
|
| 96 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
else {
|
| 100 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported ConvOp.");
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Construct the stirde types for conv collectives based on the dispatch policy, strides 64b by default
|
| 105 |
+
template <class DispatchPolicy>
|
| 106 |
+
constexpr auto
|
| 107 |
+
sm90_dispatch_policy_to_stride_B() {
|
| 108 |
+
if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) {
|
| 109 |
+
// Maps to modes (k, (C,s))
|
| 110 |
+
if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
|
| 111 |
+
return cute::Stride<int64_t, cute::Stride<cute::Int<1>, int64_t>>{};
|
| 112 |
+
}
|
| 113 |
+
// Maps to modes (k, (C,s,r))
|
| 114 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
|
| 115 |
+
return cute::Stride<int64_t, cute::Stride<cute::Int<1>, int64_t, int64_t>>{};
|
| 116 |
+
}
|
| 117 |
+
// Maps to modes (k, (C,s,r,t))
|
| 118 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
|
| 119 |
+
return cute::Stride<int64_t, cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>>{};
|
| 120 |
+
}
|
| 121 |
+
// error dims assert
|
| 122 |
+
else {
|
| 123 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) {
|
| 127 |
+
// Maps to modes (C, (w,n))
|
| 128 |
+
if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
|
| 129 |
+
return cute::Stride<cute::Int<1>,
|
| 130 |
+
cute::Stride<int64_t, int64_t>>{};
|
| 131 |
+
}
|
| 132 |
+
// Maps to modes (C, (w,h,n))
|
| 133 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
|
| 134 |
+
return cute::Stride<cute::Int<1>,
|
| 135 |
+
cute::Stride<int64_t, int64_t, int64_t>>{};
|
| 136 |
+
}
|
| 137 |
+
// Maps to modes (C, (w,h,d,n))
|
| 138 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
|
| 139 |
+
return cute::Stride<cute::Int<1>,
|
| 140 |
+
cute::Stride<int64_t, int64_t, int64_t, int64_t>>{};
|
| 141 |
+
}
|
| 142 |
+
// error dims assert
|
| 143 |
+
else {
|
| 144 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) {
|
| 148 |
+
// Maps to modes (C, (k,s))
|
| 149 |
+
if constexpr (DispatchPolicy::NumSpatialDimensions == 1) {
|
| 150 |
+
return cute::Stride<cute::Int<1>, cute::Stride<int64_t, int64_t>>{};
|
| 151 |
+
}
|
| 152 |
+
// Maps to modes (C, (k,s,r))
|
| 153 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) {
|
| 154 |
+
return cute::Stride<cute::Int<1>, cute::Stride<int64_t, int64_t, int64_t>>{};
|
| 155 |
+
}
|
| 156 |
+
// Maps to modes (C, (k,s,r,t))
|
| 157 |
+
else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) {
|
| 158 |
+
return cute::Stride<cute::Int<1>, cute::Stride<int64_t, int64_t, int64_t, int64_t>>{};
|
| 159 |
+
}
|
| 160 |
+
// error dims assert
|
| 161 |
+
else {
|
| 162 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported spatial dim count.");
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
else {
|
| 166 |
+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Unsupported ConvOp.");
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
template <class DispatchPolicy>
|
| 172 |
+
constexpr auto
|
| 173 |
+
sm100_dispatch_policy_to_stride_A() {
|
| 174 |
+
return sm90_dispatch_policy_to_stride_A<DispatchPolicy>();
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
template <class DispatchPolicy>
|
| 178 |
+
constexpr auto
|
| 179 |
+
sm100_dispatch_policy_to_stride_B() {
|
| 180 |
+
return sm90_dispatch_policy_to_stride_B<DispatchPolicy>();
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 185 |
+
|
| 186 |
+
// Compute the lower/near corner, returning it as a cute::array in [W,H,D] order
|
| 187 |
+
template <conv::Operator ConvOp, int NumSpatialDimensions>
|
| 188 |
+
CUTLASS_HOST_DEVICE
|
| 189 |
+
constexpr auto
|
| 190 |
+
compute_lower_corner_whd(ConvProblemShape<ConvOp, NumSpatialDimensions> const& problem_shape) {
|
| 191 |
+
using cute::for_each;
|
| 192 |
+
using cute::make_seq;
|
| 193 |
+
|
| 194 |
+
cute::array<int, NumSpatialDimensions> lower{};
|
| 195 |
+
if constexpr (ConvOp == conv::Operator::kFprop ||
|
| 196 |
+
ConvOp == conv::Operator::kWgrad) {
|
| 197 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 198 |
+
lower[NumSpatialDimensions-1-i] = -1 * problem_shape.lower_padding[i];
|
| 199 |
+
});
|
| 200 |
+
}
|
| 201 |
+
else if constexpr (ConvOp == conv::Operator::kDgrad) {
|
| 202 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 203 |
+
lower[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] -
|
| 204 |
+
(problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i];
|
| 205 |
+
});
|
| 206 |
+
}
|
| 207 |
+
return lower;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
// Computes the upper/far corner, returning it as a cute::array in [W,H,D] order
|
| 211 |
+
template <conv::Operator ConvOp, int NumSpatialDimensions>
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
constexpr auto
|
| 214 |
+
compute_upper_corner_whd(ConvProblemShape<ConvOp, NumSpatialDimensions> const& problem_shape) {
|
| 215 |
+
using cute::for_each;
|
| 216 |
+
using cute::make_seq;
|
| 217 |
+
|
| 218 |
+
cute::array<int, NumSpatialDimensions> upper{};
|
| 219 |
+
if constexpr (ConvOp == conv::Operator::kFprop) {
|
| 220 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 221 |
+
upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] -
|
| 222 |
+
(problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i];
|
| 223 |
+
});
|
| 224 |
+
}
|
| 225 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 226 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 227 |
+
upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] -
|
| 228 |
+
(problem_shape.shape_C[i+1] - 1) * problem_shape.dilation[i];
|
| 229 |
+
});
|
| 230 |
+
}
|
| 231 |
+
else if constexpr (ConvOp == conv::Operator::kDgrad) {
|
| 232 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 233 |
+
upper[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] -
|
| 234 |
+
(problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i] + problem_shape.shape_C[i+1] - problem_shape.shape_A[i+1];
|
| 235 |
+
});
|
| 236 |
+
}
|
| 237 |
+
return upper;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
// Compute the lower/near corner of (t,r,s), returning it as a cute::array in [S,R,T] order
|
| 241 |
+
template <conv::Operator ConvOp, int NumSpatialDimensions>
|
| 242 |
+
CUTLASS_HOST_DEVICE
|
| 243 |
+
constexpr auto
|
| 244 |
+
compute_lower_srt(ConvProblemShape<ConvOp, NumSpatialDimensions> const& problem_shape) {
|
| 245 |
+
using cute::for_each;
|
| 246 |
+
using cute::make_seq;
|
| 247 |
+
|
| 248 |
+
cute::array<int, NumSpatialDimensions> lower{};
|
| 249 |
+
if constexpr (ConvOp == conv::Operator::kFprop ||
|
| 250 |
+
ConvOp == conv::Operator::kWgrad) {
|
| 251 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 252 |
+
lower[NumSpatialDimensions-1-i] = 0;
|
| 253 |
+
});
|
| 254 |
+
}
|
| 255 |
+
else if constexpr (ConvOp == conv::Operator::kDgrad) {
|
| 256 |
+
for_each(make_seq<NumSpatialDimensions>{}, [&](auto i) {
|
| 257 |
+
lower[NumSpatialDimensions-1-i] = (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i];
|
| 258 |
+
});
|
| 259 |
+
}
|
| 260 |
+
return lower;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
template <class CopyOp> struct is_im2col_load { static constexpr bool value = false; };
|
| 264 |
+
template <> struct is_im2col_load<cute::SM90_TMA_LOAD_IM2COL > { static constexpr bool value = true; };
|
| 265 |
+
template <> struct is_im2col_load<cute::SM90_TMA_LOAD_IM2COL_MULTICAST> { static constexpr bool value = true; };
|
| 266 |
+
template <> struct is_im2col_load<cute::SM100_TMA_2SM_LOAD_IM2COL > { static constexpr bool value = true; };
|
| 267 |
+
template <> struct is_im2col_load<cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST> { static constexpr bool value = true; };
|
| 268 |
+
|
| 269 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 270 |
+
|
| 271 |
+
} // namespace cutlass::conv::collective::detail
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp
ADDED
|
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
#pragma once
|
| 34 |
+
|
| 35 |
+
#include "cutlass/cutlass.h"
|
| 36 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 37 |
+
#include "cutlass/pipeline/pipeline.hpp"
|
| 38 |
+
#include "cutlass/gemm/gemm.h"
|
| 39 |
+
#include "cutlass/detail/cluster.hpp"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/conv/detail.hpp"
|
| 42 |
+
#include "cute/algorithm/functional.hpp"
|
| 43 |
+
#include "cute/arch/cluster_sm90.hpp"
|
| 44 |
+
#include "cute/atom/mma_atom.hpp"
|
| 45 |
+
#include "cute/algorithm/gemm.hpp"
|
| 46 |
+
#include "cute/numeric/arithmetic_tuple.hpp"
|
| 47 |
+
#include "cutlass/trace.h"
|
| 48 |
+
|
| 49 |
+
#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0)
|
| 50 |
+
# include <sstream>
|
| 51 |
+
#endif
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
namespace cutlass::conv::collective {
|
| 56 |
+
using namespace cute;
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
// WarpSpecialized Mainloop
|
| 61 |
+
// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
|
| 62 |
+
template <
|
| 63 |
+
conv::Operator ConvOp,
|
| 64 |
+
int Stages,
|
| 65 |
+
int NumSpatialDims,
|
| 66 |
+
int SchedulerPipelineStageCount,
|
| 67 |
+
int AccumulatorPipelineStageCount,
|
| 68 |
+
class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
|
| 69 |
+
class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL)
|
| 70 |
+
class ElementA_,
|
| 71 |
+
class ElementB_,
|
| 72 |
+
class TiledMma_,
|
| 73 |
+
class TileTraitsA_,
|
| 74 |
+
class TileTraitsB_>
|
| 75 |
+
struct CollectiveConv<
|
| 76 |
+
MainloopSm100TmaUmmaWarpSpecializedImplicitGemm<
|
| 77 |
+
ConvOp,
|
| 78 |
+
Stages,
|
| 79 |
+
NumSpatialDims,
|
| 80 |
+
SchedulerPipelineStageCount,
|
| 81 |
+
AccumulatorPipelineStageCount,
|
| 82 |
+
ClusterShape>,
|
| 83 |
+
TileShapeMNKL_,
|
| 84 |
+
ElementA_,
|
| 85 |
+
ElementB_,
|
| 86 |
+
TiledMma_,
|
| 87 |
+
TileTraitsA_,
|
| 88 |
+
TileTraitsB_>
|
| 89 |
+
{
|
| 90 |
+
//
|
| 91 |
+
// Type Aliases
|
| 92 |
+
//
|
| 93 |
+
using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm<
|
| 94 |
+
ConvOp,
|
| 95 |
+
Stages,
|
| 96 |
+
NumSpatialDims,
|
| 97 |
+
SchedulerPipelineStageCount,
|
| 98 |
+
AccumulatorPipelineStageCount,
|
| 99 |
+
ClusterShape>;
|
| 100 |
+
using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
| 101 |
+
using ElementA = ElementA_;
|
| 102 |
+
using ElementB = ElementB_;
|
| 103 |
+
using TiledMma = TiledMma_;
|
| 104 |
+
using ElementAccumulator = typename TiledMma::ValTypeC;
|
| 105 |
+
using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy;
|
| 106 |
+
using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy;
|
| 107 |
+
using SmemLayoutAtomA = typename TileTraitsA_::SmemLayoutAtom;
|
| 108 |
+
using SmemLayoutAtomB = typename TileTraitsB_::SmemLayoutAtom;
|
| 109 |
+
using ArchTag = typename DispatchPolicy::ArchTag;
|
| 110 |
+
static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions;
|
| 111 |
+
static constexpr int NumTensorDimensions = NumSpatialDimensions + 2;
|
| 112 |
+
// deducde the kernel facing stride tuple types based on the dispatch policy (spatial dim, algo, etc.)
|
| 113 |
+
using StrideA = decltype(detail::sm100_dispatch_policy_to_stride_A<DispatchPolicy>());
|
| 114 |
+
using StrideB = decltype(detail::sm100_dispatch_policy_to_stride_B<DispatchPolicy>());
|
| 115 |
+
|
| 116 |
+
static constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShape>;
|
| 117 |
+
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
| 118 |
+
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
| 119 |
+
using TmaInternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>>;
|
| 120 |
+
using TmaInternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>>;
|
| 121 |
+
|
| 122 |
+
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
| 123 |
+
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
| 124 |
+
|
| 125 |
+
// Determine MMA type: MMA_1SM vs MMA_2SM
|
| 126 |
+
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>;
|
| 127 |
+
|
| 128 |
+
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
|
| 129 |
+
DispatchPolicy::Stages,
|
| 130 |
+
ClusterShape,
|
| 131 |
+
AtomThrShapeMNK>;
|
| 132 |
+
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
|
| 133 |
+
|
| 134 |
+
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
| 135 |
+
|
| 136 |
+
CUTE_STATIC_ASSERT_V(evenly_divides(shape<0>(TileShape{}), tile_size<0>(TiledMma{})), "TileShape_M should be evenly divided by TiledMma_M");
|
| 137 |
+
CUTE_STATIC_ASSERT_V(evenly_divides(shape<1>(TileShape{}), tile_size<1>(TiledMma{})) || (ConvOp == conv::Operator::kWgrad), "TileShape_N should be evenly divided by TiledMma_N");
|
| 138 |
+
|
| 139 |
+
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
|
| 140 |
+
|
| 141 |
+
// Define A and B block shapes for reduced size TMA_LOADs
|
| 142 |
+
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
|
| 143 |
+
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
|
| 144 |
+
|
| 145 |
+
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
| 146 |
+
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
|
| 147 |
+
"SmemLayoutAtom must evenly divide tile shape.");
|
| 148 |
+
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
|
| 149 |
+
"SmemLayoutAtom must evenly divide tile shape.");
|
| 150 |
+
|
| 151 |
+
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
| 152 |
+
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
|
| 153 |
+
"SmemLayoutAtom must evenly divide tile shape.");
|
| 154 |
+
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
|
| 155 |
+
"SmemLayoutAtom must evenly divide tile shape.");
|
| 156 |
+
|
| 157 |
+
// Tile along K mode first before tiling over MN. PIPE mode last as usual.
|
| 158 |
+
// This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs.
|
| 159 |
+
using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
|
| 160 |
+
SmemLayoutAtomA{},
|
| 161 |
+
append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
|
| 162 |
+
Step<_2,_1,_3>{}));
|
| 163 |
+
using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
|
| 164 |
+
SmemLayoutAtomB{},
|
| 165 |
+
append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
|
| 166 |
+
Step<_2,_1,_3>{}));
|
| 167 |
+
|
| 168 |
+
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
| 169 |
+
static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
| 170 |
+
cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
| 171 |
+
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
| 172 |
+
|
| 173 |
+
static constexpr bool is_im2col_A = detail::is_im2col_load<GmemTiledCopyA>::value;
|
| 174 |
+
static constexpr bool is_im2col_B = detail::is_im2col_load<GmemTiledCopyB>::value;
|
| 175 |
+
static constexpr bool is_strided_dgrad = ConvOp == conv::Operator::kDgrad && not is_im2col_A && not is_im2col_B;
|
| 176 |
+
|
| 177 |
+
static constexpr int TileShapeMNKLRank = rank(TileShapeMNKL_{});
|
| 178 |
+
// If rank > 3, TileL exists and it is GroupsPerTile. The kernel is grouped conv now.
|
| 179 |
+
static constexpr bool is_grouped_wgrad = ConvOp == conv::Operator::kWgrad && TileShapeMNKLRank > 3;
|
| 180 |
+
|
| 181 |
+
struct SharedStorage {
|
| 182 |
+
struct TensorStorage : cute::aligned_struct<128, _0> {
|
| 183 |
+
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
| 184 |
+
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
| 185 |
+
} tensors;
|
| 186 |
+
|
| 187 |
+
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
| 188 |
+
PipelineStorage pipeline;
|
| 189 |
+
};
|
| 190 |
+
|
| 191 |
+
using TensorStorage = typename SharedStorage::TensorStorage;
|
| 192 |
+
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
| 193 |
+
|
| 194 |
+
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
|
| 195 |
+
static constexpr uint32_t TmaTransactionBytes =
|
| 196 |
+
size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof(ElementA))) +
|
| 197 |
+
size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof(ElementB)));
|
| 198 |
+
|
| 199 |
+
// Host side kernel arguments
|
| 200 |
+
struct Arguments {
|
| 201 |
+
ElementA const* ptr_A{nullptr};
|
| 202 |
+
ElementB const* ptr_B{nullptr};
|
| 203 |
+
};
|
| 204 |
+
|
| 205 |
+
private:
|
| 206 |
+
|
| 207 |
+
// Note that for fprop and non-strided dgrad kernel, the tma load mode is im2col for tensor A and tiled for
|
| 208 |
+
// tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor
|
| 209 |
+
// B since operand A, B is swapped.
|
| 210 |
+
// For strided dgrad A and B are both tma tiled and not im2col
|
| 211 |
+
|
| 212 |
+
template <class TensorA, class ClusterShapeVMNK>
|
| 213 |
+
static constexpr auto
|
| 214 |
+
get_tma_load_a_instance(
|
| 215 |
+
TensorA const& tensor_a,
|
| 216 |
+
ProblemShape const& problem_shape,
|
| 217 |
+
ClusterShapeVMNK const& cluster_shape_vmnk) {
|
| 218 |
+
|
| 219 |
+
if constexpr (is_im2col_A) {
|
| 220 |
+
// compute the upper and lower corners based on the conv padding
|
| 221 |
+
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
| 222 |
+
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
| 223 |
+
auto lower_srt = detail::compute_lower_srt(problem_shape);
|
| 224 |
+
|
| 225 |
+
// gbasis strides for dgrad kernel need to be negated
|
| 226 |
+
cute::array<int32_t, NumSpatialDimensions> stride_srt{};
|
| 227 |
+
for (int i = 0; i < NumSpatialDimensions; ++i) {
|
| 228 |
+
stride_srt[i] = ConvOp == conv::Operator::kDgrad ?
|
| 229 |
+
-problem_shape.dilation[NumSpatialDimensions-1-i] :
|
| 230 |
+
problem_shape.dilation[NumSpatialDimensions-1-i];
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
return make_im2col_tma_atom_A_sm100(
|
| 234 |
+
GmemTiledCopyA{},
|
| 235 |
+
tensor_a,
|
| 236 |
+
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
| 237 |
+
TileShape{},
|
| 238 |
+
TiledMma{},
|
| 239 |
+
cluster_shape_vmnk,
|
| 240 |
+
shape(lower_corner_whd),
|
| 241 |
+
shape(upper_corner_whd),
|
| 242 |
+
cute::reverse(shape(problem_shape.lower_padding)),
|
| 243 |
+
cute::reverse(shape(problem_shape.upper_padding)),
|
| 244 |
+
cute::reverse(shape(problem_shape.traversal_stride)),
|
| 245 |
+
shape(lower_srt),
|
| 246 |
+
shape(stride_srt));
|
| 247 |
+
}
|
| 248 |
+
// TMA tiled mode for tensor A in wgrad and strided dgrad
|
| 249 |
+
else {
|
| 250 |
+
return make_tma_atom_A_sm100<TmaInternalElementA>(
|
| 251 |
+
GmemTiledCopyA{},
|
| 252 |
+
tensor_a,
|
| 253 |
+
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
| 254 |
+
TileShape{},
|
| 255 |
+
TiledMma{},
|
| 256 |
+
cluster_shape_vmnk);
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template <class TensorB, class ClusterShapeVMNK>
|
| 261 |
+
static constexpr auto
|
| 262 |
+
get_tma_load_b_instance(
|
| 263 |
+
TensorB const& tensor_b,
|
| 264 |
+
ProblemShape const& problem_shape,
|
| 265 |
+
ClusterShapeVMNK const& cluster_shape_vmnk) {
|
| 266 |
+
|
| 267 |
+
if constexpr (is_im2col_B) {
|
| 268 |
+
// compute the upper and lower corners based on the conv padding
|
| 269 |
+
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
| 270 |
+
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
| 271 |
+
auto lower_srt = detail::compute_lower_srt(problem_shape);
|
| 272 |
+
|
| 273 |
+
return make_im2col_tma_atom_B_sm100(
|
| 274 |
+
GmemTiledCopyB{},
|
| 275 |
+
tensor_b,
|
| 276 |
+
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
|
| 277 |
+
TileShape{},
|
| 278 |
+
TiledMma{},
|
| 279 |
+
cluster_shape_vmnk,
|
| 280 |
+
shape(lower_corner_whd),
|
| 281 |
+
shape(upper_corner_whd),
|
| 282 |
+
cute::reverse(shape(problem_shape.lower_padding)),
|
| 283 |
+
cute::reverse(shape(problem_shape.upper_padding)),
|
| 284 |
+
cute::reverse(shape(problem_shape.traversal_stride)),
|
| 285 |
+
shape(lower_srt),
|
| 286 |
+
cute::reverse(shape(problem_shape.dilation)));
|
| 287 |
+
}
|
| 288 |
+
else {
|
| 289 |
+
return make_tma_atom_B_sm100<TmaInternalElementB>(
|
| 290 |
+
GmemTiledCopyB{},
|
| 291 |
+
tensor_b,
|
| 292 |
+
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
|
| 293 |
+
TileShape{},
|
| 294 |
+
TiledMma{},
|
| 295 |
+
cluster_shape_vmnk);
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
public:
|
| 300 |
+
|
| 301 |
+
// Performs im2col transformations on the input of type ConvProblemShape
|
| 302 |
+
static constexpr auto
|
| 303 |
+
get_problem_shape_MNKL(ProblemShape const& problem_shape) {
|
| 304 |
+
if constexpr (is_im2col_A || is_im2col_B) {
|
| 305 |
+
// transformation + im2col linearization
|
| 306 |
+
return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape);
|
| 307 |
+
}
|
| 308 |
+
else {
|
| 309 |
+
// transformation
|
| 310 |
+
return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// Device-side kernel params
|
| 315 |
+
//
|
| 316 |
+
// Arguments has the untransformed problem shape from the user.
|
| 317 |
+
// Params will have the transformed problem shape.
|
| 318 |
+
struct Params {
|
| 319 |
+
using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{}));
|
| 320 |
+
|
| 321 |
+
using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return<IsDynamicCluster>(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})),
|
| 322 |
+
make_tile(typename TiledMma::AtomThrID{})));
|
| 323 |
+
|
| 324 |
+
// Assumption: StrideA is congruent with Problem_MK
|
| 325 |
+
// Select TMA load type according to convolution operator.
|
| 326 |
+
using TensorShapeA = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
|
| 327 |
+
decltype(repeat_like(StrideA{}, int32_t(0))),
|
| 328 |
+
decltype(make_shape(_Submode{}, int32_t(0)))>;
|
| 329 |
+
|
| 330 |
+
using TensorShapeB = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
|
| 331 |
+
decltype(make_shape(int32_t(0), _Submode{})),
|
| 332 |
+
decltype(repeat_like(StrideB{}, int32_t(0)))>;
|
| 333 |
+
|
| 334 |
+
using TMA_A = decltype(get_tma_load_a_instance(
|
| 335 |
+
make_tensor(
|
| 336 |
+
make_gmem_ptr(recast_ptr<TmaInternalElementA>(nullptr)),
|
| 337 |
+
make_layout(TensorShapeA{}, StrideA{})),
|
| 338 |
+
ConvProblemShape<ConvOp, NumSpatialDimensions>{},
|
| 339 |
+
ClusterLayout_VMNK{}));
|
| 340 |
+
|
| 341 |
+
using TMA_B = decltype(get_tma_load_b_instance(
|
| 342 |
+
make_tensor(
|
| 343 |
+
make_gmem_ptr(recast_ptr<TmaInternalElementB>(nullptr)),
|
| 344 |
+
make_layout(TensorShapeB{}, StrideB{})),
|
| 345 |
+
ConvProblemShape<ConvOp, NumSpatialDimensions>{},
|
| 346 |
+
ClusterLayout_VMNK{}));
|
| 347 |
+
|
| 348 |
+
// Members
|
| 349 |
+
TMA_A tma_load_a;
|
| 350 |
+
TMA_B tma_load_b;
|
| 351 |
+
TMA_A tma_load_a_fallback;
|
| 352 |
+
TMA_B tma_load_b_fallback;
|
| 353 |
+
dim3 cluster_shape_fallback;
|
| 354 |
+
};
|
| 355 |
+
|
| 356 |
+
//
|
| 357 |
+
// Constructor
|
| 358 |
+
//
|
| 359 |
+
CUTLASS_DEVICE
|
| 360 |
+
CollectiveConv(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
|
| 361 |
+
: cluster_shape_(cluster_shape)
|
| 362 |
+
, block_rank_in_cluster_(block_rank_in_cluster) {
|
| 363 |
+
if constexpr (IsDynamicCluster) {
|
| 364 |
+
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
|
| 365 |
+
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
|
| 366 |
+
observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a;
|
| 367 |
+
observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b;
|
| 368 |
+
}
|
| 369 |
+
else {
|
| 370 |
+
observed_tma_load_a_ = ¶ms.tma_load_a;
|
| 371 |
+
observed_tma_load_b_ = ¶ms.tma_load_b;
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
//
|
| 376 |
+
// Methods
|
| 377 |
+
//
|
| 378 |
+
|
| 379 |
+
static constexpr Params
|
| 380 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
|
| 381 |
+
(void) workspace;
|
| 382 |
+
|
| 383 |
+
// from the flat problem shape arrays of ConvProblemShape<N>, create a rank-3 MNK problem shape tuple
|
| 384 |
+
// tma desc creation depends on the original untransformed domain.
|
| 385 |
+
|
| 386 |
+
// A extents.
|
| 387 |
+
auto shape_A_orig = problem_shape.get_shape_A();
|
| 388 |
+
// B extents.
|
| 389 |
+
auto shape_B_orig = problem_shape.get_shape_B();
|
| 390 |
+
|
| 391 |
+
// Fill inferred cute strides from flat stride arrays
|
| 392 |
+
auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp);
|
| 393 |
+
auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp);
|
| 394 |
+
|
| 395 |
+
auto ptr_A = recast_ptr<TmaInternalElementA>(args.ptr_A);
|
| 396 |
+
auto ptr_B = recast_ptr<TmaInternalElementB>(args.ptr_B);
|
| 397 |
+
|
| 398 |
+
Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA));
|
| 399 |
+
Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB));
|
| 400 |
+
|
| 401 |
+
auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
|
| 402 |
+
// Cluster layout for TMA construction
|
| 403 |
+
auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
|
| 404 |
+
auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
|
| 405 |
+
|
| 406 |
+
// Cluster layout for TMA construction
|
| 407 |
+
auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{}));
|
| 408 |
+
|
| 409 |
+
auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk);
|
| 410 |
+
auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk);
|
| 411 |
+
auto tma_load_a_fallback = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk_fallback);
|
| 412 |
+
auto tma_load_b_fallback = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk_fallback);
|
| 413 |
+
|
| 414 |
+
static_assert(size(typename decltype(tma_load_a)::ThrID{}) == size(AtomThrShapeMNK{}));
|
| 415 |
+
static_assert(size(typename decltype(tma_load_b)::ThrID{}) == size(AtomThrShapeMNK{}));
|
| 416 |
+
|
| 417 |
+
return {
|
| 418 |
+
tma_load_a,
|
| 419 |
+
tma_load_b,
|
| 420 |
+
tma_load_a_fallback,
|
| 421 |
+
tma_load_b_fallback,
|
| 422 |
+
hw_info.cluster_shape_fallback
|
| 423 |
+
};
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
template<class ProblemShape>
|
| 427 |
+
static bool
|
| 428 |
+
can_implement(
|
| 429 |
+
ProblemShape const& problem_shape,
|
| 430 |
+
Arguments const& args) {
|
| 431 |
+
// Activation and Filter channel mode extents much match
|
| 432 |
+
bool implementable = true;
|
| 433 |
+
// channel mode is major
|
| 434 |
+
{
|
| 435 |
+
const bool check = problem_shape.stride_A[NumTensorDimensions-1] == 1;
|
| 436 |
+
#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0)
|
| 437 |
+
if (not check) {
|
| 438 |
+
const auto offending_stride =
|
| 439 |
+
problem_shape.stride_A[NumTensorDimensions-1];
|
| 440 |
+
std::ostringstream os;
|
| 441 |
+
os << "CollectiveConv::can_implement: "
|
| 442 |
+
"problem_shape.stride_A[NumTensorDimensions-1 = "
|
| 443 |
+
<< (NumTensorDimensions-1) << "] = "
|
| 444 |
+
<< offending_stride << " != 1";
|
| 445 |
+
CUTLASS_TRACE_HOST( os.str() );
|
| 446 |
+
}
|
| 447 |
+
#endif
|
| 448 |
+
implementable &= check;
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
{
|
| 452 |
+
const bool check = problem_shape.stride_B[NumTensorDimensions-1] == 1;
|
| 453 |
+
#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0)
|
| 454 |
+
if (not check) {
|
| 455 |
+
const auto offending_stride =
|
| 456 |
+
problem_shape.stride_B[NumTensorDimensions-1];
|
| 457 |
+
std::ostringstream os;
|
| 458 |
+
os << "CollectiveConv::can_implement: "
|
| 459 |
+
"problem_shape.stride_B[NumTensorDimensions-1 = "
|
| 460 |
+
<< (NumTensorDimensions-1) << "] = "
|
| 461 |
+
<< offending_stride << " != 1\n";
|
| 462 |
+
CUTLASS_TRACE_HOST( os.str() );
|
| 463 |
+
}
|
| 464 |
+
#endif
|
| 465 |
+
implementable &= check;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
{
|
| 469 |
+
const auto & traversal_stride = problem_shape.traversal_stride;
|
| 470 |
+
for (auto stride: traversal_stride) {
|
| 471 |
+
implementable &= (stride >= 1 && stride <= 8);
|
| 472 |
+
}
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
if constexpr (ConvOp == conv::Operator::kDgrad && not is_strided_dgrad) {
|
| 476 |
+
const auto & traversal_stride = problem_shape.traversal_stride;
|
| 477 |
+
for (auto stride: traversal_stride) {
|
| 478 |
+
implementable &= (stride == 1);
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
constexpr int tma_alignment_bits = 128;
|
| 483 |
+
// A extents.
|
| 484 |
+
auto shape_A_orig = problem_shape.get_shape_A();
|
| 485 |
+
// B extents.
|
| 486 |
+
auto shape_B_orig = problem_shape.get_shape_B();
|
| 487 |
+
|
| 488 |
+
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
| 489 |
+
{
|
| 490 |
+
const bool check = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(shape_A_orig, StrideA{});
|
| 491 |
+
if (not check) {
|
| 492 |
+
CUTLASS_TRACE_HOST("A shape and/or strides have alignment issue.");
|
| 493 |
+
}
|
| 494 |
+
implementable &= check;
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
| 498 |
+
{
|
| 499 |
+
const bool check = cutlass::detail::check_alignment<min_tma_aligned_elements_B>(shape_B_orig, StrideB{});
|
| 500 |
+
if (not check) {
|
| 501 |
+
CUTLASS_TRACE_HOST("B shape and/or strides have alignment issue.");
|
| 502 |
+
}
|
| 503 |
+
implementable &= check;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
if (not implementable) {
|
| 507 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
| 508 |
+
return false;
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
if (is_im2col_A || is_im2col_B) {
|
| 512 |
+
// Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
|
| 513 |
+
constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
|
| 514 |
+
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
| 515 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 516 |
+
implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
|
| 517 |
+
}
|
| 518 |
+
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
| 519 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 520 |
+
implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
if (!implementable) {
|
| 524 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
|
| 525 |
+
return false;
|
| 526 |
+
}
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
if (is_im2col_A || is_im2col_B) {
|
| 530 |
+
// Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit]
|
| 531 |
+
constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1;
|
| 532 |
+
auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B;
|
| 533 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 534 |
+
// flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array
|
| 535 |
+
implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0)
|
| 536 |
+
&& ((flt_data[i+1] - 1) * problem_shape.dilation[i] <= offset_limit);
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
if (!implementable) {
|
| 540 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n");
|
| 541 |
+
return false;
|
| 542 |
+
}
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
// Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
|
| 546 |
+
if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 547 |
+
|
| 548 |
+
const auto & input_shape = problem_shape.shape_A;
|
| 549 |
+
const auto & input_stride = problem_shape.stride_A;
|
| 550 |
+
|
| 551 |
+
implementable &= input_stride[ProblemShape::RankT - 1] == 1;
|
| 552 |
+
int64_t input_shape_size = 1;
|
| 553 |
+
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
|
| 554 |
+
input_shape_size *= input_shape[i + 1];
|
| 555 |
+
implementable &= input_stride[i] == input_shape_size;
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
const auto & output_shape = problem_shape.shape_C;
|
| 559 |
+
const auto & output_stride = problem_shape.stride_C;
|
| 560 |
+
|
| 561 |
+
implementable &= output_stride[ProblemShape::RankT - 1] == 1;
|
| 562 |
+
int64_t output_shape_size = 1;
|
| 563 |
+
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
|
| 564 |
+
output_shape_size *= output_shape[i + 1];
|
| 565 |
+
implementable &= output_stride[i] == output_shape_size;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
if (!implementable) {
|
| 569 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
|
| 570 |
+
return false;
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
// Conv kernels only support cross correlation mode currently.
|
| 575 |
+
{
|
| 576 |
+
implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;
|
| 577 |
+
|
| 578 |
+
if (!implementable) {
|
| 579 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
|
| 580 |
+
return false;
|
| 581 |
+
}
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
// When groups > 1, it should be a Grouped Conv.
|
| 585 |
+
if (problem_shape.groups > 1) {
|
| 586 |
+
implementable &= TileShapeMNKLRank > 3;
|
| 587 |
+
|
| 588 |
+
if (!implementable) {
|
| 589 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Only Grouped Conv can support groups > 1.\n");
|
| 590 |
+
return false;
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
// Only support Grouped Wgrad currently.
|
| 595 |
+
if constexpr (TileShapeMNKLRank > 3) {
|
| 596 |
+
implementable &= ConvOp == conv::Operator::kWgrad;
|
| 597 |
+
|
| 598 |
+
if (!implementable) {
|
| 599 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv Only support Grouped Wgrad currently.\n");
|
| 600 |
+
return false;
|
| 601 |
+
}
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
// Grouped Wgrad channel check.
|
| 605 |
+
if constexpr (is_grouped_wgrad) {
|
| 606 |
+
|
| 607 |
+
int input_K = size<0>(problem_shape.get_shape_A());
|
| 608 |
+
int input_C = size<0>(problem_shape.get_shape_B());
|
| 609 |
+
|
| 610 |
+
implementable &= input_K == input_C;
|
| 611 |
+
|
| 612 |
+
if (!implementable) {
|
| 613 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv's input K and input C do not match.\n");
|
| 614 |
+
return false;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
int output_K = size<0>(problem_shape.get_shape_C());
|
| 618 |
+
int output_C = size<1,0>(problem_shape.get_shape_C());
|
| 619 |
+
|
| 620 |
+
implementable &= input_K == output_K;
|
| 621 |
+
implementable &= input_C == output_C * problem_shape.groups;
|
| 622 |
+
|
| 623 |
+
if (!implementable) {
|
| 624 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's input and output K,C and groups do not match\n");
|
| 625 |
+
return false;
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
constexpr int Tile_N = size<1>(TileShape{});
|
| 629 |
+
constexpr int GroupsPerTile = size<3>(TileShapeMNKL_{});
|
| 630 |
+
|
| 631 |
+
implementable &= Tile_N / GroupsPerTile == input_C / problem_shape.groups;
|
| 632 |
+
|
| 633 |
+
if (!implementable) {
|
| 634 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's Tile_N, GroupsPerTile and input_C, groups do not match.\n");
|
| 635 |
+
return false;
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
// The extents of linearized problem shape should be int32_t type(maximum is 2^31-1).
|
| 640 |
+
if constexpr (is_im2col_A || is_im2col_B) {
|
| 641 |
+
auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
|
| 642 |
+
auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast<int64_t>(s); }); };
|
| 643 |
+
|
| 644 |
+
if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) {
|
| 645 |
+
implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits<int32_t>::max()) &
|
| 646 |
+
(cute::product(to_64b(L)) <= cutlass::platform::numeric_limits<int32_t>::max());
|
| 647 |
+
}
|
| 648 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 649 |
+
implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits<int32_t>::max());
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
if (!implementable) {
|
| 653 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n");
|
| 654 |
+
return false;
|
| 655 |
+
}
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
return true;
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
| 662 |
+
CUTLASS_DEVICE void
|
| 663 |
+
prefetch_tma_descriptors() {
|
| 664 |
+
cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor());
|
| 665 |
+
cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor());
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
/// Construct A Single Stage's Accumulator Shape
|
| 669 |
+
CUTLASS_DEVICE static auto
|
| 670 |
+
partition_accumulator_shape() {
|
| 671 |
+
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
|
| 672 |
+
|
| 673 |
+
return acc_shape;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
/// Perform a collective-scoped matrix multiply-accumulate
|
| 677 |
+
/// Producer Perspective
|
| 678 |
+
template <
|
| 679 |
+
class GTensorA, class GTensorB,
|
| 680 |
+
class GTensorPartitionedA, class GTensorPartitionedB,
|
| 681 |
+
class STensorA, class STensorB,
|
| 682 |
+
class TileCoordMNKL,
|
| 683 |
+
class KTileIterator
|
| 684 |
+
>
|
| 685 |
+
CUTLASS_DEVICE auto
|
| 686 |
+
load(
|
| 687 |
+
Params const& params,
|
| 688 |
+
MainloopPipeline pipeline,
|
| 689 |
+
MainloopPipelineState mainloop_pipe_producer_state,
|
| 690 |
+
cute::tuple<GTensorA, GTensorB,
|
| 691 |
+
GTensorPartitionedA, GTensorPartitionedB,
|
| 692 |
+
STensorA, STensorB,
|
| 693 |
+
uint16_t, uint16_t> const& load_inputs,
|
| 694 |
+
TileCoordMNKL const& cta_coord_mnkl,
|
| 695 |
+
KTileIterator k_tile_iter, int k_tile_count) {
|
| 696 |
+
|
| 697 |
+
auto [unused_gA, unused_gB,
|
| 698 |
+
tAgA_mk, tBgB_nk, tAsA, tBsB,
|
| 699 |
+
mcast_mask_a, mcast_mask_b] = load_inputs;
|
| 700 |
+
|
| 701 |
+
// slice out the work coord from partitioned tensors
|
| 702 |
+
Tensor tAgA = tAgA_mk(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _);
|
| 703 |
+
auto tensor_b_coord = get<1>(cta_coord_mnkl);
|
| 704 |
+
if constexpr (is_grouped_wgrad) {
|
| 705 |
+
// in grouped wgrad, tensor A = NZPQK, tensor B = NDHWC, tensor C = KTRSc, where C = G*c, c = channel_per_group = 8,16,32.
|
| 706 |
+
// CTA Tiling follows output tensor KTRSc. So cta_size_m = K/CTA_TILE_M. cta_size_n = T*R*S*ceil(c/CTA_TILE_N) = T*R*S*1 = T*R*S.
|
| 707 |
+
// tensor_a_coord = K_idx = cta_coord_m.
|
| 708 |
+
// tensor_b_coord = TRS_idx * C/CTA_TILE_N + C_idx = cta_coord_n * get<1,0>(shape(tBgB_nk) + cta_coord_m,
|
| 709 |
+
// because K == C and CTA_TILE_M == CTA_TILE_N => C_idx = K_idx = cta_coord_m.
|
| 710 |
+
tensor_b_coord = get<0>(cta_coord_mnkl) + get<1>(cta_coord_mnkl) * get<1,0>(shape(tBgB_nk));
|
| 711 |
+
}
|
| 712 |
+
Tensor tBgB = tBgB_nk(_, tensor_b_coord, _);
|
| 713 |
+
|
| 714 |
+
auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state);
|
| 715 |
+
|
| 716 |
+
// Issue the Mainloop loads
|
| 717 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 718 |
+
while (k_tile_count > 0) {
|
| 719 |
+
// LOCK mainloop_pipe_producer_state for _writing_
|
| 720 |
+
pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
|
| 721 |
+
|
| 722 |
+
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
| 723 |
+
BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state);
|
| 724 |
+
|
| 725 |
+
int write_stage = mainloop_pipe_producer_state.index();
|
| 726 |
+
++mainloop_pipe_producer_state;
|
| 727 |
+
barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state);
|
| 728 |
+
|
| 729 |
+
if constexpr (is_strided_dgrad) {
|
| 730 |
+
// construct gemm-k tile coord for gB
|
| 731 |
+
auto [conv_k, flt_coord, out_coord] = *k_tile_iter;
|
| 732 |
+
auto gemm_k_tile = prepend(flt_coord, conv_k); // (k,s,r,t)
|
| 733 |
+
|
| 734 |
+
// gA doesn't have a gemm-k (k,s,r,t) iterator mode because it's not an im2col tensor
|
| 735 |
+
auto offset_kqpzn = append(prepend(out_coord, _0{}),_0{}); // (k,q,p,z,n)
|
| 736 |
+
auto tAgA_offset = make_tensor(tAgA.data() + offset_kqpzn, tAgA.layout()); // (TMA, k)
|
| 737 |
+
|
| 738 |
+
if (cute::elect_one_sync()) {
|
| 739 |
+
copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA_offset(_,conv_k), tAsA(_,write_stage));
|
| 740 |
+
copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,gemm_k_tile) , tBsB(_,write_stage));
|
| 741 |
+
}
|
| 742 |
+
}
|
| 743 |
+
else {
|
| 744 |
+
if (cute::elect_one_sync()) {
|
| 745 |
+
copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
|
| 746 |
+
copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage));
|
| 747 |
+
}
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
--k_tile_count;
|
| 751 |
+
++k_tile_iter;
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
/// Set up the data needed by this collective for load.
|
| 758 |
+
/// Return tuple element contain
|
| 759 |
+
/// gA_mk - The tiled tma tensor for input A
|
| 760 |
+
/// gB_nk - The tiled tma tensor for input B
|
| 761 |
+
/// tAsA - partitioned smem tensor for A
|
| 762 |
+
/// tBsB - partitioned smem tensor for B
|
| 763 |
+
/// mcast_mask_a - tma multicast mask for A
|
| 764 |
+
/// mcast_mask_b - tma multicast mask for B
|
| 765 |
+
template <class ProblemShape_MNKL>
|
| 766 |
+
CUTLASS_DEVICE auto
|
| 767 |
+
load_init(
|
| 768 |
+
ProblemShape_MNKL const& problem_shape_MNKL,
|
| 769 |
+
Params const& params,
|
| 770 |
+
TensorStorage& shared_tensors) const {
|
| 771 |
+
using X = Underscore;
|
| 772 |
+
|
| 773 |
+
// Separate out problem shape for convenience
|
| 774 |
+
auto [M,N,K,L] = problem_shape_MNKL;
|
| 775 |
+
|
| 776 |
+
// Represent the full tensors -- get these from TMA
|
| 777 |
+
auto K_A = conditional_return<is_strided_dgrad>(get<0>(K), K);
|
| 778 |
+
Tensor mA_mk = observed_tma_load_a_->get_tma_tensor(make_shape(M, K_A));
|
| 779 |
+
Tensor mB_nk = observed_tma_load_b_->get_tma_tensor(make_shape(N, K));
|
| 780 |
+
|
| 781 |
+
// Tile the tensors and defer the slice
|
| 782 |
+
Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k)
|
| 783 |
+
Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k)
|
| 784 |
+
|
| 785 |
+
// Partition for this CTA
|
| 786 |
+
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
|
| 787 |
+
|
| 788 |
+
Tensor tCgA_mk = cta_mma.partition_A(gA_mk); // (MMA, MMA_M, MMA_K, m, k)
|
| 789 |
+
Tensor tCgB_nk = cta_mma.partition_B(gB_nk); // (MMA, MMA_N, MMA_K, n, k)
|
| 790 |
+
|
| 791 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
|
| 792 |
+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE)
|
| 793 |
+
|
| 794 |
+
// Define the CTA-in-cluster Layout and Coord
|
| 795 |
+
Layout cta_layout_mnk = make_layout(cluster_shape_);
|
| 796 |
+
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
|
| 797 |
+
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
|
| 798 |
+
|
| 799 |
+
// Project the cta_layout for tma_a along the n-modes
|
| 800 |
+
auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_,
|
| 801 |
+
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
|
| 802 |
+
group_modes<0,3>(sA), group_modes<0,3>(tCgA_mk));
|
| 803 |
+
|
| 804 |
+
// Project the cta_layout for tma_b along the m-modes
|
| 805 |
+
auto [tBgB_nk, tBsB] = tma_partition(*observed_tma_load_b_,
|
| 806 |
+
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
|
| 807 |
+
group_modes<0,3>(sB), group_modes<0,3>(tCgB_nk));
|
| 808 |
+
|
| 809 |
+
// TMA Multicast Masks
|
| 810 |
+
uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
|
| 811 |
+
uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
|
| 812 |
+
|
| 813 |
+
return cute::make_tuple(
|
| 814 |
+
gA_mk, gB_nk, // for scheduler
|
| 815 |
+
tAgA_mk, tBgB_nk, tAsA, tBsB, // for input tensor values
|
| 816 |
+
mcast_mask_a, mcast_mask_b); // multicast masks
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
/// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
|
| 820 |
+
CUTLASS_DEVICE void
|
| 821 |
+
load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) {
|
| 822 |
+
// Issue the epilogue waits
|
| 823 |
+
/* This helps avoid early exit of ctas in Cluster
|
| 824 |
+
* Waits for all stages to either be released (all
|
| 825 |
+
* Consumer UNLOCKs), or if the stage was never used
|
| 826 |
+
* then would just be acquired since the phase was
|
| 827 |
+
* still inverted from make_producer_start_state
|
| 828 |
+
*/
|
| 829 |
+
pipeline.producer_tail(mainloop_pipe_producer_state);
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
/// Perform a collective-scoped matrix multiply-accumulate
|
| 833 |
+
/// Consumer Perspective
|
| 834 |
+
template <
|
| 835 |
+
class FrgEngine, class FrgLayout,
|
| 836 |
+
class FragmentA, class FragmentB
|
| 837 |
+
>
|
| 838 |
+
CUTLASS_DEVICE auto
|
| 839 |
+
mma(MainloopPipeline pipeline,
|
| 840 |
+
MainloopPipelineState mainloop_pipe_consumer_state,
|
| 841 |
+
cute::Tensor<FrgEngine, FrgLayout>& accumulators,
|
| 842 |
+
cute::tuple<TiledMma, FragmentA, FragmentB> const& mma_inputs,
|
| 843 |
+
int k_tile_count)
|
| 844 |
+
{
|
| 845 |
+
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
|
| 846 |
+
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
|
| 847 |
+
|
| 848 |
+
auto [tiled_mma, tCrA, tCrB] = mma_inputs;
|
| 849 |
+
|
| 850 |
+
uint32_t skip_wait = k_tile_count <= 0;
|
| 851 |
+
auto barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
| 852 |
+
|
| 853 |
+
//
|
| 854 |
+
// PIPELINED MAIN LOOP
|
| 855 |
+
//
|
| 856 |
+
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
| 857 |
+
|
| 858 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 859 |
+
while (k_tile_count > 0) {
|
| 860 |
+
// WAIT on mainloop_pipe_consumer_state until its data are available (phase bit flips from mainloop_pipe_consumer_state.phase() value)
|
| 861 |
+
pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
|
| 862 |
+
|
| 863 |
+
// Compute on k_tile
|
| 864 |
+
int read_stage = mainloop_pipe_consumer_state.index();
|
| 865 |
+
// Save current mainlop pipeline read state
|
| 866 |
+
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
|
| 867 |
+
|
| 868 |
+
// Advance mainloop_pipe
|
| 869 |
+
++mainloop_pipe_consumer_state;
|
| 870 |
+
--k_tile_count;
|
| 871 |
+
skip_wait = k_tile_count <= 0;
|
| 872 |
+
// Peek at next iteration
|
| 873 |
+
barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
| 874 |
+
|
| 875 |
+
// Unroll the K mode manually so we can set scale C to 1
|
| 876 |
+
CUTLASS_PRAGMA_UNROLL
|
| 877 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
| 878 |
+
// (V,M,K) x (V,N,K) => (V,M,N)
|
| 879 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators);
|
| 880 |
+
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
| 881 |
+
}
|
| 882 |
+
pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
return mainloop_pipe_consumer_state;
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
CUTLASS_DEVICE auto
|
| 889 |
+
mma_init(TensorStorage& shared_tensors) const {
|
| 890 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 891 |
+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
| 892 |
+
|
| 893 |
+
TiledMma tiled_mma;
|
| 894 |
+
|
| 895 |
+
// Allocate "fragments/descriptors" for A and B matrices
|
| 896 |
+
Tensor tCrA = tiled_mma.make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
| 897 |
+
Tensor tCrB = tiled_mma.make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
| 898 |
+
|
| 899 |
+
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
|
| 900 |
+
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB)); // PIPE
|
| 901 |
+
return cute::make_tuple(tiled_mma, tCrA, tCrB);
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
private:
|
| 905 |
+
|
| 906 |
+
typename Params::TMA_A const* observed_tma_load_a_ = nullptr;
|
| 907 |
+
typename Params::TMA_B const* observed_tma_load_b_ = nullptr;
|
| 908 |
+
|
| 909 |
+
ClusterShape cluster_shape_;
|
| 910 |
+
uint32_t block_rank_in_cluster_;
|
| 911 |
+
};
|
| 912 |
+
|
| 913 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 914 |
+
|
| 915 |
+
} // namespace cutlass::conv::collective
|
| 916 |
+
|
| 917 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp
ADDED
|
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/cutlass.h"
|
| 34 |
+
|
| 35 |
+
#include "cute/arch/cluster_sm90.hpp"
|
| 36 |
+
#include "cute/arch/copy_sm90.hpp"
|
| 37 |
+
#include "cute/atom/mma_atom.hpp"
|
| 38 |
+
#include "cute/atom/copy_traits_sm90_im2col.hpp"
|
| 39 |
+
#include "cute/numeric/arithmetic_tuple.hpp"
|
| 40 |
+
#include "cute/algorithm/functional.hpp"
|
| 41 |
+
#include "cute/algorithm/gemm.hpp"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/detail.hpp"
|
| 44 |
+
#include "cutlass/conv/convolution.h"
|
| 45 |
+
#include "cutlass/conv/dispatch_policy.hpp"
|
| 46 |
+
#include "cutlass/pipeline/pipeline.hpp"
|
| 47 |
+
#include "cutlass/util/packed_stride.hpp"
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass::conv::collective {
|
| 52 |
+
using namespace cute;
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
template <
|
| 57 |
+
conv::Operator ConvOp,
|
| 58 |
+
int Stages,
|
| 59 |
+
int NumSpatialDims,
|
| 60 |
+
class ClusterShape,
|
| 61 |
+
class KernelSchedule,
|
| 62 |
+
int PipelineAsyncMmaStages,
|
| 63 |
+
class TileShape_,
|
| 64 |
+
class ElementA_,
|
| 65 |
+
class ElementB_,
|
| 66 |
+
class TiledMma_,
|
| 67 |
+
class TileTraitsA_,
|
| 68 |
+
class TileTraitsB_>
|
| 69 |
+
struct CollectiveConv<
|
| 70 |
+
MainloopSm90TmaGmmaWarpSpecializedImplicitGemm<
|
| 71 |
+
ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>,
|
| 72 |
+
TileShape_,
|
| 73 |
+
ElementA_,
|
| 74 |
+
ElementB_,
|
| 75 |
+
TiledMma_,
|
| 76 |
+
TileTraitsA_,
|
| 77 |
+
TileTraitsB_>
|
| 78 |
+
{
|
| 79 |
+
//
|
| 80 |
+
// Type Aliases
|
| 81 |
+
//
|
| 82 |
+
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm<
|
| 83 |
+
ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>;
|
| 84 |
+
using TileShape = TileShape_;
|
| 85 |
+
using ElementA = ElementA_;
|
| 86 |
+
using ElementB = ElementB_;
|
| 87 |
+
using TiledMma = TiledMma_;
|
| 88 |
+
using ElementAccumulator = typename TiledMma::ValTypeC;
|
| 89 |
+
using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy;
|
| 90 |
+
using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy;
|
| 91 |
+
using SmemLayoutA = typename TileTraitsA_::SmemLayout;
|
| 92 |
+
using SmemLayoutB = typename TileTraitsB_::SmemLayout;
|
| 93 |
+
using ArchTag = typename DispatchPolicy::ArchTag;
|
| 94 |
+
static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions;
|
| 95 |
+
static constexpr int NumTensorDimensions = NumSpatialDimensions + 2;
|
| 96 |
+
// Deduce the kernel-facing stride tuple types based on the dispatch policy
|
| 97 |
+
// (which is a function of the number of spatial dimensions, the algorithm, etc.)
|
| 98 |
+
using StrideA = decltype(detail::sm90_dispatch_policy_to_stride_A<DispatchPolicy>());
|
| 99 |
+
using StrideB = decltype(detail::sm90_dispatch_policy_to_stride_B<DispatchPolicy>());
|
| 100 |
+
|
| 101 |
+
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
| 102 |
+
|
| 103 |
+
using PipelineParams = typename MainloopPipeline::Params;
|
| 104 |
+
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
|
| 105 |
+
|
| 106 |
+
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
| 107 |
+
|
| 108 |
+
static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
|
| 109 |
+
static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
|
| 110 |
+
static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape.");
|
| 111 |
+
|
| 112 |
+
static_assert(rank(SmemLayoutB{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
|
| 113 |
+
static_assert((size<1>(TileShape{}) == size<0>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape.");
|
| 114 |
+
static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape.");
|
| 115 |
+
|
| 116 |
+
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
| 117 |
+
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
| 118 |
+
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
| 119 |
+
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
| 120 |
+
|
| 121 |
+
// The tma load mode of wgrad is tiled for tensor A and im2col for tensor B while the tma load mode of fprop and dgrad
|
| 122 |
+
// kernel is im2col for tensor A and tiled for tensor B.
|
| 123 |
+
static_assert((ConvOp == conv::Operator::kWgrad
|
| 124 |
+
&& (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>))
|
| 125 |
+
|| (ConvOp != conv::Operator::kWgrad
|
| 126 |
+
&& (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_IM2COL> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_IM2COL_MULTICAST>)),
|
| 127 |
+
"GmemTiledCopyA - invalid SM90 TMA copy atom specified.");
|
| 128 |
+
static_assert((ConvOp == conv::Operator::kWgrad
|
| 129 |
+
&& (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_IM2COL> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_IM2COL_MULTICAST>))
|
| 130 |
+
|| (ConvOp != conv::Operator::kWgrad
|
| 131 |
+
&& (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)),
|
| 132 |
+
"GmemTiledCopyB - invalid SM90 TMA copy atom specified.");
|
| 133 |
+
|
| 134 |
+
static constexpr bool is_im2col_A = detail::is_im2col_load<GmemTiledCopyA>::value;
|
| 135 |
+
static constexpr bool is_im2col_B = detail::is_im2col_load<GmemTiledCopyB>::value;
|
| 136 |
+
|
| 137 |
+
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
| 138 |
+
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
| 139 |
+
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
| 140 |
+
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
| 141 |
+
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
| 142 |
+
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
| 143 |
+
|
| 144 |
+
struct SharedStorage
|
| 145 |
+
{
|
| 146 |
+
struct TensorStorage : cute::aligned_struct<128, _0> {
|
| 147 |
+
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
| 148 |
+
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
| 149 |
+
} tensors;
|
| 150 |
+
|
| 151 |
+
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
| 152 |
+
PipelineStorage pipeline;
|
| 153 |
+
};
|
| 154 |
+
using TensorStorage = typename SharedStorage::TensorStorage;
|
| 155 |
+
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
| 156 |
+
|
| 157 |
+
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
| 158 |
+
static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages;
|
| 159 |
+
static constexpr uint32_t TmaTransactionBytes =
|
| 160 |
+
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof(InternalElementA)))+
|
| 161 |
+
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof(InternalElementB)));
|
| 162 |
+
|
| 163 |
+
// Host side kernel arguments
|
| 164 |
+
struct Arguments {
|
| 165 |
+
ElementA const* ptr_A{nullptr};
|
| 166 |
+
ElementB const* ptr_B{nullptr};
|
| 167 |
+
};
|
| 168 |
+
|
| 169 |
+
private:
|
| 170 |
+
// Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for
|
| 171 |
+
// tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor
|
| 172 |
+
// B since operand A, B is swapped.
|
| 173 |
+
// Get tma_load_a instantce.
|
| 174 |
+
template <class TensorA>
|
| 175 |
+
static constexpr auto
|
| 176 |
+
get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) {
|
| 177 |
+
if constexpr (is_im2col_A) {
|
| 178 |
+
// compute the upper and lower corners based on the conv padding
|
| 179 |
+
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
| 180 |
+
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
| 181 |
+
auto lower_srt = detail::compute_lower_srt(problem_shape);
|
| 182 |
+
|
| 183 |
+
// The calculation of gbasis strides for dgrad kernel needs perform negate for dilation values.
|
| 184 |
+
cute::array<int32_t, NumSpatialDimensions> stride_srt{};
|
| 185 |
+
for (int i = 0; i < NumSpatialDimensions; ++i) {
|
| 186 |
+
stride_srt[i] = ConvOp == conv::Operator::kDgrad ?
|
| 187 |
+
-problem_shape.dilation[NumSpatialDimensions-1-i] :
|
| 188 |
+
problem_shape.dilation[NumSpatialDimensions-1-i];
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
return make_im2col_tma_copy(
|
| 192 |
+
GmemTiledCopyA{},
|
| 193 |
+
tensor_a,
|
| 194 |
+
SmemLayoutA{}(_,_,_0{}),
|
| 195 |
+
product_each(shape(SmemLayoutA{}(_,_,_0{}))),
|
| 196 |
+
size<1>(ClusterShape{}),
|
| 197 |
+
shape(lower_corner_whd),
|
| 198 |
+
shape(upper_corner_whd),
|
| 199 |
+
cute::reverse(shape(problem_shape.lower_padding)),
|
| 200 |
+
cute::reverse(shape(problem_shape.upper_padding)),
|
| 201 |
+
cute::reverse(shape(problem_shape.traversal_stride)),
|
| 202 |
+
shape(lower_srt),
|
| 203 |
+
shape(stride_srt));
|
| 204 |
+
}
|
| 205 |
+
// TMA tiled mode for tensor A in wgrad kernel.
|
| 206 |
+
else {
|
| 207 |
+
return make_tma_copy(
|
| 208 |
+
GmemTiledCopyA{},
|
| 209 |
+
tensor_a,
|
| 210 |
+
SmemLayoutA{}(_,_,_0{}),
|
| 211 |
+
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
| 212 |
+
size<1>(ClusterShape{}));
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// Get tma_load_b instantce.
|
| 217 |
+
template <class TensorB>
|
| 218 |
+
static constexpr auto
|
| 219 |
+
get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) {
|
| 220 |
+
// TMA im2col mode for tensor B in wgrad kernel.
|
| 221 |
+
if constexpr (is_im2col_B) {
|
| 222 |
+
// compute the upper and lower corners based on the conv padding
|
| 223 |
+
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
| 224 |
+
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
| 225 |
+
auto lower_srt = detail::compute_lower_srt(problem_shape);
|
| 226 |
+
|
| 227 |
+
return make_im2col_tma_copy(
|
| 228 |
+
GmemTiledCopyB{},
|
| 229 |
+
tensor_b,
|
| 230 |
+
SmemLayoutB{}(_,_,_0{}),
|
| 231 |
+
product_each(shape(SmemLayoutB{}(_,_,_0{}))),
|
| 232 |
+
size<0>(ClusterShape{}),
|
| 233 |
+
shape(lower_corner_whd),
|
| 234 |
+
shape(upper_corner_whd),
|
| 235 |
+
cute::reverse(shape(problem_shape.lower_padding)),
|
| 236 |
+
cute::reverse(shape(problem_shape.upper_padding)),
|
| 237 |
+
cute::reverse(shape(problem_shape.traversal_stride)),
|
| 238 |
+
shape(lower_srt),
|
| 239 |
+
cute::reverse(shape(problem_shape.dilation)));
|
| 240 |
+
}
|
| 241 |
+
else {
|
| 242 |
+
return make_tma_copy(
|
| 243 |
+
GmemTiledCopyB{},
|
| 244 |
+
tensor_b,
|
| 245 |
+
SmemLayoutB{}(_,_,_0{}),
|
| 246 |
+
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
| 247 |
+
size<0>(ClusterShape{}));
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
public:
|
| 252 |
+
|
| 253 |
+
// Performs im2col transformations on the input of type ConvProblemShape
|
| 254 |
+
static constexpr auto
|
| 255 |
+
get_problem_shape_MNKL(ProblemShape const& problem_shape) {
|
| 256 |
+
|
| 257 |
+
if constexpr (is_im2col_A || is_im2col_B) {
|
| 258 |
+
// transformation + im2col linearization
|
| 259 |
+
return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape);
|
| 260 |
+
}
|
| 261 |
+
else {
|
| 262 |
+
// transformation
|
| 263 |
+
return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
// Device side kernel params
|
| 268 |
+
struct Params {
|
| 269 |
+
using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{}));
|
| 270 |
+
|
| 271 |
+
// Assumption: StrideA is congruent with Problem_MK
|
| 272 |
+
// Select TMA load type according to convolution operator.
|
| 273 |
+
using TensorShapeA = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
|
| 274 |
+
decltype(repeat_like(StrideA{}, int32_t(0))),
|
| 275 |
+
decltype(make_shape(_Submode{}, int(0)))>;
|
| 276 |
+
|
| 277 |
+
using TensorShapeB = cute::conditional_t<ConvOp == conv::Operator::kWgrad,
|
| 278 |
+
decltype(make_shape(int(0), _Submode{})),
|
| 279 |
+
decltype(repeat_like(StrideB{}, int32_t(0)))>;
|
| 280 |
+
|
| 281 |
+
using TMA_A = decltype(get_tma_load_a_instance(
|
| 282 |
+
make_tensor(
|
| 283 |
+
make_gmem_ptr(static_cast<InternalElementA const*>(nullptr)),
|
| 284 |
+
make_layout(TensorShapeA{}, StrideA{})),
|
| 285 |
+
ConvProblemShape<ConvOp, NumSpatialDimensions>{}));
|
| 286 |
+
|
| 287 |
+
using TMA_B = decltype(get_tma_load_b_instance(
|
| 288 |
+
make_tensor(
|
| 289 |
+
make_gmem_ptr(static_cast<InternalElementB const*>(nullptr)),
|
| 290 |
+
make_layout(TensorShapeB{}, StrideB{})),
|
| 291 |
+
ConvProblemShape<ConvOp, NumSpatialDimensions>{}));
|
| 292 |
+
|
| 293 |
+
// Members
|
| 294 |
+
TMA_A tma_load_a;
|
| 295 |
+
TMA_B tma_load_b;
|
| 296 |
+
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
| 297 |
+
};
|
| 298 |
+
|
| 299 |
+
//
|
| 300 |
+
// Methods
|
| 301 |
+
//
|
| 302 |
+
|
| 303 |
+
// Lowers the host side user facing arguments to the kernel facing lauch params
|
| 304 |
+
static constexpr Params
|
| 305 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
| 306 |
+
(void) workspace;
|
| 307 |
+
// from the flat problem shape arrays of ConvProblemShape<ConvOp, N>, create a rank-3 MNK problem shape tuple
|
| 308 |
+
// tma desc creation depends on the original untransformed domain.
|
| 309 |
+
|
| 310 |
+
// A extents.
|
| 311 |
+
auto shape_A_orig = problem_shape.get_shape_A();
|
| 312 |
+
// B extents.
|
| 313 |
+
auto shape_B_orig = problem_shape.get_shape_B();
|
| 314 |
+
|
| 315 |
+
// Fill inferred cute strides from flat stride arrays
|
| 316 |
+
auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp);
|
| 317 |
+
auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp);
|
| 318 |
+
|
| 319 |
+
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
| 320 |
+
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
|
| 321 |
+
|
| 322 |
+
Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA));
|
| 323 |
+
Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB));
|
| 324 |
+
|
| 325 |
+
auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape);
|
| 326 |
+
auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape);
|
| 327 |
+
|
| 328 |
+
return {
|
| 329 |
+
tma_load_a,
|
| 330 |
+
tma_load_b,
|
| 331 |
+
TmaTransactionBytes
|
| 332 |
+
};
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
template <class ProblemShape>
|
| 336 |
+
static bool
|
| 337 |
+
can_implement(
|
| 338 |
+
ProblemShape const& problem_shape,
|
| 339 |
+
Arguments const& args) {
|
| 340 |
+
// Activation and Filter channel mode extents much match
|
| 341 |
+
bool implementable = true;
|
| 342 |
+
// channel mode is major
|
| 343 |
+
implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1;
|
| 344 |
+
implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1;
|
| 345 |
+
|
| 346 |
+
constexpr int tma_alignment_bits = 128;
|
| 347 |
+
// A extents.
|
| 348 |
+
auto shape_A_orig = problem_shape.get_shape_A();
|
| 349 |
+
// B extents.
|
| 350 |
+
auto shape_B_orig = problem_shape.get_shape_B();
|
| 351 |
+
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
| 352 |
+
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(shape_A_orig, StrideA{});
|
| 353 |
+
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
| 354 |
+
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(shape_B_orig, StrideB{});
|
| 355 |
+
|
| 356 |
+
if (!implementable) {
|
| 357 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
| 358 |
+
return false;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
// Check valid padding values for TMA_LOAD_IM2COL
|
| 362 |
+
constexpr int padding_limit = (ProblemShape::RankS == 1) ? 65536 : (ProblemShape::RankS == 2 ? 256 : 16);
|
| 363 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 364 |
+
implementable = implementable && problem_shape.lower_padding[i] <= padding_limit && problem_shape.lower_padding[i] >= 0;
|
| 365 |
+
implementable = implementable && problem_shape.upper_padding[i] <= padding_limit && problem_shape.upper_padding[i] >= 0;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
if (!implementable) {
|
| 369 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
|
| 370 |
+
return false;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
if (is_im2col_A || is_im2col_B) {
|
| 374 |
+
// Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
|
| 375 |
+
constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
|
| 376 |
+
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
| 377 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 378 |
+
implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
|
| 379 |
+
}
|
| 380 |
+
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
| 381 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 382 |
+
implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
if (!implementable) {
|
| 386 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
|
| 387 |
+
return false;
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
if (is_im2col_A || is_im2col_B) {
|
| 392 |
+
// Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1]
|
| 393 |
+
constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1;
|
| 394 |
+
auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B;
|
| 395 |
+
for (int i = 0; i < problem_shape.RankS; ++i) {
|
| 396 |
+
// flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array
|
| 397 |
+
implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0)
|
| 398 |
+
&& ((flt_data[i+1] - 1) * problem_shape.dilation[i] < offset_limit);
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
if (!implementable) {
|
| 402 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n");
|
| 403 |
+
return false;
|
| 404 |
+
}
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
// Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
|
| 408 |
+
if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 409 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 410 |
+
std::ostringstream os;
|
| 411 |
+
#endif
|
| 412 |
+
const auto & input_shape = problem_shape.shape_A;
|
| 413 |
+
const auto & input_stride = problem_shape.stride_A;
|
| 414 |
+
|
| 415 |
+
implementable &= input_stride[ProblemShape::RankT - 1] == 1;
|
| 416 |
+
int64_t input_shape_size = 1;
|
| 417 |
+
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
|
| 418 |
+
input_shape_size *= input_shape[i + 1];
|
| 419 |
+
implementable &= input_stride[i] == input_shape_size;
|
| 420 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 421 |
+
if (input_stride[i] != input_shape_size) {
|
| 422 |
+
os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***";
|
| 423 |
+
}
|
| 424 |
+
#endif
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
if (!implementable) {
|
| 428 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 429 |
+
os << "\n input_shape_size: " << input_shape_size
|
| 430 |
+
<< "\n input_shape: " << input_shape
|
| 431 |
+
<< "\n input_stride: " << input_stride
|
| 432 |
+
<< "\n";
|
| 433 |
+
#endif
|
| 434 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n");
|
| 435 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 436 |
+
CUTLASS_TRACE_HOST(os.str());
|
| 437 |
+
#endif
|
| 438 |
+
return false;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
const auto & output_shape = problem_shape.shape_C;
|
| 442 |
+
const auto & output_stride = problem_shape.stride_C;
|
| 443 |
+
|
| 444 |
+
implementable &= output_stride[ProblemShape::RankT - 1] == 1;
|
| 445 |
+
int64_t output_shape_size = 1;
|
| 446 |
+
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
|
| 447 |
+
output_shape_size *= output_shape[i + 1];
|
| 448 |
+
implementable &= output_stride[i] == output_shape_size;
|
| 449 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 450 |
+
if (output_stride[i] != output_shape_size) {
|
| 451 |
+
os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***";
|
| 452 |
+
}
|
| 453 |
+
#endif
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
if (!implementable) {
|
| 457 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 458 |
+
os << "\n output_shape_size: " << input_shape_size
|
| 459 |
+
<< "\n output_shape: " << input_shape
|
| 460 |
+
<< "\n output_stride: " << input_stride
|
| 461 |
+
<< "\n";
|
| 462 |
+
#endif
|
| 463 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
|
| 464 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 465 |
+
CUTLASS_TRACE_HOST(os.str());
|
| 466 |
+
#endif
|
| 467 |
+
return false;
|
| 468 |
+
}
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
// Conv kernels only support cross correlation mode currently.
|
| 472 |
+
implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;
|
| 473 |
+
|
| 474 |
+
if (!implementable) {
|
| 475 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
|
| 476 |
+
return false;
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
if (problem_shape.groups > 1) {
|
| 480 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n");
|
| 481 |
+
return false;
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
if constexpr (is_im2col_A || is_im2col_B) {
|
| 485 |
+
auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape);
|
| 486 |
+
auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast<int64_t>(s); }); };
|
| 487 |
+
|
| 488 |
+
if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) {
|
| 489 |
+
implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits<int32_t>::max()) &
|
| 490 |
+
(cute::product(to_64b(L)) <= cutlass::platform::numeric_limits<int32_t>::max());
|
| 491 |
+
}
|
| 492 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 493 |
+
implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits<int32_t>::max());
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
if (!implementable) {
|
| 497 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n");
|
| 498 |
+
return false;
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
return true;
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
| 506 |
+
CUTLASS_DEVICE
|
| 507 |
+
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
| 508 |
+
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
| 509 |
+
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
/// Set up the data needed by this collective for load and mma.
|
| 513 |
+
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
| 514 |
+
/// Returned tuple must contain at least two elements, with the first two elements being:
|
| 515 |
+
/// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k)
|
| 516 |
+
/// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k)
|
| 517 |
+
/// The rest of the tensors can be specified as needed by this collective.
|
| 518 |
+
/// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with
|
| 519 |
+
/// StrideA and StrideB set up for TMA
|
| 520 |
+
template <class ProblemShapeMNKL>
|
| 521 |
+
CUTLASS_DEVICE auto
|
| 522 |
+
load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){
|
| 523 |
+
//load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
| 524 |
+
using X = Underscore;
|
| 525 |
+
// Separate out problem shape for convenience
|
| 526 |
+
auto [M, N, K, L] = problem_shape_MNKL;
|
| 527 |
+
|
| 528 |
+
// TMA requires special handling of strides to deal with coord codomain mapping
|
| 529 |
+
// Represent the full tensors -- get these from TMA
|
| 530 |
+
Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k)
|
| 531 |
+
Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k)
|
| 532 |
+
|
| 533 |
+
// Make tiled views, defer the slice
|
| 534 |
+
Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k)
|
| 535 |
+
Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k)
|
| 536 |
+
|
| 537 |
+
return cute::make_tuple(gA_mk, gB_nk);
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
/// Perform a collective-scoped matrix multiply-accumulate
|
| 541 |
+
/// Producer Perspective
|
| 542 |
+
template <
|
| 543 |
+
class TensorA, class TensorB,
|
| 544 |
+
class KTileIterator, class BlockCoord
|
| 545 |
+
>
|
| 546 |
+
CUTLASS_DEVICE void
|
| 547 |
+
load(
|
| 548 |
+
Params const& mainloop_params,
|
| 549 |
+
MainloopPipeline pipeline,
|
| 550 |
+
PipelineState smem_pipe_producer_state,
|
| 551 |
+
cute::tuple<TensorA, TensorB> const& load_inputs,
|
| 552 |
+
BlockCoord const& blk_coord,
|
| 553 |
+
KTileIterator k_tile_iter, int k_tile_count,
|
| 554 |
+
int thread_idx,
|
| 555 |
+
uint32_t block_rank_in_cluster,
|
| 556 |
+
TensorStorage& shared_tensors) {
|
| 557 |
+
|
| 558 |
+
int lane_predicate = cute::elect_one_sync();
|
| 559 |
+
if (lane_predicate) {
|
| 560 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 561 |
+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
| 562 |
+
|
| 563 |
+
//
|
| 564 |
+
// Prepare the TMA loads for A and B
|
| 565 |
+
//
|
| 566 |
+
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
| 567 |
+
|
| 568 |
+
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
| 569 |
+
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
| 570 |
+
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
| 571 |
+
|
| 572 |
+
auto [gA_mk, gB_nk] = load_inputs;
|
| 573 |
+
|
| 574 |
+
// Partition the inputs based on the current block coordinates.
|
| 575 |
+
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
| 576 |
+
|
| 577 |
+
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k)
|
| 578 |
+
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k)
|
| 579 |
+
|
| 580 |
+
// Applies the mapping from block_tma_a
|
| 581 |
+
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
| 582 |
+
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
| 583 |
+
|
| 584 |
+
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
| 585 |
+
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
| 586 |
+
|
| 587 |
+
uint16_t mcast_mask_a = 0;
|
| 588 |
+
uint16_t mcast_mask_b = 0;
|
| 589 |
+
|
| 590 |
+
// Issue TmaLoads
|
| 591 |
+
// Maps the tile -> block, value
|
| 592 |
+
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
|
| 593 |
+
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
| 594 |
+
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 595 |
+
for (int n = 0; n < size<1>(block_layout); ++n) {
|
| 596 |
+
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
| 597 |
+
}
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
|
| 601 |
+
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
| 602 |
+
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
| 603 |
+
for (int m = 0; m < size<0>(block_layout); ++m) {
|
| 604 |
+
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
// Mainloop
|
| 609 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 610 |
+
for ( ; k_tile_count > 0; --k_tile_count) {
|
| 611 |
+
// LOCK smem_pipe_producer_state for _writing_
|
| 612 |
+
pipeline.producer_acquire(smem_pipe_producer_state);
|
| 613 |
+
|
| 614 |
+
//
|
| 615 |
+
// Copy gmem to smem for *k_tile_iter
|
| 616 |
+
//
|
| 617 |
+
|
| 618 |
+
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
| 619 |
+
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state);
|
| 620 |
+
|
| 621 |
+
int write_stage = smem_pipe_producer_state.index();
|
| 622 |
+
|
| 623 |
+
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
| 624 |
+
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
| 625 |
+
++k_tile_iter;
|
| 626 |
+
|
| 627 |
+
// Advance smem_pipe_producer_state
|
| 628 |
+
++smem_pipe_producer_state;
|
| 629 |
+
}
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
| 634 |
+
CUTLASS_DEVICE void
|
| 635 |
+
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) {
|
| 636 |
+
int lane_predicate = cute::elect_one_sync();
|
| 637 |
+
|
| 638 |
+
// Issue the epilogue waits
|
| 639 |
+
if (lane_predicate) {
|
| 640 |
+
/* This helps avoid early exit of blocks in Cluster
|
| 641 |
+
* Waits for all stages to either be released (all
|
| 642 |
+
* Consumer UNLOCKs), or if the stage was never used
|
| 643 |
+
* then would just be acquired since the phase was
|
| 644 |
+
* still inverted from make_producer_start_state
|
| 645 |
+
*/
|
| 646 |
+
pipeline.producer_tail(smem_pipe_producer_state);
|
| 647 |
+
}
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
/// Perform a collective-scoped matrix multiply-accumulate
|
| 651 |
+
/// Consumer Perspective
|
| 652 |
+
template <class FrgTensorC>
|
| 653 |
+
CUTLASS_DEVICE void
|
| 654 |
+
mma(MainloopPipeline pipeline,
|
| 655 |
+
PipelineState smem_pipe_consumer_state,
|
| 656 |
+
FrgTensorC& accum,
|
| 657 |
+
int k_tile_count,
|
| 658 |
+
int thread_idx,
|
| 659 |
+
TensorStorage& shared_tensors,
|
| 660 |
+
Params const& mainloop_params) {
|
| 661 |
+
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
| 662 |
+
|
| 663 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
| 664 |
+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
| 665 |
+
|
| 666 |
+
//
|
| 667 |
+
// Define C accumulators and A/B partitioning
|
| 668 |
+
//
|
| 669 |
+
|
| 670 |
+
TiledMma tiled_mma;
|
| 671 |
+
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
| 672 |
+
|
| 673 |
+
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
| 674 |
+
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
| 675 |
+
|
| 676 |
+
// Allocate "fragments/descriptors"
|
| 677 |
+
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
| 678 |
+
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
| 679 |
+
|
| 680 |
+
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
| 681 |
+
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
| 682 |
+
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
| 683 |
+
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
| 684 |
+
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
| 685 |
+
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
| 686 |
+
|
| 687 |
+
//
|
| 688 |
+
// PIPELINED MAIN LOOP
|
| 689 |
+
//
|
| 690 |
+
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
| 691 |
+
"ERROR : Incorrect number of MMAs in flight");
|
| 692 |
+
|
| 693 |
+
// We release buffers to producer warps(dma load) with some mmas in flight
|
| 694 |
+
PipelineState smem_pipe_release = smem_pipe_consumer_state;
|
| 695 |
+
|
| 696 |
+
// Prologue GMMAs
|
| 697 |
+
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
| 698 |
+
|
| 699 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
| 700 |
+
|
| 701 |
+
warpgroup_fence_operand(accum);
|
| 702 |
+
CUTLASS_PRAGMA_UNROLL
|
| 703 |
+
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) {
|
| 704 |
+
// WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value)
|
| 705 |
+
pipeline.consumer_wait(smem_pipe_consumer_state);
|
| 706 |
+
|
| 707 |
+
int read_stage = smem_pipe_consumer_state.index();
|
| 708 |
+
warpgroup_arrive();
|
| 709 |
+
// Unroll the K mode manually to set scale D to 1
|
| 710 |
+
CUTLASS_PRAGMA_UNROLL
|
| 711 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
| 712 |
+
// (V,M,K) x (V,N,K) => (V,M,N)
|
| 713 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
| 714 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
warpgroup_commit_batch();
|
| 718 |
+
|
| 719 |
+
++smem_pipe_consumer_state;
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
warpgroup_fence_operand(accum);
|
| 723 |
+
// Mainloop GMMAs
|
| 724 |
+
k_tile_count -= prologue_mma_count;
|
| 725 |
+
|
| 726 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 727 |
+
for ( ; k_tile_count > 0; --k_tile_count) {
|
| 728 |
+
// WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value)
|
| 729 |
+
pipeline.consumer_wait(smem_pipe_consumer_state);
|
| 730 |
+
|
| 731 |
+
//
|
| 732 |
+
// Compute on k_tile
|
| 733 |
+
//
|
| 734 |
+
|
| 735 |
+
int read_stage = smem_pipe_consumer_state.index();
|
| 736 |
+
warpgroup_fence_operand(accum);
|
| 737 |
+
warpgroup_arrive();
|
| 738 |
+
// Unroll the K mode manually to set scale D to 1
|
| 739 |
+
CUTLASS_PRAGMA_UNROLL
|
| 740 |
+
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
| 741 |
+
// (V,M) x (V,N) => (V,M,N)
|
| 742 |
+
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
| 743 |
+
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
| 744 |
+
}
|
| 745 |
+
warpgroup_commit_batch();
|
| 746 |
+
|
| 747 |
+
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_producer_state is consumed
|
| 748 |
+
warpgroup_wait<K_PIPE_MMAS>();
|
| 749 |
+
warpgroup_fence_operand(accum);
|
| 750 |
+
|
| 751 |
+
// UNLOCK smem_pipe_release, done _computing_ on it
|
| 752 |
+
pipeline.consumer_release(smem_pipe_release);
|
| 753 |
+
|
| 754 |
+
// Advance smem_pipe_consumer_state and smem_pipe_release
|
| 755 |
+
++smem_pipe_consumer_state;
|
| 756 |
+
++smem_pipe_release;
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
warpgroup_fence_operand(accum);
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
/// Perform a Consumer Epilogue to release all buffers
|
| 763 |
+
CUTLASS_DEVICE void
|
| 764 |
+
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
| 765 |
+
// Prologue GMMAs
|
| 766 |
+
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
| 767 |
+
k_tile_count -= prologue_mma_count;
|
| 768 |
+
|
| 769 |
+
smem_pipe_release.advance(k_tile_count);
|
| 770 |
+
|
| 771 |
+
// Wait on all GMMAs to complete
|
| 772 |
+
warpgroup_wait<0>();
|
| 773 |
+
|
| 774 |
+
for (int count = 0; count < prologue_mma_count; ++count) {
|
| 775 |
+
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
| 776 |
+
++smem_pipe_release;
|
| 777 |
+
}
|
| 778 |
+
}
|
| 779 |
+
};
|
| 780 |
+
|
| 781 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 782 |
+
|
| 783 |
+
} // namespace cutlass::conv::collective
|
| 784 |
+
|
| 785 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief This file contains definitions and utility functions for describing convolution problem sizes.
|
| 33 |
+
|
| 34 |
+
Conv2dProblem desciption:
|
| 35 |
+
activation (NHWC),
|
| 36 |
+
filter (KRSC),
|
| 37 |
+
output (NPQK),
|
| 38 |
+
pading (pad_h, pad_w),
|
| 39 |
+
stride (stride_h, stride_w),
|
| 40 |
+
dilation (dilation_h, dilation_w).
|
| 41 |
+
|
| 42 |
+
Free functions to map:
|
| 43 |
+
Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
|
| 44 |
+
Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
|
| 45 |
+
Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
|
| 46 |
+
*/
|
| 47 |
+
|
| 48 |
+
#pragma once
|
| 49 |
+
|
| 50 |
+
#include "cutlass/cutlass.h"
|
| 51 |
+
#include "cutlass/tensor_coord.h"
|
| 52 |
+
#include "cutlass/fast_math.h"
|
| 53 |
+
#include "cutlass/gemm/gemm_enumerated_types.h"
|
| 54 |
+
#include "cutlass/matrix_coord.h"
|
| 55 |
+
#include "cutlass/conv/convolution.h"
|
| 56 |
+
#include "cutlass/functional.h"
|
| 57 |
+
|
| 58 |
+
namespace cutlass {
|
| 59 |
+
namespace conv {
|
| 60 |
+
|
| 61 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
/// Problem size structure
|
| 64 |
+
struct Conv2dProblemSize {
|
| 65 |
+
|
| 66 |
+
// Conv2d strictly problem size parameters
|
| 67 |
+
int N, H, W, C, P, Q, K, R, S;
|
| 68 |
+
int pad_h, pad_w;
|
| 69 |
+
int stride_h, stride_w;
|
| 70 |
+
int dilation_h, dilation_w;
|
| 71 |
+
Mode mode;
|
| 72 |
+
|
| 73 |
+
// Conv2d implementation-related parameters
|
| 74 |
+
int split_k_slices;
|
| 75 |
+
int groups;
|
| 76 |
+
|
| 77 |
+
//
|
| 78 |
+
// Methods
|
| 79 |
+
//
|
| 80 |
+
|
| 81 |
+
public:
|
| 82 |
+
CUTLASS_HOST_DEVICE
|
| 83 |
+
Conv2dProblemSize():
|
| 84 |
+
N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0),
|
| 85 |
+
pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1),
|
| 86 |
+
mode(Mode::kConvolution), split_k_slices(1), groups(1) { }
|
| 87 |
+
|
| 88 |
+
/// Constructor for default padding, stride, dilation, and split-K
|
| 89 |
+
CUTLASS_HOST_DEVICE
|
| 90 |
+
Conv2dProblemSize(
|
| 91 |
+
int N,
|
| 92 |
+
int H,
|
| 93 |
+
int W,
|
| 94 |
+
int C,
|
| 95 |
+
int P,
|
| 96 |
+
int Q,
|
| 97 |
+
int K,
|
| 98 |
+
int R,
|
| 99 |
+
int S,
|
| 100 |
+
Mode mode
|
| 101 |
+
):
|
| 102 |
+
N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S),
|
| 103 |
+
pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1),
|
| 104 |
+
mode(mode), split_k_slices(1), groups (1) { }
|
| 105 |
+
|
| 106 |
+
/// Constructor
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
Conv2dProblemSize(
|
| 109 |
+
int N,
|
| 110 |
+
int H,
|
| 111 |
+
int W,
|
| 112 |
+
int C,
|
| 113 |
+
int K,
|
| 114 |
+
int R,
|
| 115 |
+
int S,
|
| 116 |
+
int P,
|
| 117 |
+
int Q,
|
| 118 |
+
int pad_h,
|
| 119 |
+
int pad_w,
|
| 120 |
+
int stride_h,
|
| 121 |
+
int stride_w,
|
| 122 |
+
int dilation_h,
|
| 123 |
+
int dilation_w,
|
| 124 |
+
Mode mode,
|
| 125 |
+
int split_k_slices = 1,
|
| 126 |
+
int groups = 1
|
| 127 |
+
):
|
| 128 |
+
N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S),
|
| 129 |
+
pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w),
|
| 130 |
+
dilation_h(dilation_h), dilation_w(dilation_w),
|
| 131 |
+
mode(mode), split_k_slices(split_k_slices), groups (groups) { }
|
| 132 |
+
|
| 133 |
+
/// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
|
| 134 |
+
// set user-defined output size and sets P and Q (include all data members in ctor)
|
| 135 |
+
CUTLASS_HOST_DEVICE
|
| 136 |
+
Conv2dProblemSize(
|
| 137 |
+
cutlass::Tensor4DCoord input_size, // NHWC
|
| 138 |
+
cutlass::Tensor4DCoord filter_size, // KRSC
|
| 139 |
+
cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _
|
| 140 |
+
cutlass::MatrixCoord stride, // stride_h, stride_w
|
| 141 |
+
cutlass::MatrixCoord dilation, // dilation_h, dilation_w
|
| 142 |
+
cutlass::Tensor4DCoord output_size, // NPQK
|
| 143 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
| 144 |
+
int split_k_slices = 1,
|
| 145 |
+
int groups = 1
|
| 146 |
+
):
|
| 147 |
+
N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
|
| 148 |
+
P(output_size.h()), Q(output_size.w()),
|
| 149 |
+
K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
|
| 150 |
+
pad_h(padding[0]), pad_w(padding[2]),
|
| 151 |
+
stride_h(stride.row()), stride_w(stride.column()),
|
| 152 |
+
dilation_h(dilation.row()), dilation_w(dilation.column()),
|
| 153 |
+
mode(mode), split_k_slices(split_k_slices), groups(groups) {}
|
| 154 |
+
|
| 155 |
+
/// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
|
| 156 |
+
// computes output size and sets P and Q (skip output from ctor arguments)
|
| 157 |
+
CUTLASS_HOST_DEVICE
|
| 158 |
+
Conv2dProblemSize(
|
| 159 |
+
cutlass::Tensor4DCoord input_size, // NHWC
|
| 160 |
+
cutlass::Tensor4DCoord filter_size, // KRSC
|
| 161 |
+
cutlass::Tensor4DCoord padding, // pad_h, upper_pad_h, pad_w, upper_pad_w
|
| 162 |
+
cutlass::MatrixCoord stride, // stride_h, stride_w
|
| 163 |
+
cutlass::MatrixCoord dilation, // dilation_h, dilation_w
|
| 164 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
| 165 |
+
int split_k_slices = 1,
|
| 166 |
+
int groups = 1
|
| 167 |
+
):
|
| 168 |
+
N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
|
| 169 |
+
K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
|
| 170 |
+
pad_h(padding[0]), pad_w(padding[2]),
|
| 171 |
+
stride_h(stride.row()), stride_w(stride.column()),
|
| 172 |
+
dilation_h(dilation.row()), dilation_w(dilation.column()),
|
| 173 |
+
mode(mode), split_k_slices(split_k_slices), groups(groups) {
|
| 174 |
+
// set output P and Q
|
| 175 |
+
P = ((H + pad_h + padding[1] - R * dilation_h) / stride_h) + 1;
|
| 176 |
+
Q = ((W + pad_w + padding[3] - S * dilation_w) / stride_w) + 1;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
/// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
|
| 180 |
+
// set user-defined output size and sets P and Q (skip padding, striding, and dilation)
|
| 181 |
+
CUTLASS_HOST_DEVICE
|
| 182 |
+
Conv2dProblemSize(
|
| 183 |
+
cutlass::Tensor4DCoord input_size, // NHWC
|
| 184 |
+
cutlass::Tensor4DCoord filter_size, // KRSC
|
| 185 |
+
cutlass::Tensor4DCoord output_size, // NPQK
|
| 186 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
| 187 |
+
int split_k_slices = 1,
|
| 188 |
+
int groups = 1
|
| 189 |
+
):
|
| 190 |
+
N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
|
| 191 |
+
P(output_size.h()), Q(output_size.w()),
|
| 192 |
+
K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
|
| 193 |
+
pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1),
|
| 194 |
+
dilation_h(1), dilation_w(1),
|
| 195 |
+
mode(mode), split_k_slices(split_k_slices), groups(groups) {}
|
| 196 |
+
|
| 197 |
+
// Reset covolution mode in the problem
|
| 198 |
+
CUTLASS_HOST_DEVICE
|
| 199 |
+
Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) {
|
| 200 |
+
Conv2dProblemSize tmp(*this);
|
| 201 |
+
tmp.mode = mode_;
|
| 202 |
+
return tmp;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Reset covolution mode in the problem
|
| 206 |
+
CUTLASS_HOST_DEVICE
|
| 207 |
+
Conv2dProblemSize reset_split_k_slices(int split_k_slices_) {
|
| 208 |
+
Conv2dProblemSize tmp(*this);
|
| 209 |
+
tmp.split_k_slices = split_k_slices_;
|
| 210 |
+
return tmp;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Equality operator (ignores mode and split_k_slice)
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
bool operator==(Conv2dProblemSize const &conv) const {
|
| 216 |
+
return (
|
| 217 |
+
(N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) &&
|
| 218 |
+
(K == conv.K) && (R == conv.R) && (S == conv.S) &&
|
| 219 |
+
(P == conv.P) && (Q == conv.Q) &&
|
| 220 |
+
(pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
|
| 221 |
+
(stride_h == conv.stride_h) && (stride_w == conv.stride_w) &&
|
| 222 |
+
(dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w)
|
| 223 |
+
);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Inequality operator
|
| 227 |
+
CUTLASS_HOST_DEVICE
|
| 228 |
+
bool operator!=(Conv2dProblemSize const &rhs) const {
|
| 229 |
+
return !(*this == rhs);
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
/// Returns activation extent as Tensor4DCoord
|
| 233 |
+
CUTLASS_HOST_DEVICE
|
| 234 |
+
cutlass::Tensor4DCoord activation_extent() const {
|
| 235 |
+
|
| 236 |
+
return cutlass::Tensor4DCoord ({N, H, W, C});
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
/// Returns filter extent as Tensor4DCoord
|
| 240 |
+
CUTLASS_HOST_DEVICE
|
| 241 |
+
cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const {
|
| 242 |
+
|
| 243 |
+
return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups})
|
| 244 |
+
: cutlass::Tensor4DCoord ({K, R, S, C / groups});
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/// Returns output extent as Tensor4DCoord
|
| 248 |
+
CUTLASS_HOST_DEVICE
|
| 249 |
+
cutlass::Tensor4DCoord output_extent() const {
|
| 250 |
+
|
| 251 |
+
return cutlass::Tensor4DCoord ({N, P, Q, K});
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
/// Returns activation size in number of elements
|
| 255 |
+
CUTLASS_HOST_DEVICE
|
| 256 |
+
int64_t activation_size() const {
|
| 257 |
+
|
| 258 |
+
return static_cast<int64_t>(N) * static_cast<int64_t>(H) *
|
| 259 |
+
static_cast<int64_t>(W) * static_cast<int64_t>(C);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/// Returns filter size in number of elements
|
| 263 |
+
CUTLASS_HOST_DEVICE
|
| 264 |
+
int64_t filter_size() const {
|
| 265 |
+
|
| 266 |
+
return static_cast<int64_t>(K) * static_cast<int64_t>(R) *
|
| 267 |
+
static_cast<int64_t>(S) * static_cast<int64_t>(C) /
|
| 268 |
+
static_cast<int64_t>(groups);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
/// Returns output size in number of elements
|
| 272 |
+
CUTLASS_HOST_DEVICE
|
| 273 |
+
int64_t output_size() const {
|
| 274 |
+
|
| 275 |
+
return static_cast<int64_t>(N) * static_cast<int64_t>(P) *
|
| 276 |
+
static_cast<int64_t>(Q) * static_cast<int64_t>(K);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/// Returns padding as Tensor4DCoord
|
| 280 |
+
CUTLASS_HOST_DEVICE
|
| 281 |
+
cutlass::Tensor4DCoord padding() const {
|
| 282 |
+
|
| 283 |
+
return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w});
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
/// Returns stride as MatrixCoord
|
| 287 |
+
CUTLASS_HOST_DEVICE
|
| 288 |
+
cutlass::MatrixCoord stride() const {
|
| 289 |
+
|
| 290 |
+
return cutlass::MatrixCoord ({stride_h, stride_w});
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
/// Returns dilation as MatrixCoord
|
| 294 |
+
CUTLASS_HOST_DEVICE
|
| 295 |
+
cutlass::MatrixCoord dilation() const {
|
| 296 |
+
|
| 297 |
+
return cutlass::MatrixCoord ({dilation_h, dilation_w});
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
/////////////////////////////////////////////////////////////////
|
| 301 |
+
// Methods used for strided dgrad implementation
|
| 302 |
+
/////////////////////////////////////////////////////////////////
|
| 303 |
+
/// Number of filter r positions to accumulate in gemm-k dim
|
| 304 |
+
CUTLASS_HOST_DEVICE
|
| 305 |
+
int num_gemm_k_filter_r(int r) const {
|
| 306 |
+
return ((R - r + stride_h - 1) / stride_h);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
/// Number of filter s positions to accumulate in gemm-k dim
|
| 310 |
+
CUTLASS_HOST_DEVICE
|
| 311 |
+
int num_gemm_k_filter_s(int s) const {
|
| 312 |
+
return ((S - s + stride_w - 1) / stride_w);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
/// Number of filter positions to accumulate in gemm-k dim
|
| 316 |
+
CUTLASS_HOST_DEVICE
|
| 317 |
+
int num_gemm_k_filter_positions(int r, int s) const {
|
| 318 |
+
return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s);
|
| 319 |
+
}
|
| 320 |
+
};
|
| 321 |
+
|
| 322 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 323 |
+
// ImplicitGemm helper functions //
|
| 324 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 325 |
+
|
| 326 |
+
/// Determine the problem size of the implicit GEMM operation
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
cutlass::gemm::GemmCoord implicit_gemm_problem_size(
|
| 329 |
+
Operator conv_operator,
|
| 330 |
+
Conv2dProblemSize const &problem_size) {
|
| 331 |
+
// Compute problem size
|
| 332 |
+
switch (conv_operator) {
|
| 333 |
+
case Operator::kFprop:
|
| 334 |
+
return gemm::GemmCoord(
|
| 335 |
+
problem_size.N * problem_size.P * problem_size.Q,
|
| 336 |
+
problem_size.K,
|
| 337 |
+
problem_size.R * problem_size.S * problem_size.C / problem_size.groups
|
| 338 |
+
);
|
| 339 |
+
case Operator::kDeconv:
|
| 340 |
+
case Operator::kDgrad:
|
| 341 |
+
return gemm::GemmCoord(
|
| 342 |
+
problem_size.N * problem_size.H * problem_size.W,
|
| 343 |
+
problem_size.C,
|
| 344 |
+
problem_size.R * problem_size.S * problem_size.K
|
| 345 |
+
);
|
| 346 |
+
case Operator::kWgrad:
|
| 347 |
+
return gemm::GemmCoord(
|
| 348 |
+
problem_size.K,
|
| 349 |
+
problem_size.R * problem_size.S * problem_size.C,
|
| 350 |
+
problem_size.N * problem_size.P * problem_size.Q
|
| 351 |
+
);
|
| 352 |
+
default:
|
| 353 |
+
break;
|
| 354 |
+
}
|
| 355 |
+
return gemm::GemmCoord();
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
|
| 359 |
+
CUTLASS_HOST_DEVICE
|
| 360 |
+
int implicit_gemm_k_iterations(
|
| 361 |
+
Operator conv_operator,
|
| 362 |
+
int threadblock_K,
|
| 363 |
+
Conv2dProblemSize const &problem_size,
|
| 364 |
+
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
|
| 365 |
+
GroupMode group_mode = GroupMode::kNone,
|
| 366 |
+
int threadblock_N = 0) {
|
| 367 |
+
|
| 368 |
+
int iterations = 0;
|
| 369 |
+
|
| 370 |
+
if (group_mode == GroupMode::kNone) {
|
| 371 |
+
|
| 372 |
+
if (algorithm == IteratorAlgorithm::kFixedChannels) {
|
| 373 |
+
|
| 374 |
+
int positions_per_iteration = threadblock_K / problem_size.C;
|
| 375 |
+
switch (conv_operator) {
|
| 376 |
+
case Operator::kFprop:
|
| 377 |
+
iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration;
|
| 378 |
+
break;
|
| 379 |
+
|
| 380 |
+
default:
|
| 381 |
+
break;
|
| 382 |
+
}
|
| 383 |
+
}
|
| 384 |
+
else if (algorithm == IteratorAlgorithm::kFewChannels) {
|
| 385 |
+
|
| 386 |
+
switch (conv_operator) {
|
| 387 |
+
case Operator::kFprop:
|
| 388 |
+
iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K;
|
| 389 |
+
break;
|
| 390 |
+
|
| 391 |
+
default:
|
| 392 |
+
break;
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
else {
|
| 396 |
+
int elements_per_split_k_slice = 0;
|
| 397 |
+
|
| 398 |
+
switch (conv_operator) {
|
| 399 |
+
case Operator::kFprop:
|
| 400 |
+
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 401 |
+
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
| 402 |
+
break;
|
| 403 |
+
|
| 404 |
+
case Operator::kDeconv:
|
| 405 |
+
case Operator::kDgrad:
|
| 406 |
+
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 407 |
+
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
| 408 |
+
break;
|
| 409 |
+
|
| 410 |
+
case Operator::kWgrad:
|
| 411 |
+
elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 412 |
+
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
| 413 |
+
break;
|
| 414 |
+
|
| 415 |
+
default:
|
| 416 |
+
break;
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
} else if (group_mode == GroupMode::kDepthwise) {
|
| 421 |
+
int channels_per_cta = threadblock_N;
|
| 422 |
+
|
| 423 |
+
if (algorithm == IteratorAlgorithm::kAnalytic) {
|
| 424 |
+
switch (conv_operator) {
|
| 425 |
+
case Operator::kFprop:
|
| 426 |
+
iterations = problem_size.R * problem_size.S *
|
| 427 |
+
((channels_per_cta + threadblock_K - 1) / threadblock_K);
|
| 428 |
+
break;
|
| 429 |
+
|
| 430 |
+
default:
|
| 431 |
+
break;
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
} else { // Group conv
|
| 435 |
+
|
| 436 |
+
int channels_per_group = problem_size.C / problem_size.groups;
|
| 437 |
+
int k_per_group = problem_size.K / problem_size.groups;
|
| 438 |
+
|
| 439 |
+
if (algorithm == IteratorAlgorithm::kAnalytic) {
|
| 440 |
+
switch (conv_operator) {
|
| 441 |
+
case Operator::kFprop:
|
| 442 |
+
iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
|
| 443 |
+
// In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
|
| 444 |
+
if (problem_size.groups != 1) {
|
| 445 |
+
if (k_per_group < threadblock_N) {
|
| 446 |
+
iterations *= threadblock_N / k_per_group;
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
break;
|
| 450 |
+
|
| 451 |
+
default:
|
| 452 |
+
break;
|
| 453 |
+
}
|
| 454 |
+
} else if (algorithm == IteratorAlgorithm::kOptimized) {
|
| 455 |
+
// Current optimized iterator only support GroupMode::kSingleGroup
|
| 456 |
+
if (group_mode == GroupMode::kSingleGroup) {
|
| 457 |
+
switch (conv_operator) {
|
| 458 |
+
case Operator::kFprop:
|
| 459 |
+
iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
|
| 460 |
+
break;
|
| 461 |
+
|
| 462 |
+
default:
|
| 463 |
+
break;
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
return iterations;
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
template <int N = 1, int Output_P = 1, int Output_Q = 1>
|
| 475 |
+
CUTLASS_HOST_DEVICE
|
| 476 |
+
int depthwise_gemm_k_iterations(
|
| 477 |
+
Operator conv_operator,
|
| 478 |
+
int threadblock_K,
|
| 479 |
+
Conv2dProblemSize const &problem_size,
|
| 480 |
+
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
|
| 481 |
+
GroupMode group_mode = GroupMode::kNone,
|
| 482 |
+
int threadblock_N = 0) {
|
| 483 |
+
|
| 484 |
+
int n = problem_size.N;
|
| 485 |
+
int p = (problem_size.P + Output_P - 1) / Output_P;
|
| 486 |
+
int q = (problem_size.Q + Output_Q - 1) / Output_Q;
|
| 487 |
+
|
| 488 |
+
int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 489 |
+
return iterations;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
CUTLASS_HOST_DEVICE
|
| 494 |
+
int implicit_gemm_k_iterations_per_channel(
|
| 495 |
+
Operator conv_operator,
|
| 496 |
+
Conv2dProblemSize const &problem_size,
|
| 497 |
+
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
|
| 498 |
+
|
| 499 |
+
int iterations = 0; //0 means not applicable
|
| 500 |
+
if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) {
|
| 501 |
+
switch (conv_operator) {
|
| 502 |
+
case Operator::kFprop:
|
| 503 |
+
iterations = problem_size.R * problem_size.S;
|
| 504 |
+
break;
|
| 505 |
+
|
| 506 |
+
case Operator::kDeconv:
|
| 507 |
+
case Operator::kDgrad:
|
| 508 |
+
iterations = problem_size.R * problem_size.S;
|
| 509 |
+
break;
|
| 510 |
+
|
| 511 |
+
default:
|
| 512 |
+
break;
|
| 513 |
+
}
|
| 514 |
+
}
|
| 515 |
+
return iterations;
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 519 |
+
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
|
| 520 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 521 |
+
/// Returns ImplicitGemm tensor A extent as Tensor4DCoord
|
| 522 |
+
CUTLASS_HOST_DEVICE
|
| 523 |
+
cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent(
|
| 524 |
+
Operator conv_operator,
|
| 525 |
+
Conv2dProblemSize const &problem_size) {
|
| 526 |
+
switch (conv_operator) {
|
| 527 |
+
case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
|
| 528 |
+
case cutlass::conv::Operator::kDeconv:
|
| 529 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
|
| 530 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
|
| 531 |
+
default : break;
|
| 532 |
+
}
|
| 533 |
+
return cutlass::Tensor4DCoord();
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
/// Returns ImplicitGemm tensor B extent as Tensor4DCoord
|
| 537 |
+
CUTLASS_HOST_DEVICE
|
| 538 |
+
cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent(
|
| 539 |
+
Operator conv_operator,
|
| 540 |
+
Conv2dProblemSize const &problem_size) {
|
| 541 |
+
switch (conv_operator) {
|
| 542 |
+
case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
|
| 543 |
+
case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true);
|
| 544 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
|
| 545 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
|
| 546 |
+
default : break;
|
| 547 |
+
}
|
| 548 |
+
return cutlass::Tensor4DCoord();
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
/// Returns ImplicitGemm tensor C extent as Tensor4DCoord
|
| 552 |
+
CUTLASS_HOST_DEVICE
|
| 553 |
+
cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent(
|
| 554 |
+
Operator conv_operator,
|
| 555 |
+
Conv2dProblemSize const &problem_size) {
|
| 556 |
+
switch (conv_operator) {
|
| 557 |
+
case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
|
| 558 |
+
case cutlass::conv::Operator::kDeconv:
|
| 559 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
|
| 560 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
|
| 561 |
+
default : break;
|
| 562 |
+
}
|
| 563 |
+
return cutlass::Tensor4DCoord();
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
/// Returns ImplicitGemm tensor A size in number of elements
|
| 567 |
+
CUTLASS_HOST_DEVICE
|
| 568 |
+
int64_t implicit_gemm_tensor_a_size(
|
| 569 |
+
Operator conv_operator,
|
| 570 |
+
Conv2dProblemSize const &problem_size) {
|
| 571 |
+
switch (conv_operator) {
|
| 572 |
+
case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
|
| 573 |
+
case cutlass::conv::Operator::kDeconv:
|
| 574 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
|
| 575 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
|
| 576 |
+
default : break;
|
| 577 |
+
}
|
| 578 |
+
return 0;
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
/// Returns ImplicitGemm tensor B size in number of elements
|
| 582 |
+
CUTLASS_HOST_DEVICE
|
| 583 |
+
int64_t implicit_gemm_tensor_b_size(
|
| 584 |
+
Operator conv_operator,
|
| 585 |
+
Conv2dProblemSize const &problem_size) {
|
| 586 |
+
switch (conv_operator) {
|
| 587 |
+
case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
|
| 588 |
+
case cutlass::conv::Operator::kDeconv:
|
| 589 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
|
| 590 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
|
| 591 |
+
default : break;
|
| 592 |
+
}
|
| 593 |
+
return 0;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
/// Returns ImplicitGemm tensor C size in number of elements
|
| 597 |
+
CUTLASS_HOST_DEVICE
|
| 598 |
+
int64_t implicit_gemm_tensor_c_size(
|
| 599 |
+
Operator conv_operator,
|
| 600 |
+
Conv2dProblemSize const &problem_size) {
|
| 601 |
+
switch (conv_operator) {
|
| 602 |
+
case cutlass::conv::Operator::kFprop: return problem_size.output_size();
|
| 603 |
+
case cutlass::conv::Operator::kDeconv:
|
| 604 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
|
| 605 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
|
| 606 |
+
default : break;
|
| 607 |
+
}
|
| 608 |
+
return 0;
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 612 |
+
|
| 613 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 614 |
+
// Strided dgrad helper functions //
|
| 615 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 616 |
+
// Returns number of CTAs tile M to cover valid MMAs per starting filter postion
|
| 617 |
+
CUTLASS_HOST_DEVICE
|
| 618 |
+
int strided_dgrad_tile_m_per_filter(
|
| 619 |
+
Conv2dProblemSize const &problem_size,
|
| 620 |
+
int tile_size_m) {
|
| 621 |
+
|
| 622 |
+
// Compute NHW rows in Dx output that needs MMA per starting filter position
|
| 623 |
+
int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h;
|
| 624 |
+
int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w;
|
| 625 |
+
int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter;
|
| 626 |
+
|
| 627 |
+
// Number of CTAs tile M to cover valid MMAs per starting filter postion
|
| 628 |
+
int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m;
|
| 629 |
+
|
| 630 |
+
return tile_m_per_filter;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
// Computes starting Dx coord (h, w) for given starting filter postion
|
| 634 |
+
CUTLASS_HOST_DEVICE
|
| 635 |
+
void strided_dgrad_starting_coords(
|
| 636 |
+
Conv2dProblemSize const &problem_size,
|
| 637 |
+
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
| 638 |
+
int r, int s,
|
| 639 |
+
int &start_h, int &start_w) {
|
| 640 |
+
|
| 641 |
+
// function locals for remainder by fast divmod
|
| 642 |
+
int pad_h_rem_, pad_w_rem_;
|
| 643 |
+
|
| 644 |
+
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
| 645 |
+
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
| 646 |
+
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
|
| 647 |
+
stride_h_divmod.divmod(start_h, r_);
|
| 648 |
+
|
| 649 |
+
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
| 650 |
+
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
| 651 |
+
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
|
| 652 |
+
stride_w_divmod.divmod(start_w, s_);
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
} // namespace conv
|
| 656 |
+
} // namespace cutlass
|
| 657 |
+
|
| 658 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief This file contains definitions and utility functions for describing convolution problem sizes.
|
| 33 |
+
|
| 34 |
+
Conv3dProblem desciption:
|
| 35 |
+
activation (NDHWC),
|
| 36 |
+
filter (KTRSC),
|
| 37 |
+
output (NZPQK),
|
| 38 |
+
pading (pad_d, pad_h, pad_w),
|
| 39 |
+
stride (stride_d, stride_h, stride_w),
|
| 40 |
+
dilation (dilation_d, dilation_h, dilation_w).
|
| 41 |
+
|
| 42 |
+
Free functions to map:
|
| 43 |
+
Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
|
| 44 |
+
Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
|
| 45 |
+
Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
|
| 46 |
+
*/
|
| 47 |
+
|
| 48 |
+
#pragma once
|
| 49 |
+
|
| 50 |
+
#include "cutlass/conv/convolution.h"
|
| 51 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 52 |
+
|
| 53 |
+
namespace cutlass {
|
| 54 |
+
namespace conv {
|
| 55 |
+
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
/// Problem size structure
|
| 59 |
+
struct Conv3dProblemSize : public Conv2dProblemSize {
|
| 60 |
+
//
|
| 61 |
+
// Type definitions
|
| 62 |
+
//
|
| 63 |
+
|
| 64 |
+
// 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions
|
| 65 |
+
using Coord3D = Coord<3>;
|
| 66 |
+
|
| 67 |
+
//
|
| 68 |
+
// Data members
|
| 69 |
+
//
|
| 70 |
+
|
| 71 |
+
// Conv3d strictly problem size parameters
|
| 72 |
+
int D, T, Z; // input depth, filter depth, output depth
|
| 73 |
+
int pad_d; // padding in depth dimension
|
| 74 |
+
int stride_d; // stride in depth dimension
|
| 75 |
+
int dilation_d; // dilation in depth dimension
|
| 76 |
+
|
| 77 |
+
//
|
| 78 |
+
// Methods
|
| 79 |
+
//
|
| 80 |
+
public:
|
| 81 |
+
CUTLASS_HOST_DEVICE
|
| 82 |
+
Conv3dProblemSize():
|
| 83 |
+
Conv2dProblemSize(),
|
| 84 |
+
D(0), T(0), Z(0),
|
| 85 |
+
pad_d(0),
|
| 86 |
+
stride_d(1),
|
| 87 |
+
dilation_d(1) { }
|
| 88 |
+
|
| 89 |
+
/// Constructor for default padding, stride, dilation, and split-K
|
| 90 |
+
CUTLASS_HOST_DEVICE
|
| 91 |
+
Conv3dProblemSize(
|
| 92 |
+
int N,
|
| 93 |
+
int D,
|
| 94 |
+
int H,
|
| 95 |
+
int W,
|
| 96 |
+
int C,
|
| 97 |
+
int Z,
|
| 98 |
+
int P,
|
| 99 |
+
int Q,
|
| 100 |
+
int K,
|
| 101 |
+
int T,
|
| 102 |
+
int R,
|
| 103 |
+
int S,
|
| 104 |
+
Mode mode
|
| 105 |
+
):
|
| 106 |
+
Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode),
|
| 107 |
+
D(D), T(T), Z(Z),
|
| 108 |
+
pad_d(T / 2), stride_d(1), dilation_d(1) { }
|
| 109 |
+
|
| 110 |
+
/// Constructor
|
| 111 |
+
CUTLASS_HOST_DEVICE
|
| 112 |
+
Conv3dProblemSize(
|
| 113 |
+
int N,
|
| 114 |
+
int D,
|
| 115 |
+
int H,
|
| 116 |
+
int W,
|
| 117 |
+
int C,
|
| 118 |
+
int K,
|
| 119 |
+
int T,
|
| 120 |
+
int R,
|
| 121 |
+
int S,
|
| 122 |
+
int Z,
|
| 123 |
+
int P,
|
| 124 |
+
int Q,
|
| 125 |
+
int pad_d,
|
| 126 |
+
int pad_h,
|
| 127 |
+
int pad_w,
|
| 128 |
+
int stride_d,
|
| 129 |
+
int stride_h,
|
| 130 |
+
int stride_w,
|
| 131 |
+
int dilation_d,
|
| 132 |
+
int dilation_h,
|
| 133 |
+
int dilation_w,
|
| 134 |
+
Mode mode,
|
| 135 |
+
int split_k_slices = 1,
|
| 136 |
+
int groups = 1
|
| 137 |
+
):
|
| 138 |
+
Conv2dProblemSize(
|
| 139 |
+
N, H, W, C, K, R, S, P, Q,
|
| 140 |
+
pad_h, pad_w,
|
| 141 |
+
stride_h, stride_w,
|
| 142 |
+
dilation_h, dilation_w,
|
| 143 |
+
mode, split_k_slices, groups),
|
| 144 |
+
D(D), T(T), Z(Z),
|
| 145 |
+
pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { }
|
| 146 |
+
|
| 147 |
+
/// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
|
| 148 |
+
// set *user-defined* output size and sets Z, P, and Q (include all data members in ctor)
|
| 149 |
+
CUTLASS_HOST_DEVICE
|
| 150 |
+
Conv3dProblemSize(
|
| 151 |
+
cutlass::Tensor5DCoord input_size, // NDHWC
|
| 152 |
+
cutlass::Tensor5DCoord filter_size, // KTRSC
|
| 153 |
+
Coord3D padding, // pad_d, pad_h, pad_w
|
| 154 |
+
Coord3D stride, // stride_d, stride_h, stride_w
|
| 155 |
+
Coord3D dilation, // dilation_d, dilation_h, dilation_w
|
| 156 |
+
cutlass::Tensor5DCoord output_size, // NZPQK
|
| 157 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
| 158 |
+
int split_k_slices = 1,
|
| 159 |
+
int groups = 1
|
| 160 |
+
):
|
| 161 |
+
Conv2dProblemSize(
|
| 162 |
+
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
|
| 163 |
+
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
|
| 164 |
+
{padding[1], padding[1], padding[2], padding[2]},
|
| 165 |
+
{stride[1], stride[2]},
|
| 166 |
+
{dilation[1], dilation[2]},
|
| 167 |
+
{output_size.n(), output_size.h(), output_size.w(), output_size.c()},
|
| 168 |
+
mode, split_k_slices, groups),
|
| 169 |
+
D(input_size.d()), T(filter_size.d()), Z(output_size.d()),
|
| 170 |
+
pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { }
|
| 171 |
+
|
| 172 |
+
/// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
|
| 173 |
+
// *computes* output size and sets Z, P and Q (include all data members in ctor)
|
| 174 |
+
CUTLASS_HOST_DEVICE
|
| 175 |
+
Conv3dProblemSize(
|
| 176 |
+
cutlass::Tensor5DCoord input_size, // NDHWC
|
| 177 |
+
cutlass::Tensor5DCoord filter_size, // KTRSC
|
| 178 |
+
Coord3D padding, // pad_d, pad_h, pad_w
|
| 179 |
+
Coord3D stride, // stride_d, stride_h, stride_w
|
| 180 |
+
Coord3D dilation, // dilation_d, dilation_h, dilation_w
|
| 181 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
| 182 |
+
int split_k_slices = 1,
|
| 183 |
+
int groups = 1
|
| 184 |
+
):
|
| 185 |
+
Conv2dProblemSize(
|
| 186 |
+
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
|
| 187 |
+
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
|
| 188 |
+
{padding[1], padding[1], padding[2], padding[2]},
|
| 189 |
+
{stride[1], stride[2]},
|
| 190 |
+
{dilation[1], dilation[2]},
|
| 191 |
+
mode, split_k_slices, groups),
|
| 192 |
+
D(input_size.d()), T(filter_size.d()),
|
| 193 |
+
pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0])
|
| 194 |
+
{
|
| 195 |
+
// set output Z
|
| 196 |
+
Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D
|
| 200 |
+
// *computes* output size and sets Z, P and Q (include all data members in ctor)
|
| 201 |
+
CUTLASS_HOST_DEVICE
|
| 202 |
+
Conv3dProblemSize(
|
| 203 |
+
cutlass::Tensor5DCoord input_size, // NDHWC
|
| 204 |
+
cutlass::Tensor5DCoord filter_size, // KTRSC
|
| 205 |
+
CUTLASS_STL_NAMESPACE::tuple<Coord3D, Coord3D> padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q
|
| 206 |
+
Coord3D stride, // stride_d, stride_h, stride_w
|
| 207 |
+
Coord3D dilation, // dilation_d, dilation_h, dilation_w
|
| 208 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
| 209 |
+
int split_k_slices = 1,
|
| 210 |
+
int groups = 1
|
| 211 |
+
):
|
| 212 |
+
Conv2dProblemSize(
|
| 213 |
+
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
|
| 214 |
+
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
|
| 215 |
+
{CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1],
|
| 216 |
+
CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]},
|
| 217 |
+
{stride[1], stride[2]},
|
| 218 |
+
{dilation[1], dilation[2]},
|
| 219 |
+
mode, split_k_slices, groups),
|
| 220 |
+
D(input_size.d()), T(filter_size.d()),
|
| 221 |
+
pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0])
|
| 222 |
+
{
|
| 223 |
+
// set output Z
|
| 224 |
+
Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
/// Equality operator (ignores mode and split_k_slice)
|
| 228 |
+
CUTLASS_HOST_DEVICE
|
| 229 |
+
bool operator==(Conv3dProblemSize const &conv) const {
|
| 230 |
+
return (
|
| 231 |
+
(N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) &&
|
| 232 |
+
(K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) &&
|
| 233 |
+
(Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) &&
|
| 234 |
+
(pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
|
| 235 |
+
(stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) &&
|
| 236 |
+
(dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w)
|
| 237 |
+
);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Inequality operator
|
| 241 |
+
CUTLASS_HOST_DEVICE
|
| 242 |
+
bool operator!=(Conv3dProblemSize const &rhs) const {
|
| 243 |
+
return !(*this == rhs);
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// Reset covolution mode in the problem
|
| 247 |
+
CUTLASS_HOST_DEVICE
|
| 248 |
+
Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) {
|
| 249 |
+
Conv3dProblemSize tmp(*this);
|
| 250 |
+
tmp.mode = mode_;
|
| 251 |
+
return tmp;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Reset covolution mode in the problem
|
| 255 |
+
CUTLASS_HOST_DEVICE
|
| 256 |
+
Conv3dProblemSize reset_split_k_slices(int split_k_slices_) {
|
| 257 |
+
Conv3dProblemSize tmp(*this);
|
| 258 |
+
tmp.split_k_slices = split_k_slices_;
|
| 259 |
+
return tmp;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/// Returns activation extent as Tensor5DCoord
|
| 263 |
+
CUTLASS_HOST_DEVICE
|
| 264 |
+
cutlass::Tensor5DCoord activation_extent() const {
|
| 265 |
+
|
| 266 |
+
return cutlass::Tensor5DCoord ({N, D, H, W, C});
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/// Returns filter extent as Tensor5DCoord
|
| 270 |
+
CUTLASS_HOST_DEVICE
|
| 271 |
+
cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const {
|
| 272 |
+
|
| 273 |
+
return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K})
|
| 274 |
+
: cutlass::Tensor5DCoord ({K, T, R, S, C});
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
/// Returns output extent as Tensor5DCoord
|
| 278 |
+
CUTLASS_HOST_DEVICE
|
| 279 |
+
cutlass::Tensor5DCoord output_extent() const {
|
| 280 |
+
|
| 281 |
+
return cutlass::Tensor5DCoord ({N, Z, P, Q, K});
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
/// Returns activation size in number of elements
|
| 285 |
+
CUTLASS_HOST_DEVICE
|
| 286 |
+
int64_t activation_size() const {
|
| 287 |
+
|
| 288 |
+
return static_cast<int64_t>(N) * static_cast<int64_t>(D) *
|
| 289 |
+
static_cast<int64_t>(H) * static_cast<int64_t>(W) *
|
| 290 |
+
static_cast<int64_t>(C);
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
/// Returns filter size in number of elements
|
| 294 |
+
CUTLASS_HOST_DEVICE
|
| 295 |
+
int64_t filter_size() const {
|
| 296 |
+
|
| 297 |
+
return static_cast<int64_t>(K) * static_cast<int64_t>(T) *
|
| 298 |
+
static_cast<int64_t>(R) * static_cast<int64_t>(S) *
|
| 299 |
+
static_cast<int64_t>(C);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Returns output size in number of elements
|
| 303 |
+
CUTLASS_HOST_DEVICE
|
| 304 |
+
int64_t output_size() const {
|
| 305 |
+
|
| 306 |
+
return static_cast<int64_t>(N) * static_cast<int64_t>(Z) *
|
| 307 |
+
static_cast<int64_t>(P) * static_cast<int64_t>(Q) *
|
| 308 |
+
static_cast<int64_t>(K);
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
/// Returns padding as Coord3D
|
| 312 |
+
CUTLASS_HOST_DEVICE
|
| 313 |
+
Coord3D padding() const {
|
| 314 |
+
|
| 315 |
+
return Coord3D ({pad_d, pad_h, pad_w});
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
/// Returns stride as MatrixCoord
|
| 319 |
+
CUTLASS_HOST_DEVICE
|
| 320 |
+
Coord3D stride() const {
|
| 321 |
+
|
| 322 |
+
return Coord3D ({stride_d, stride_h, stride_w});
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
/// Returns dilation as MatrixCoord
|
| 326 |
+
CUTLASS_HOST_DEVICE
|
| 327 |
+
Coord3D dilation() const {
|
| 328 |
+
|
| 329 |
+
return Coord3D ({dilation_d, dilation_h, dilation_w});
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
};
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 336 |
+
// ImplicitGemm helper functions //
|
| 337 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 338 |
+
|
| 339 |
+
/// Determine the problem size of the implicit GEMM operation
|
| 340 |
+
CUTLASS_HOST_DEVICE
|
| 341 |
+
cutlass::gemm::GemmCoord implicit_gemm_problem_size(
|
| 342 |
+
Operator conv_operator,
|
| 343 |
+
Conv3dProblemSize const &problem_size) {
|
| 344 |
+
// Compute problem size
|
| 345 |
+
switch (conv_operator) {
|
| 346 |
+
case Operator::kFprop:
|
| 347 |
+
return gemm::GemmCoord(
|
| 348 |
+
problem_size.N * problem_size.Z * problem_size.P * problem_size.Q,
|
| 349 |
+
problem_size.K,
|
| 350 |
+
problem_size.T * problem_size.R * problem_size.S * problem_size.C
|
| 351 |
+
);
|
| 352 |
+
case Operator::kDeconv:
|
| 353 |
+
case Operator::kDgrad:
|
| 354 |
+
return gemm::GemmCoord(
|
| 355 |
+
problem_size.N * problem_size.D * problem_size.H * problem_size.W,
|
| 356 |
+
problem_size.C,
|
| 357 |
+
problem_size.T * problem_size.R * problem_size.S * problem_size.K
|
| 358 |
+
);
|
| 359 |
+
case Operator::kWgrad:
|
| 360 |
+
return gemm::GemmCoord(
|
| 361 |
+
problem_size.K,
|
| 362 |
+
problem_size.T * problem_size.R * problem_size.S * problem_size.C,
|
| 363 |
+
problem_size.N * problem_size.Z * problem_size.P * problem_size.Q
|
| 364 |
+
);
|
| 365 |
+
default:
|
| 366 |
+
break;
|
| 367 |
+
}
|
| 368 |
+
return gemm::GemmCoord();
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
|
| 372 |
+
CUTLASS_HOST_DEVICE
|
| 373 |
+
int implicit_gemm_k_iterations(
|
| 374 |
+
Operator conv_operator,
|
| 375 |
+
int threadblock_K,
|
| 376 |
+
Conv3dProblemSize const &problem_size,
|
| 377 |
+
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
|
| 378 |
+
GroupMode group_mode = GroupMode::kNone,
|
| 379 |
+
int threadblock_N = 0) {
|
| 380 |
+
|
| 381 |
+
int iterations = 0;
|
| 382 |
+
int elements_per_split_k_slice = 0;
|
| 383 |
+
if (group_mode == GroupMode::kNone) {
|
| 384 |
+
switch (conv_operator) {
|
| 385 |
+
case Operator::kFprop:
|
| 386 |
+
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 387 |
+
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
| 388 |
+
break;
|
| 389 |
+
|
| 390 |
+
case Operator::kDeconv:
|
| 391 |
+
case Operator::kDgrad:
|
| 392 |
+
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 393 |
+
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
| 394 |
+
break;
|
| 395 |
+
|
| 396 |
+
case Operator::kWgrad:
|
| 397 |
+
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
| 398 |
+
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
| 399 |
+
break;
|
| 400 |
+
|
| 401 |
+
default:
|
| 402 |
+
break;
|
| 403 |
+
}
|
| 404 |
+
} else if (group_mode == GroupMode::kDepthwise) {
|
| 405 |
+
int channels_per_cta = threadblock_N;
|
| 406 |
+
|
| 407 |
+
if (algorithm == IteratorAlgorithm::kAnalytic) {
|
| 408 |
+
switch (conv_operator) {
|
| 409 |
+
case Operator::kFprop:
|
| 410 |
+
iterations = problem_size.T * problem_size.R * problem_size.S *
|
| 411 |
+
((channels_per_cta + threadblock_K - 1) / threadblock_K);
|
| 412 |
+
break;
|
| 413 |
+
|
| 414 |
+
default:
|
| 415 |
+
break;
|
| 416 |
+
}
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
return iterations;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 424 |
+
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
|
| 425 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 426 |
+
/// Returns ImplicitGemm tensor A extent as Tensor5DCoord
|
| 427 |
+
CUTLASS_HOST_DEVICE
|
| 428 |
+
cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent(
|
| 429 |
+
Operator conv_operator,
|
| 430 |
+
Conv3dProblemSize const &problem_size) {
|
| 431 |
+
switch (conv_operator) {
|
| 432 |
+
case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
|
| 433 |
+
case cutlass::conv::Operator::kDeconv:
|
| 434 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
|
| 435 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
|
| 436 |
+
default : break;
|
| 437 |
+
}
|
| 438 |
+
return cutlass::Tensor5DCoord();
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
/// Returns ImplicitGemm tensor B extent as Tensor5DCoord
|
| 442 |
+
CUTLASS_HOST_DEVICE
|
| 443 |
+
cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent(
|
| 444 |
+
Operator conv_operator,
|
| 445 |
+
Conv3dProblemSize const &problem_size) {
|
| 446 |
+
switch (conv_operator) {
|
| 447 |
+
case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
|
| 448 |
+
case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true);
|
| 449 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
|
| 450 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
|
| 451 |
+
default : break;
|
| 452 |
+
}
|
| 453 |
+
return cutlass::Tensor5DCoord();
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
/// Returns ImplicitGemm tensor C extent as Tensor5DCoord
|
| 457 |
+
CUTLASS_HOST_DEVICE
|
| 458 |
+
cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent(
|
| 459 |
+
Operator conv_operator,
|
| 460 |
+
Conv3dProblemSize const &problem_size) {
|
| 461 |
+
switch (conv_operator) {
|
| 462 |
+
case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
|
| 463 |
+
case cutlass::conv::Operator::kDeconv:
|
| 464 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
|
| 465 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
|
| 466 |
+
default : break;
|
| 467 |
+
}
|
| 468 |
+
return cutlass::Tensor5DCoord();
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
/// Returns ImplicitGemm tensor A size in number of elements
|
| 472 |
+
CUTLASS_HOST_DEVICE
|
| 473 |
+
int64_t implicit_gemm_tensor_a_size(
|
| 474 |
+
Operator conv_operator,
|
| 475 |
+
Conv3dProblemSize const &problem_size) {
|
| 476 |
+
switch (conv_operator) {
|
| 477 |
+
case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
|
| 478 |
+
case cutlass::conv::Operator::kDeconv:
|
| 479 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
|
| 480 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
|
| 481 |
+
default : break;
|
| 482 |
+
}
|
| 483 |
+
return 0;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
/// Returns ImplicitGemm tensor B size in number of elements
|
| 487 |
+
CUTLASS_HOST_DEVICE
|
| 488 |
+
int64_t implicit_gemm_tensor_b_size(
|
| 489 |
+
Operator conv_operator,
|
| 490 |
+
Conv3dProblemSize const &problem_size) {
|
| 491 |
+
switch (conv_operator) {
|
| 492 |
+
case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
|
| 493 |
+
case cutlass::conv::Operator::kDeconv:
|
| 494 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
|
| 495 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
|
| 496 |
+
default : break;
|
| 497 |
+
}
|
| 498 |
+
return 0;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
/// Returns ImplicitGemm tensor C size in number of elements
|
| 502 |
+
CUTLASS_HOST_DEVICE
|
| 503 |
+
int64_t implicit_gemm_tensor_c_size(
|
| 504 |
+
Operator conv_operator,
|
| 505 |
+
Conv3dProblemSize const &problem_size) {
|
| 506 |
+
switch (conv_operator) {
|
| 507 |
+
case cutlass::conv::Operator::kFprop: return problem_size.output_size();
|
| 508 |
+
case cutlass::conv::Operator::kDeconv:
|
| 509 |
+
case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
|
| 510 |
+
case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
|
| 511 |
+
default : break;
|
| 512 |
+
}
|
| 513 |
+
return 0;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
} // namespace conv
|
| 517 |
+
} // namespace cutlass
|
| 518 |
+
|
| 519 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief This file contains definitions and utility functions for describing convolution problem shapes.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/tensor_coord.h"
|
| 38 |
+
#include "cutlass/conv/convolution.h"
|
| 39 |
+
|
| 40 |
+
#include "cute/container/array.hpp"
|
| 41 |
+
|
| 42 |
+
#if ! defined(__CUDACC_RTC__)
|
| 43 |
+
#include <initializer_list>
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass::conv {
|
| 50 |
+
|
| 51 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
// Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion.
|
| 54 |
+
// All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK)
|
| 55 |
+
// Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types.
|
| 56 |
+
template <
|
| 57 |
+
conv::Operator ConvOp_,
|
| 58 |
+
int NumSpatialDimensions_
|
| 59 |
+
>
|
| 60 |
+
struct ConvProblemShape {
|
| 61 |
+
//
|
| 62 |
+
// Alias types for members
|
| 63 |
+
//
|
| 64 |
+
|
| 65 |
+
static constexpr int RankS = NumSpatialDimensions_;
|
| 66 |
+
static constexpr int RankT = NumSpatialDimensions_ + 2;
|
| 67 |
+
static constexpr conv::Operator ConvOp = ConvOp_;
|
| 68 |
+
static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
|
| 69 |
+
using SpatialExtent = cute::array<int, RankS>;
|
| 70 |
+
using TensorExtent = cute::array<int, RankT>;
|
| 71 |
+
using TensorStride = cute::array<int64_t, RankT>;
|
| 72 |
+
using ShapePadding = SpatialExtent;
|
| 73 |
+
using TraversalStride = SpatialExtent;
|
| 74 |
+
using ShapeDilation = SpatialExtent;
|
| 75 |
+
using Corner = SpatialExtent;
|
| 76 |
+
|
| 77 |
+
//
|
| 78 |
+
// Members
|
| 79 |
+
//
|
| 80 |
+
cutlass::conv::Mode mode{};
|
| 81 |
+
TensorExtent shape_A{};
|
| 82 |
+
TensorStride stride_A{};
|
| 83 |
+
TensorExtent shape_B{};
|
| 84 |
+
TensorStride stride_B{};
|
| 85 |
+
TensorExtent shape_C{};
|
| 86 |
+
TensorStride stride_C{};
|
| 87 |
+
|
| 88 |
+
// asymmetric padding, both upper and lower padding must be >= 0
|
| 89 |
+
ShapePadding lower_padding{};
|
| 90 |
+
ShapePadding upper_padding{};
|
| 91 |
+
TraversalStride traversal_stride{};
|
| 92 |
+
ShapeDilation dilation{};
|
| 93 |
+
int groups = 1;
|
| 94 |
+
|
| 95 |
+
//
|
| 96 |
+
// Methods
|
| 97 |
+
//
|
| 98 |
+
|
| 99 |
+
ConvProblemShape() = default;
|
| 100 |
+
|
| 101 |
+
// Constructor accepts user facing arguments and computes to stores the corners as its internal state
|
| 102 |
+
ConvProblemShape(
|
| 103 |
+
conv::Mode mode, // convolution/cross-correlation
|
| 104 |
+
TensorExtent shape_act, // [n,d,h,w,c]
|
| 105 |
+
TensorStride stride_act, // [n,d,h,w,c]
|
| 106 |
+
TensorExtent shape_flt, // [k,t,r,s,c]
|
| 107 |
+
TensorStride stride_flt, // [k,t,r,s,c]
|
| 108 |
+
ShapePadding lower_padding, // [pad_d, pad_h, pad_w]
|
| 109 |
+
ShapePadding upper_padding, // [pad_d, pad_h, pad_w]
|
| 110 |
+
TraversalStride tstride, // [stride_d, stride_h, stride_w]
|
| 111 |
+
ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w]
|
| 112 |
+
int groups)
|
| 113 |
+
: mode(mode)
|
| 114 |
+
, lower_padding(lower_padding)
|
| 115 |
+
, upper_padding(upper_padding)
|
| 116 |
+
, traversal_stride(tstride)
|
| 117 |
+
, dilation(dilation)
|
| 118 |
+
, groups(groups) {
|
| 119 |
+
|
| 120 |
+
auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
|
| 121 |
+
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// Allow user input of xformed activation stride to support non-packed strides.
|
| 125 |
+
ConvProblemShape(
|
| 126 |
+
conv::Mode mode, // convolution/cross-correlation
|
| 127 |
+
TensorExtent shape_act, // [n,d,h,w,c]
|
| 128 |
+
TensorStride stride_act, // [n,d,h,w,c]
|
| 129 |
+
TensorExtent shape_flt, // [k,t,r,s,c]
|
| 130 |
+
TensorStride stride_flt, // [k,t,r,s,c]
|
| 131 |
+
TensorStride stride_xformed_act, // [n,z,p,q,k]
|
| 132 |
+
ShapePadding lower_padding, // [pad_d, pad_h, pad_w]
|
| 133 |
+
ShapePadding upper_padding, // [pad_d, pad_h, pad_w]
|
| 134 |
+
TraversalStride tstride, // [stride_d, stride_h, stride_w]
|
| 135 |
+
ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w]
|
| 136 |
+
int groups)
|
| 137 |
+
: mode(mode)
|
| 138 |
+
, lower_padding(lower_padding)
|
| 139 |
+
, upper_padding(upper_padding)
|
| 140 |
+
, traversal_stride(tstride)
|
| 141 |
+
, dilation(dilation)
|
| 142 |
+
, groups(groups) {
|
| 143 |
+
|
| 144 |
+
CUTLASS_ASSERT(stride_act[RankT - 1] == 1);
|
| 145 |
+
CUTLASS_ASSERT(stride_flt[RankT - 1] == 1);
|
| 146 |
+
CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1);
|
| 147 |
+
|
| 148 |
+
auto stride_act_packed = packed_stride_right_major(shape_act);
|
| 149 |
+
auto stride_flt_packed = packed_stride_right_major(shape_flt);
|
| 150 |
+
auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt);
|
| 151 |
+
|
| 152 |
+
CUTLASS_PRAGMA_UNROLL
|
| 153 |
+
for(int i = 0; i < RankT - 1; ++i) {
|
| 154 |
+
CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]);
|
| 155 |
+
CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]);
|
| 156 |
+
CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order.
|
| 163 |
+
ConvProblemShape(
|
| 164 |
+
conv::Mode mode,
|
| 165 |
+
TensorExtent shape_act,
|
| 166 |
+
TensorExtent shape_flt,
|
| 167 |
+
ShapePadding lower_padding,
|
| 168 |
+
ShapePadding upper_padding,
|
| 169 |
+
TraversalStride tstride,
|
| 170 |
+
ShapeDilation dilation,
|
| 171 |
+
int groups)
|
| 172 |
+
: ConvProblemShape(
|
| 173 |
+
mode,
|
| 174 |
+
shape_act,
|
| 175 |
+
packed_stride_right_major(shape_act),
|
| 176 |
+
shape_flt,
|
| 177 |
+
packed_stride_right_major(shape_flt),
|
| 178 |
+
lower_padding,
|
| 179 |
+
upper_padding,
|
| 180 |
+
tstride,
|
| 181 |
+
dilation,
|
| 182 |
+
groups) {
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
#if ! defined(__CUDACC_RTC__)
|
| 186 |
+
// Constructor accepts user facing arguments and computes to stores the corners as its internal state
|
| 187 |
+
ConvProblemShape(
|
| 188 |
+
conv::Mode mode,
|
| 189 |
+
std::initializer_list<int> shape_act_,
|
| 190 |
+
std::initializer_list<int64_t> stride_act_,
|
| 191 |
+
std::initializer_list<int> shape_flt_,
|
| 192 |
+
std::initializer_list<int64_t> stride_flt_,
|
| 193 |
+
std::initializer_list<int> lower_padding_,
|
| 194 |
+
std::initializer_list<int> upper_padding_,
|
| 195 |
+
std::initializer_list<int> traversal_stride_,
|
| 196 |
+
std::initializer_list<int> dilation_,
|
| 197 |
+
int groups)
|
| 198 |
+
: mode(mode)
|
| 199 |
+
, groups(groups) {
|
| 200 |
+
|
| 201 |
+
TensorExtent shape_act{};
|
| 202 |
+
TensorStride stride_act{};
|
| 203 |
+
TensorExtent shape_flt{};
|
| 204 |
+
TensorStride stride_flt{};
|
| 205 |
+
|
| 206 |
+
assert(shape_act_.size() == shape_act.size());
|
| 207 |
+
assert(stride_act_.size() == stride_act.size());
|
| 208 |
+
assert(shape_flt_.size() == shape_flt.size());
|
| 209 |
+
assert(stride_flt_.size() == stride_flt.size());
|
| 210 |
+
assert(lower_padding_.size() == lower_padding.size());
|
| 211 |
+
assert(upper_padding_.size() == upper_padding.size());
|
| 212 |
+
assert(traversal_stride_.size() == traversal_stride.size());
|
| 213 |
+
assert(dilation_.size() == dilation.size());
|
| 214 |
+
|
| 215 |
+
std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
|
| 216 |
+
std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin());
|
| 217 |
+
std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
|
| 218 |
+
std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin());
|
| 219 |
+
std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
|
| 220 |
+
std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
|
| 221 |
+
std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
|
| 222 |
+
std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
|
| 223 |
+
|
| 224 |
+
auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
|
| 225 |
+
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
// Allow user input of xformed activation stride to support non-packed strides.
|
| 229 |
+
ConvProblemShape(
|
| 230 |
+
conv::Mode mode,
|
| 231 |
+
std::initializer_list<int> shape_act_,
|
| 232 |
+
std::initializer_list<int64_t> stride_act_,
|
| 233 |
+
std::initializer_list<int> shape_flt_,
|
| 234 |
+
std::initializer_list<int64_t> stride_flt_,
|
| 235 |
+
std::initializer_list<int64_t> stride_xformed_act_,
|
| 236 |
+
std::initializer_list<int> lower_padding_,
|
| 237 |
+
std::initializer_list<int> upper_padding_,
|
| 238 |
+
std::initializer_list<int> traversal_stride_,
|
| 239 |
+
std::initializer_list<int> dilation_,
|
| 240 |
+
int groups)
|
| 241 |
+
: mode(mode)
|
| 242 |
+
, groups(groups) {
|
| 243 |
+
TensorExtent shape_act{};
|
| 244 |
+
TensorStride stride_act{};
|
| 245 |
+
TensorExtent shape_flt{};
|
| 246 |
+
TensorStride stride_flt{};
|
| 247 |
+
TensorStride stride_xformed_act{};
|
| 248 |
+
|
| 249 |
+
std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
|
| 250 |
+
std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin());
|
| 251 |
+
std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
|
| 252 |
+
std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin());
|
| 253 |
+
std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin());
|
| 254 |
+
std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
|
| 255 |
+
std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
|
| 256 |
+
std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
|
| 257 |
+
std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
|
| 258 |
+
|
| 259 |
+
CUTLASS_ASSERT(stride_act[RankT - 1] == 1);
|
| 260 |
+
CUTLASS_ASSERT(stride_flt[RankT - 1] == 1);
|
| 261 |
+
CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1);
|
| 262 |
+
|
| 263 |
+
auto stride_act_packed = packed_stride_right_major(shape_act);
|
| 264 |
+
auto stride_flt_packed = packed_stride_right_major(shape_flt);
|
| 265 |
+
auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt);
|
| 266 |
+
|
| 267 |
+
CUTLASS_PRAGMA_UNROLL
|
| 268 |
+
for(int i = 0; i < RankT - 1; ++i) {
|
| 269 |
+
CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]);
|
| 270 |
+
CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]);
|
| 271 |
+
CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// Constructor accepts user facing arguments and computes to stores the corners as its internal state
|
| 278 |
+
ConvProblemShape(
|
| 279 |
+
conv::Mode mode,
|
| 280 |
+
std::initializer_list<int> shape_act_,
|
| 281 |
+
std::initializer_list<int> shape_flt_,
|
| 282 |
+
std::initializer_list<int> lower_padding_,
|
| 283 |
+
std::initializer_list<int> upper_padding_,
|
| 284 |
+
std::initializer_list<int> traversal_stride_,
|
| 285 |
+
std::initializer_list<int> dilation_,
|
| 286 |
+
int groups)
|
| 287 |
+
: mode(mode)
|
| 288 |
+
, groups(groups) {
|
| 289 |
+
TensorExtent shape_act{};
|
| 290 |
+
TensorStride stride_act{};
|
| 291 |
+
TensorExtent shape_flt{};
|
| 292 |
+
TensorStride stride_flt{};
|
| 293 |
+
|
| 294 |
+
assert(shape_act_.size() == shape_act.size());
|
| 295 |
+
assert(shape_flt_.size() == shape_flt.size());
|
| 296 |
+
assert(lower_padding_.size() == lower_padding.size());
|
| 297 |
+
assert(upper_padding_.size() == upper_padding.size());
|
| 298 |
+
assert(traversal_stride_.size() == traversal_stride.size());
|
| 299 |
+
assert(dilation_.size() == dilation.size());
|
| 300 |
+
|
| 301 |
+
std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
|
| 302 |
+
std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
|
| 303 |
+
std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
|
| 304 |
+
std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
|
| 305 |
+
std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
|
| 306 |
+
std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
|
| 307 |
+
stride_act = packed_stride_right_major(shape_act);
|
| 308 |
+
stride_flt = packed_stride_right_major(shape_flt);
|
| 309 |
+
|
| 310 |
+
auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
|
| 311 |
+
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
|
| 312 |
+
}
|
| 313 |
+
#endif // not defined(__CUDACC_RTC__)
|
| 314 |
+
|
| 315 |
+
// Set shape and stride of tensor A/B/C according to following table:
|
| 316 |
+
// | | Fprop | Dgrad | Wgrad |
|
| 317 |
+
// | ------ | ------ | ------ | ------|
|
| 318 |
+
// | ShapeA | NDHWC | NZPQK | NZPQK |
|
| 319 |
+
// | ShapeB | KTRSC | KTRSC | NDHWC |
|
| 320 |
+
// | ShapeC | NZPQK | NDHWC | KTRSC |
|
| 321 |
+
//
|
| 322 |
+
// Input comes from calculate_xformed_act, which does NOT depend on ConvOp.
|
| 323 |
+
CUTLASS_HOST_DEVICE
|
| 324 |
+
constexpr void
|
| 325 |
+
set_shape_stride_ABC(
|
| 326 |
+
TensorExtent shape_act,
|
| 327 |
+
TensorStride stride_act,
|
| 328 |
+
TensorExtent shape_flt,
|
| 329 |
+
TensorStride stride_flt,
|
| 330 |
+
TensorExtent shape_xformed_act,
|
| 331 |
+
TensorStride stride_xformed_act) {
|
| 332 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 333 |
+
printf("*** set_shape_stride_ABC ***");
|
| 334 |
+
printf("\n shape_act: ");
|
| 335 |
+
print(shape_act);
|
| 336 |
+
printf("\n stride_act: ");
|
| 337 |
+
print(stride_act);
|
| 338 |
+
printf("\n shape_flt: ");
|
| 339 |
+
print(shape_flt);
|
| 340 |
+
printf("\n stride_flt: ");
|
| 341 |
+
print(stride_flt);
|
| 342 |
+
printf("\n shape_xformed_act: ");
|
| 343 |
+
print(shape_xformed_act);
|
| 344 |
+
printf("\n stride_xformed_act: ");
|
| 345 |
+
print(stride_xformed_act);
|
| 346 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
|
| 347 |
+
printf("\n ConvOp: Fprop");
|
| 348 |
+
}
|
| 349 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 350 |
+
printf("\n ConvOp: Dgrad");
|
| 351 |
+
}
|
| 352 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 353 |
+
printf("\n ConvOp: Wgrad");
|
| 354 |
+
}
|
| 355 |
+
printf("\n");
|
| 356 |
+
#endif
|
| 357 |
+
|
| 358 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
|
| 359 |
+
shape_A = shape_act;
|
| 360 |
+
stride_A = stride_act;
|
| 361 |
+
shape_B = shape_flt;
|
| 362 |
+
stride_B = stride_flt;
|
| 363 |
+
shape_C = shape_xformed_act;
|
| 364 |
+
stride_C = stride_xformed_act;
|
| 365 |
+
}
|
| 366 |
+
else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 367 |
+
shape_A = shape_xformed_act;
|
| 368 |
+
stride_A = stride_xformed_act;
|
| 369 |
+
shape_B = shape_flt;
|
| 370 |
+
stride_B = stride_flt;
|
| 371 |
+
shape_C = shape_act;
|
| 372 |
+
stride_C = stride_act;
|
| 373 |
+
}
|
| 374 |
+
else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 375 |
+
shape_A = shape_xformed_act;
|
| 376 |
+
stride_A = stride_xformed_act;
|
| 377 |
+
shape_B = shape_act;
|
| 378 |
+
stride_B = stride_act;
|
| 379 |
+
shape_C = shape_flt;
|
| 380 |
+
stride_C = stride_flt;
|
| 381 |
+
}
|
| 382 |
+
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 383 |
+
printf("\n shape_A: ");
|
| 384 |
+
print(shape_A);
|
| 385 |
+
printf("\n stride_A: ");
|
| 386 |
+
print(stride_A);
|
| 387 |
+
printf("\n shape_B: ");
|
| 388 |
+
print(shape_B);
|
| 389 |
+
printf("\n stride_B: ");
|
| 390 |
+
print(stride_B);
|
| 391 |
+
printf("\n shape_C: ");
|
| 392 |
+
print(shape_C);
|
| 393 |
+
printf("\n stride_C: ");
|
| 394 |
+
print(stride_C);
|
| 395 |
+
#endif
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
// Get A extents.
|
| 399 |
+
// fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C))
|
| 400 |
+
// dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K))
|
| 401 |
+
// wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N))
|
| 402 |
+
CUTLASS_HOST_DEVICE
|
| 403 |
+
constexpr auto
|
| 404 |
+
get_shape_A() const {
|
| 405 |
+
using cute::make_shape;
|
| 406 |
+
using cute::take;
|
| 407 |
+
|
| 408 |
+
if constexpr (ConvOp == conv::Operator::kFprop ||
|
| 409 |
+
ConvOp == conv::Operator::kDgrad) {
|
| 410 |
+
return make_shape(
|
| 411 |
+
cute::reverse(take<0, RankT - 1>(shape_A)),
|
| 412 |
+
shape_A[RankT - 1]);
|
| 413 |
+
}
|
| 414 |
+
// For wgrad kernel, we need to linearize NZPQ for tensor A
|
| 415 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 416 |
+
return make_shape(
|
| 417 |
+
shape_A[RankT - 1],
|
| 418 |
+
cute::product(take<0, RankT - 1>(shape_A)));
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
// Get B extents.
|
| 423 |
+
// fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T))
|
| 424 |
+
// dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T))
|
| 425 |
+
// wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N))
|
| 426 |
+
CUTLASS_HOST_DEVICE
|
| 427 |
+
constexpr auto
|
| 428 |
+
get_shape_B() const {
|
| 429 |
+
using cute::make_shape;
|
| 430 |
+
using cute::reverse;
|
| 431 |
+
using cute::take;
|
| 432 |
+
|
| 433 |
+
if constexpr (ConvOp == conv::Operator::kFprop) {
|
| 434 |
+
return make_shape(
|
| 435 |
+
shape_B[0],
|
| 436 |
+
reverse(take<1, RankT>(shape_B)));
|
| 437 |
+
}
|
| 438 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 439 |
+
return make_shape(
|
| 440 |
+
shape_B[RankT - 1],
|
| 441 |
+
reverse(take<0, RankT - 1>(shape_B)));
|
| 442 |
+
}
|
| 443 |
+
else if constexpr (ConvOp == conv::Operator::kDgrad) {
|
| 444 |
+
// shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)]
|
| 445 |
+
return make_shape(
|
| 446 |
+
shape_B[RankT - 1],
|
| 447 |
+
cute::insert<0>(
|
| 448 |
+
reverse(take<1, RankT - 1>(shape_B)),
|
| 449 |
+
shape_B[0]));
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
// Get C extents.
|
| 454 |
+
// fprop: C extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K))
|
| 455 |
+
// dgrad: C extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C))
|
| 456 |
+
// wgrad: C extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T))
|
| 457 |
+
CUTLASS_HOST_DEVICE
|
| 458 |
+
constexpr auto
|
| 459 |
+
get_shape_C() const {
|
| 460 |
+
using cute::make_shape;
|
| 461 |
+
using cute::reverse;
|
| 462 |
+
using cute::take;
|
| 463 |
+
|
| 464 |
+
if constexpr (ConvOp == conv::Operator::kFprop ||
|
| 465 |
+
ConvOp == conv::Operator::kDgrad) {
|
| 466 |
+
return make_shape(
|
| 467 |
+
reverse(take<0, RankT - 1>(shape_C)),
|
| 468 |
+
shape_C[RankT - 1]);
|
| 469 |
+
}
|
| 470 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 471 |
+
return make_shape(
|
| 472 |
+
shape_C[0],
|
| 473 |
+
reverse(take<1, RankT>(shape_C)));
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
// Static method that returns the canonical strides of tensors (layouts are right major and compact)
|
| 478 |
+
CUTLASS_HOST_DEVICE
|
| 479 |
+
static constexpr TensorStride
|
| 480 |
+
packed_stride_right_major(TensorExtent const& extents) {
|
| 481 |
+
TensorStride strides{};
|
| 482 |
+
strides[RankT-1] = 1;
|
| 483 |
+
cute::for_each(cute::make_rseq<RankT-1>{}, [&](auto i) {
|
| 484 |
+
strides[i] = extents[i+1] * strides[i+1];
|
| 485 |
+
});
|
| 486 |
+
return strides;
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
// Static method that returns the packed logical size of any TensorExtent
|
| 490 |
+
CUTLASS_HOST_DEVICE
|
| 491 |
+
static constexpr size_t
|
| 492 |
+
size(TensorExtent const& extents) {
|
| 493 |
+
size_t size = 1;
|
| 494 |
+
cute::for_each(cute::make_seq<RankT>{}, [&](auto i) {
|
| 495 |
+
size *= extents[i];
|
| 496 |
+
});
|
| 497 |
+
return size;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
CUTLASS_HOST_DEVICE
|
| 501 |
+
constexpr size_t
|
| 502 |
+
size_A() const {
|
| 503 |
+
return shape_A[0] * stride_A[0];
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
CUTLASS_HOST_DEVICE
|
| 507 |
+
constexpr size_t
|
| 508 |
+
size_B() const {
|
| 509 |
+
return shape_B[0] * stride_B[0];
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
CUTLASS_HOST_DEVICE
|
| 513 |
+
constexpr size_t
|
| 514 |
+
size_C() const {
|
| 515 |
+
return shape_C[0] * stride_C[0];
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
// Equality operator
|
| 519 |
+
CUTLASS_HOST_DEVICE
|
| 520 |
+
bool operator==(ConvProblemShape<ConvOp, NumSpatialDimensions> const& rhs) const {
|
| 521 |
+
using cute::for_each;
|
| 522 |
+
using cute::make_seq;
|
| 523 |
+
|
| 524 |
+
bool is_equal = true;
|
| 525 |
+
|
| 526 |
+
// Compare all tensor extents
|
| 527 |
+
for_each(make_seq<RankT>{}, [&](auto i) {
|
| 528 |
+
is_equal = is_equal
|
| 529 |
+
&& (shape_A[i] == rhs.shape_A[i])
|
| 530 |
+
&& (shape_B[i] == rhs.shape_B[i]);
|
| 531 |
+
});
|
| 532 |
+
|
| 533 |
+
// Compare all spatial extents
|
| 534 |
+
for_each(make_seq<RankS>{}, [&](auto i) {
|
| 535 |
+
is_equal = is_equal
|
| 536 |
+
&& (lower_padding[i] == rhs.lower_padding[i])
|
| 537 |
+
&& (upper_padding[i] == rhs.upper_padding[i])
|
| 538 |
+
&& (traversal_stride[i] == rhs.traversal_stride[i])
|
| 539 |
+
&& (dilation[i] == rhs.dilation[i]);
|
| 540 |
+
});
|
| 541 |
+
|
| 542 |
+
return is_equal;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
/// Inequality operator
|
| 546 |
+
CUTLASS_HOST_DEVICE
|
| 547 |
+
bool operator!=(ConvProblemShape<ConvOp, NumSpatialDimensions> const &rhs) const {
|
| 548 |
+
return !(*this == rhs);
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
private:
|
| 552 |
+
CUTLASS_HOST_DEVICE
|
| 553 |
+
constexpr auto
|
| 554 |
+
calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) {
|
| 555 |
+
TensorExtent shape_xformed_act{};
|
| 556 |
+
// calculate n,z,p,q,k.
|
| 557 |
+
// a helper lambda to compute a single spatial extent of the nzpqk tensor
|
| 558 |
+
auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) {
|
| 559 |
+
return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride;
|
| 560 |
+
};
|
| 561 |
+
|
| 562 |
+
shape_xformed_act[0] = shape_act[0]; // Activation N extent
|
| 563 |
+
cute::for_each(cute::make_seq<RankS>{}, [&](auto i) {
|
| 564 |
+
shape_xformed_act[i+1] = nzpqk_extent(
|
| 565 |
+
shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]);
|
| 566 |
+
});
|
| 567 |
+
shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent
|
| 568 |
+
|
| 569 |
+
TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act);
|
| 570 |
+
|
| 571 |
+
return cute::make_tuple(shape_xformed_act, stride_xformed_act);
|
| 572 |
+
}
|
| 573 |
+
};
|
| 574 |
+
|
| 575 |
+
template<
|
| 576 |
+
conv::Operator ConvOp,
|
| 577 |
+
int SpatialDim
|
| 578 |
+
>
|
| 579 |
+
void print(ConvProblemShape<ConvOp, SpatialDim> const& problem) {
|
| 580 |
+
printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n",
|
| 581 |
+
SpatialDim, int(ConvOp));
|
| 582 |
+
printf("\tTensorA: ");
|
| 583 |
+
cute::print(problem.shape_A); printf(":");
|
| 584 |
+
cute::print(problem.stride_A); printf("\n");
|
| 585 |
+
printf("\tTensorB: ");
|
| 586 |
+
cute::print(problem.shape_B); printf(":");
|
| 587 |
+
cute::print(problem.stride_B); printf("\n");
|
| 588 |
+
printf("\tTensorC: ");
|
| 589 |
+
cute::print(problem.shape_C); printf(":");
|
| 590 |
+
cute::print(problem.stride_C); printf("\n");
|
| 591 |
+
printf("\tLower padding: "); print(problem.lower_padding); printf("\n");
|
| 592 |
+
printf("\tUpper padding: "); print(problem.upper_padding); printf("\n");
|
| 593 |
+
printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n");
|
| 594 |
+
printf("\tDilation: "); print(problem.dilation); printf("\n");
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 598 |
+
|
| 599 |
+
} // namespace cutlass::conv
|
| 600 |
+
|
| 601 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
|
| 34 |
+
This file contains definitions and utility functions for describing convolution problem sizes in terms of
|
| 35 |
+
activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and
|
| 36 |
+
dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm
|
| 37 |
+
tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types.
|
| 38 |
+
|
| 39 |
+
* Mapping convolutions to Gemm computation *
|
| 40 |
+
|
| 41 |
+
Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm
|
| 42 |
+
(general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output.
|
| 43 |
+
The underlying gemm operation follows the standard gemm definition:
|
| 44 |
+
|
| 45 |
+
C = A * B + C
|
| 46 |
+
|
| 47 |
+
A and B are input matrices
|
| 48 |
+
C is source and output matrix
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped
|
| 52 |
+
to convolution tensors Activation, Filter and Output as described in the table below.
|
| 53 |
+
|
| 54 |
+
___________________________________________________________________________
|
| 55 |
+
ConvolutionalOperator | A | B | C
|
| 56 |
+
___________________________________________________________________________
|
| 57 |
+
| | | | |
|
| 58 |
+
| Fprop | Activation | Filter | Output |
|
| 59 |
+
| Dgrad | Output | Filter | Activation |
|
| 60 |
+
| Wgrad | Output | Activation | Filter |
|
| 61 |
+
___________________________________________________________________________
|
| 62 |
+
|
| 63 |
+
In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output).
|
| 64 |
+
|
| 65 |
+
For example, it's confusing and error prone to document a convolution class or function
|
| 66 |
+
as operating on "A, B, Output." Instead, use the mapping functions below,
|
| 67 |
+
and adhere to using either A, B, C or Activation, Filter, Output.
|
| 68 |
+
|
| 69 |
+
Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap
|
| 70 |
+
Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap
|
| 71 |
+
*/
|
| 72 |
+
|
| 73 |
+
#pragma once
|
| 74 |
+
|
| 75 |
+
#include "cutlass/cutlass.h"
|
| 76 |
+
#include "cutlass/layout/tensor.h"
|
| 77 |
+
#include "cutlass/tensor_coord.h"
|
| 78 |
+
#include "cutlass/fast_math.h"
|
| 79 |
+
#include "cutlass/gemm/gemm_enumerated_types.h"
|
| 80 |
+
#include "cutlass/matrix_coord.h"
|
| 81 |
+
|
| 82 |
+
namespace cutlass {
|
| 83 |
+
namespace conv {
|
| 84 |
+
|
| 85 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 86 |
+
|
| 87 |
+
/// Convolutional operator
|
| 88 |
+
enum class Operator {
|
| 89 |
+
kFprop,
|
| 90 |
+
kDgrad,
|
| 91 |
+
kWgrad,
|
| 92 |
+
kDeconv
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
/// Distinguishes convolution from cross correlation
|
| 96 |
+
enum class Mode {
|
| 97 |
+
kCrossCorrelation,
|
| 98 |
+
kConvolution
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
/// Selects among several implementation variants trading off performance with simplicity
|
| 102 |
+
enum class IteratorAlgorithm {
|
| 103 |
+
kAnalytic, ///< functionally correct in all cases but lower performance
|
| 104 |
+
kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad
|
| 105 |
+
kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize)
|
| 106 |
+
kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize)
|
| 107 |
+
kFixedStrideDilation ///< Optimized for fixed stride and dilation
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
/// Distinguishes among partial specializations that accelerate certain problems where convolution
|
| 111 |
+
/// stride is unit.
|
| 112 |
+
enum class StrideSupport {
|
| 113 |
+
kStrided, ///< arbitrary convolution stride
|
| 114 |
+
kUnity, ///< unit convolution stride
|
| 115 |
+
kFixed ///< fixed convolution stride
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
/// Identifies split-K mode
|
| 119 |
+
enum class SplitKMode {
|
| 120 |
+
kNone,
|
| 121 |
+
kSerial,
|
| 122 |
+
kParallel
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
/// Identifies group mode
|
| 126 |
+
enum class GroupMode {
|
| 127 |
+
kNone,
|
| 128 |
+
kSingleGroup, ///< One CTA calculates one group or less
|
| 129 |
+
kMultipleGroup, ///< One CTA calculates multiple groups
|
| 130 |
+
kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups)
|
| 131 |
+
};
|
| 132 |
+
|
| 133 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
|
| 135 |
+
/// Shape of a tensor
|
| 136 |
+
template <
|
| 137 |
+
int N = 1,
|
| 138 |
+
int H = 1,
|
| 139 |
+
int W = 1,
|
| 140 |
+
int C = 1
|
| 141 |
+
>
|
| 142 |
+
struct TensorNHWCShape {
|
| 143 |
+
static int const kN = N;
|
| 144 |
+
static int const kH = H;
|
| 145 |
+
static int const kW = W;
|
| 146 |
+
static int const kC = C;
|
| 147 |
+
|
| 148 |
+
static int const kHW = H * W;
|
| 149 |
+
static int const kNHW = N * kHW;
|
| 150 |
+
static int const kNHWC = N * H * W * C;
|
| 151 |
+
|
| 152 |
+
static int const kCount = kNHWC;
|
| 153 |
+
|
| 154 |
+
//
|
| 155 |
+
// Static member functions
|
| 156 |
+
//
|
| 157 |
+
|
| 158 |
+
/// Returns a Coord object
|
| 159 |
+
CUTLASS_HOST_DEVICE
|
| 160 |
+
static Coord<4> toCoord() {
|
| 161 |
+
return make_Coord(kN, kH, kW, kC);
|
| 162 |
+
}
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 166 |
+
|
| 167 |
+
/// Shape of a conv2d stride, which controls how the filter convolves around the input volume
|
| 168 |
+
template <
|
| 169 |
+
/// Stride in horizontal direction
|
| 170 |
+
int u = 1,
|
| 171 |
+
/// Stride in vertical direction
|
| 172 |
+
int v = 1
|
| 173 |
+
>
|
| 174 |
+
struct Stride2D {
|
| 175 |
+
static int const kU = u;
|
| 176 |
+
static int const kV = v;
|
| 177 |
+
|
| 178 |
+
//
|
| 179 |
+
// Static member functions
|
| 180 |
+
//
|
| 181 |
+
|
| 182 |
+
/// Returns a Coord object
|
| 183 |
+
CUTLASS_HOST_DEVICE
|
| 184 |
+
static Coord<2> toCoord() {
|
| 185 |
+
return make_Coord(kU, kV);
|
| 186 |
+
}
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 190 |
+
|
| 191 |
+
} // namespace conv
|
| 192 |
+
} // namespace cutlass
|
| 193 |
+
|
| 194 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
/***************************************************************************************************
|
| 3 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
*
|
| 6 |
+
* Redistribution and use in source and binary forms, with or without
|
| 7 |
+
* modification, are permitted provided that the following conditions are met:
|
| 8 |
+
*
|
| 9 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
* list of conditions and the following disclaimer.
|
| 11 |
+
*
|
| 12 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
* and/or other materials provided with the distribution.
|
| 15 |
+
*
|
| 16 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
* contributors may be used to endorse or promote products derived from
|
| 18 |
+
* this software without specific prior written permission.
|
| 19 |
+
*
|
| 20 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
*
|
| 31 |
+
**************************************************************************************************/
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include "cutlass/conv/convnd_problem_shape.hpp"
|
| 35 |
+
|
| 36 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
|
| 38 |
+
namespace cutlass::conv::detail {
|
| 39 |
+
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
// Helper function to get the problem shape
|
| 43 |
+
template <typename T, class ProblemShape>
|
| 44 |
+
auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) {
|
| 45 |
+
return T::get_problem_shape_MNKL(problem_shape);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <typename T, class ProblemShape>
|
| 49 |
+
ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) {
|
| 50 |
+
return problem_shape;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Get problem shape MNKL according to following table:
|
| 54 |
+
// | | Fprop | Dgrad | Wgrad |
|
| 55 |
+
// | ---- | --------- | -------- | -------- |
|
| 56 |
+
// | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) |
|
| 57 |
+
// | Shape_N | (K) | (C) | (C,S,R,T) |
|
| 58 |
+
// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) |
|
| 59 |
+
// | Shape_L | _1 | (V,U,O) | _1 |
|
| 60 |
+
|
| 61 |
+
template <class ProblemShape>
|
| 62 |
+
CUTLASS_HOST_DEVICE
|
| 63 |
+
constexpr auto
|
| 64 |
+
get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) {
|
| 65 |
+
return problem_shape;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
template <conv::Operator ConvOp, int SpatialDim>
|
| 70 |
+
CUTLASS_HOST_DEVICE
|
| 71 |
+
constexpr auto
|
| 72 |
+
get_transformed_problem_shape_MNKL(ConvProblemShape<ConvOp, SpatialDim> const& problem_shape) {
|
| 73 |
+
using cute::insert;
|
| 74 |
+
using cute::make_shape;
|
| 75 |
+
using cute::reverse;
|
| 76 |
+
using cute::take;
|
| 77 |
+
|
| 78 |
+
constexpr int RankT = SpatialDim + 2;
|
| 79 |
+
|
| 80 |
+
if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 81 |
+
auto M_xformed = problem_shape.shape_C[0];
|
| 82 |
+
auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C));
|
| 83 |
+
auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A));
|
| 84 |
+
auto L_xformed = cute::Int<1>{};
|
| 85 |
+
|
| 86 |
+
return make_shape(M_xformed, N_xformed, K_xformed, L_xformed);
|
| 87 |
+
}
|
| 88 |
+
else if constexpr (ConvOp == conv::Operator::kFprop){
|
| 89 |
+
auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C));
|
| 90 |
+
auto N_xformed = problem_shape.shape_C[RankT - 1];
|
| 91 |
+
auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B));
|
| 92 |
+
auto L_xformed = cute::Int<1>{};
|
| 93 |
+
|
| 94 |
+
return make_shape(M_xformed, N_xformed, K_xformed, L_xformed);
|
| 95 |
+
}
|
| 96 |
+
else if constexpr (ConvOp == conv::Operator::kDgrad) {
|
| 97 |
+
auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O)
|
| 98 |
+
auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed);
|
| 99 |
+
auto N_xformed = problem_shape.shape_C[RankT - 1];
|
| 100 |
+
// shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T]
|
| 101 |
+
auto K_xformed = insert<0>(
|
| 102 |
+
(reverse(take<1,RankT - 1>(problem_shape.shape_B))),
|
| 103 |
+
problem_shape.shape_B[0]);
|
| 104 |
+
|
| 105 |
+
return make_shape(M_xformed, N_xformed, K_xformed, L_xformed);
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Assuming im2col linearization
|
| 110 |
+
// Get problem shape MNKL according to following table:
|
| 111 |
+
// | | Fprop | Dgrad | Wgrad |
|
| 112 |
+
// | ---- | --------- | -------- | -------- |
|
| 113 |
+
// | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) |
|
| 114 |
+
// | Shape_N | (K) | (C) | (C,S,R,T) |
|
| 115 |
+
// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) |
|
| 116 |
+
// | Shape_L | _1 | (V*U*O) | _1 |
|
| 117 |
+
template <conv::Operator ConvOp, int SpatialDim>
|
| 118 |
+
CUTLASS_HOST_DEVICE
|
| 119 |
+
constexpr auto
|
| 120 |
+
get_linearized_problem_shape_MNKL(ConvProblemShape<ConvOp, SpatialDim> const& problem_shape) {
|
| 121 |
+
|
| 122 |
+
auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape);
|
| 123 |
+
|
| 124 |
+
if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) {
|
| 125 |
+
return cute::make_shape(cute::product(M), N, K, cute::product(L));
|
| 126 |
+
}
|
| 127 |
+
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
| 128 |
+
return cute::make_shape(M, N, cute::product(K), L);
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
|
| 135 |
+
} // namespace cutlass::conv::detail
|
| 136 |
+
|
| 137 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
// common
|
| 34 |
+
#include "cutlass/arch/mma.h"
|
| 35 |
+
#include "cutlass/cutlass.h"
|
| 36 |
+
#include "cutlass/arch/mma.h"
|
| 37 |
+
#include "cutlass/trace.h"
|
| 38 |
+
#include "cutlass/cluster_launch.hpp"
|
| 39 |
+
#include "cutlass/device_kernel.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/conv/kernel/conv_universal.hpp"
|
| 42 |
+
#include "cutlass/gemm/gemm.h"
|
| 43 |
+
#include "cutlass/detail/layout.hpp"
|
| 44 |
+
#include "cutlass/cuda_host_adapter.hpp"
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass::conv::device {
|
| 49 |
+
|
| 50 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/*!
|
| 53 |
+
ConvUniversalAdapter is a stateful, reusable handle built around a kernel
|
| 54 |
+
of type cutlass::conv::kernel::ConvUniversal.
|
| 55 |
+
|
| 56 |
+
It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs
|
| 57 |
+
to create it from the host facing arguments. For power users, static methods
|
| 58 |
+
are exposed that bypass the stateful methods or args->params lowering.
|
| 59 |
+
*/
|
| 60 |
+
template <class ConvKernel_>
|
| 61 |
+
class ConvUniversalAdapter
|
| 62 |
+
{
|
| 63 |
+
public:
|
| 64 |
+
using ConvKernel = GetUnderlyingKernel_t<ConvKernel_>;
|
| 65 |
+
using TileShape = typename ConvKernel::TileShape;
|
| 66 |
+
using ElementA = typename ConvKernel::ElementA;
|
| 67 |
+
using ElementB = typename ConvKernel::ElementB;
|
| 68 |
+
using ElementC = typename ConvKernel::ElementC;
|
| 69 |
+
using ElementD = typename ConvKernel::ElementD;
|
| 70 |
+
using ElementAccumulator = typename ConvKernel::TiledMma::ValTypeC;
|
| 71 |
+
using DispatchPolicy = typename ConvKernel::DispatchPolicy;
|
| 72 |
+
using CollectiveMainloop = typename ConvKernel::CollectiveMainloop;
|
| 73 |
+
using CollectiveEpilogue = typename ConvKernel::CollectiveEpilogue;
|
| 74 |
+
|
| 75 |
+
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
| 76 |
+
|
| 77 |
+
// Tease out meta-information about the conv algorithm
|
| 78 |
+
static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp;
|
| 79 |
+
static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions;
|
| 80 |
+
|
| 81 |
+
// If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop!
|
| 82 |
+
using OperatorClass = cute::conditional_t<
|
| 83 |
+
(cute::size(typename ConvKernel::TiledMma::AtomThrID{}) > 1),
|
| 84 |
+
cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>;
|
| 85 |
+
|
| 86 |
+
using ArchTag = typename ConvKernel::ArchTag;
|
| 87 |
+
|
| 88 |
+
// Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape
|
| 89 |
+
using ThreadblockShape = cutlass::gemm::GemmShape<
|
| 90 |
+
cute::size<0>(TileShape{}),
|
| 91 |
+
cute::size<1>(TileShape{}),
|
| 92 |
+
cute::size<2>(TileShape{})>;
|
| 93 |
+
|
| 94 |
+
using ClusterShape = cutlass::gemm::GemmShape<
|
| 95 |
+
cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
| 96 |
+
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
| 97 |
+
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})>;
|
| 98 |
+
|
| 99 |
+
// Instruction shape is easy too, since we get that directly from our TiledMma's atom shape
|
| 100 |
+
using InstructionShape = cutlass::gemm::GemmShape<
|
| 101 |
+
cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}),
|
| 102 |
+
cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}),
|
| 103 |
+
cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>;
|
| 104 |
+
|
| 105 |
+
// Legacy: provide a correct warp count, but no reliable warp shape
|
| 106 |
+
static int const kThreadCount = ConvKernel::MaxThreadsPerBlock;
|
| 107 |
+
|
| 108 |
+
// Warp shape is not a primary API type in 3.x
|
| 109 |
+
// But we can best approximate it by inspecting the TiledMma
|
| 110 |
+
// For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K
|
| 111 |
+
// We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads
|
| 112 |
+
static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename ConvKernel::TiledMma{})) / 32);
|
| 113 |
+
static constexpr int WarpsInMmaM = 4;
|
| 114 |
+
static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
|
| 115 |
+
using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>;
|
| 116 |
+
using WarpShape = cutlass::gemm::GemmShape<
|
| 117 |
+
CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM,
|
| 118 |
+
CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN,
|
| 119 |
+
CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>;
|
| 120 |
+
|
| 121 |
+
static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages;
|
| 122 |
+
|
| 123 |
+
// Inspect TiledCopy for A and B to compute the alignment size
|
| 124 |
+
static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
|
| 125 |
+
typename CollectiveMainloop::GmemTiledCopyA, ElementA>();
|
| 126 |
+
static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
|
| 127 |
+
typename CollectiveMainloop::GmemTiledCopyB, ElementB>();
|
| 128 |
+
static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
|
| 129 |
+
typename CollectiveEpilogue::GmemTiledCopyC, ElementC>();
|
| 130 |
+
static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy<
|
| 131 |
+
typename CollectiveEpilogue::GmemTiledCopyD, ElementD>();
|
| 132 |
+
|
| 133 |
+
using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
| 134 |
+
|
| 135 |
+
/// Argument structure: User API
|
| 136 |
+
using Arguments = typename ConvKernel::Arguments;
|
| 137 |
+
/// Argument structure: Kernel API
|
| 138 |
+
using Params = typename ConvKernel::Params;
|
| 139 |
+
|
| 140 |
+
private:
|
| 141 |
+
|
| 142 |
+
/// Kernel API parameters object
|
| 143 |
+
Params params_;
|
| 144 |
+
|
| 145 |
+
public:
|
| 146 |
+
|
| 147 |
+
/// Access the Params structure
|
| 148 |
+
Params const& params() const {
|
| 149 |
+
return params_;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Determines whether the conv can execute the given problem.
|
| 153 |
+
static Status
|
| 154 |
+
can_implement(Arguments const& args) {
|
| 155 |
+
if (ConvKernel::can_implement(args)) {
|
| 156 |
+
return Status::kSuccess;
|
| 157 |
+
}
|
| 158 |
+
else {
|
| 159 |
+
return Status::kInvalid;
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Gets the workspace size
|
| 164 |
+
static size_t
|
| 165 |
+
get_workspace_size(Arguments const& args) {
|
| 166 |
+
size_t workspace_bytes = 0;
|
| 167 |
+
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
| 168 |
+
|
| 169 |
+
workspace_bytes += ConvKernel::get_workspace_size(args);
|
| 170 |
+
return workspace_bytes;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/// Computes the grid shape
|
| 174 |
+
static dim3
|
| 175 |
+
get_grid_shape(Arguments const& args, void* workspace = nullptr) {
|
| 176 |
+
auto tmp_params = ConvKernel::to_underlying_arguments(args, workspace);
|
| 177 |
+
return ConvKernel::get_grid_shape(tmp_params);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/// Computes the grid shape
|
| 181 |
+
static dim3
|
| 182 |
+
get_grid_shape(Params const& params) {
|
| 183 |
+
return ConvKernel::get_grid_shape(params);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/// Computes the maximum number of active blocks per multiprocessor
|
| 187 |
+
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
| 188 |
+
CUTLASS_TRACE_HOST("ConvUniversal::maximum_active_blocks()");
|
| 189 |
+
int max_active_blocks = -1;
|
| 190 |
+
int smem_size = ConvKernel::SharedStorageSize;
|
| 191 |
+
|
| 192 |
+
// first, account for dynamic smem capacity if needed
|
| 193 |
+
cudaError_t result;
|
| 194 |
+
if (smem_size >= (48 << 10)) {
|
| 195 |
+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
| 196 |
+
result = cudaFuncSetAttribute(
|
| 197 |
+
device_kernel<ConvKernel>,
|
| 198 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 199 |
+
smem_size);
|
| 200 |
+
if (cudaSuccess != result) {
|
| 201 |
+
result = cudaGetLastError(); // to clear the error bit
|
| 202 |
+
CUTLASS_TRACE_HOST(
|
| 203 |
+
" cudaFuncSetAttribute() returned error: "
|
| 204 |
+
<< cudaGetErrorString(result));
|
| 205 |
+
return -1;
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// query occupancy after setting smem size
|
| 210 |
+
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
| 211 |
+
&max_active_blocks,
|
| 212 |
+
device_kernel<ConvKernel>,
|
| 213 |
+
ConvKernel::MaxThreadsPerBlock,
|
| 214 |
+
smem_size);
|
| 215 |
+
|
| 216 |
+
if (cudaSuccess != result) {
|
| 217 |
+
result = cudaGetLastError(); // to clear the error bit
|
| 218 |
+
CUTLASS_TRACE_HOST(
|
| 219 |
+
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
| 220 |
+
<< cudaGetErrorString(result));
|
| 221 |
+
return -1;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
| 225 |
+
return max_active_blocks;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
/// Initializes conv state from arguments.
|
| 229 |
+
Status
|
| 230 |
+
initialize(
|
| 231 |
+
Arguments const& args,
|
| 232 |
+
void* workspace = nullptr,
|
| 233 |
+
cudaStream_t stream = nullptr,
|
| 234 |
+
CudaHostAdapter *cuda_adapter = nullptr) {
|
| 235 |
+
|
| 236 |
+
CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace "
|
| 237 |
+
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
| 238 |
+
|
| 239 |
+
// Initialize the workspace
|
| 240 |
+
Status status = ConvKernel::initialize_workspace(args, workspace, stream, cuda_adapter);
|
| 241 |
+
if (status != Status::kSuccess) {
|
| 242 |
+
return status;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// Initialize the Params structure
|
| 246 |
+
params_ = ConvKernel::to_underlying_arguments(args, workspace);
|
| 247 |
+
|
| 248 |
+
// Don't set the function attributes - require the CudaHostAdapter to set it.
|
| 249 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 250 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 251 |
+
return Status::kSuccess;
|
| 252 |
+
}
|
| 253 |
+
else {
|
| 254 |
+
// account for dynamic smem capacity if needed
|
| 255 |
+
int smem_size = ConvKernel::SharedStorageSize;
|
| 256 |
+
if (smem_size >= (48 << 10)) {
|
| 257 |
+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
| 258 |
+
cudaError_t result = cudaFuncSetAttribute(
|
| 259 |
+
device_kernel<ConvKernel>,
|
| 260 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 261 |
+
smem_size);
|
| 262 |
+
if (cudaSuccess != result) {
|
| 263 |
+
result = cudaGetLastError(); // to clear the error bit
|
| 264 |
+
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
| 265 |
+
return Status::kErrorInternal;
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
return Status::kSuccess;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
| 273 |
+
Status
|
| 274 |
+
update(Arguments const& args, void* workspace = nullptr) {
|
| 275 |
+
CUTLASS_TRACE_HOST("ConvUniversal()::update() - workspace: " << workspace);
|
| 276 |
+
|
| 277 |
+
size_t workspace_bytes = get_workspace_size(args);
|
| 278 |
+
if (workspace_bytes > 0 && nullptr == workspace) {
|
| 279 |
+
return Status::kErrorWorkspaceNull;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
params_ = ConvKernel::to_underlying_arguments(args, workspace);
|
| 283 |
+
return Status::kSuccess;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
| 287 |
+
/// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments()
|
| 288 |
+
static Status
|
| 289 |
+
run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 290 |
+
CUTLASS_TRACE_HOST("ConvUniversal::run()");
|
| 291 |
+
dim3 const block = ConvKernel::get_block_shape();
|
| 292 |
+
dim3 const grid = get_grid_shape(params);
|
| 293 |
+
|
| 294 |
+
// configure smem size and carveout
|
| 295 |
+
int smem_size = ConvKernel::SharedStorageSize;
|
| 296 |
+
|
| 297 |
+
Status launch_result;
|
| 298 |
+
// Use extended launch API only for mainloops that use it
|
| 299 |
+
if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
|
| 300 |
+
[[maybe_unused]] constexpr bool is_static_1x1x1 =
|
| 301 |
+
cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and
|
| 302 |
+
cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1;
|
| 303 |
+
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
| 304 |
+
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
| 305 |
+
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
|
| 306 |
+
// Dynamic cluster support
|
| 307 |
+
[[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0};
|
| 308 |
+
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
|
| 309 |
+
ConvKernel::ArchTag::kMinComputeCapability == 101) {
|
| 310 |
+
if constexpr (!cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape>) {
|
| 311 |
+
fallback_cluster = params.hw_info.cluster_shape_fallback;
|
| 312 |
+
cluster = params.hw_info.cluster_shape;
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
void* kernel_params[] = {¶ms};
|
| 317 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 318 |
+
//
|
| 319 |
+
// Use the cuda host adapter
|
| 320 |
+
//
|
| 321 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 322 |
+
if (cuda_adapter) {
|
| 323 |
+
|
| 324 |
+
launch_result = cuda_adapter->launch(grid,
|
| 325 |
+
cluster,
|
| 326 |
+
fallback_cluster,
|
| 327 |
+
block,
|
| 328 |
+
smem_size,
|
| 329 |
+
stream,
|
| 330 |
+
kernel_params,
|
| 331 |
+
kernel_index);
|
| 332 |
+
}
|
| 333 |
+
else {
|
| 334 |
+
return Status::kErrorInternal;
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
else {
|
| 338 |
+
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
| 339 |
+
void const* kernel = (void const*) device_kernel<ConvKernel>;
|
| 340 |
+
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90
|
| 341 |
+
|| ConvKernel::ArchTag::kMinComputeCapability == 100
|
| 342 |
+
) {
|
| 343 |
+
if constexpr (is_static_1x1x1) {
|
| 344 |
+
device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
|
| 345 |
+
launch_result = Status::kSuccess;
|
| 346 |
+
}
|
| 347 |
+
else {
|
| 348 |
+
launch_result = ClusterLauncher::launch(
|
| 349 |
+
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
else {
|
| 353 |
+
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
|
| 354 |
+
ConvKernel::ArchTag::kMinComputeCapability == 101) {
|
| 355 |
+
launch_result = ClusterLauncher::launch_with_fallback_cluster(
|
| 356 |
+
grid,
|
| 357 |
+
cluster,
|
| 358 |
+
fallback_cluster,
|
| 359 |
+
block,
|
| 360 |
+
smem_size,
|
| 361 |
+
stream,
|
| 362 |
+
kernel,
|
| 363 |
+
kernel_params);
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
else {
|
| 369 |
+
launch_result = Status::kSuccess;
|
| 370 |
+
|
| 371 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 372 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 373 |
+
if (cuda_adapter) {
|
| 374 |
+
void* kernel_params[] = {¶ms};
|
| 375 |
+
|
| 376 |
+
launch_result = cuda_adapter->launch(
|
| 377 |
+
grid, block, smem_size, stream, kernel_params, 0
|
| 378 |
+
);
|
| 379 |
+
|
| 380 |
+
}
|
| 381 |
+
else {
|
| 382 |
+
return Status::kErrorInternal;
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
else {
|
| 386 |
+
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
| 387 |
+
device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
cudaError_t result = cudaGetLastError();
|
| 392 |
+
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
| 393 |
+
return Status::kSuccess;
|
| 394 |
+
}
|
| 395 |
+
else {
|
| 396 |
+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
| 397 |
+
return Status::kErrorInternal;
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
//
|
| 402 |
+
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
| 403 |
+
//
|
| 404 |
+
|
| 405 |
+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
| 406 |
+
Status
|
| 407 |
+
run(
|
| 408 |
+
Arguments const& args,
|
| 409 |
+
void* workspace = nullptr,
|
| 410 |
+
cudaStream_t stream = nullptr,
|
| 411 |
+
CudaHostAdapter *cuda_adapter = nullptr,
|
| 412 |
+
int32_t kernel_index = 0
|
| 413 |
+
) {
|
| 414 |
+
Status status = initialize(args, workspace, stream, cuda_adapter);
|
| 415 |
+
if (Status::kSuccess == status) {
|
| 416 |
+
status = run(params_, stream, cuda_adapter, kernel_index);
|
| 417 |
+
}
|
| 418 |
+
return status;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
| 422 |
+
Status
|
| 423 |
+
operator()(
|
| 424 |
+
Arguments const& args,
|
| 425 |
+
void* workspace = nullptr,
|
| 426 |
+
cudaStream_t stream = nullptr,
|
| 427 |
+
CudaHostAdapter *cuda_adapter = nullptr) {
|
| 428 |
+
return run(args, workspace, stream, cuda_adapter);
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
| 432 |
+
Status
|
| 433 |
+
run(cudaStream_t stream = nullptr) {
|
| 434 |
+
return run(params_, stream);
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
| 438 |
+
Status
|
| 439 |
+
operator()(cudaStream_t stream = nullptr) {
|
| 440 |
+
return run(params_, stream);
|
| 441 |
+
}
|
| 442 |
+
};
|
| 443 |
+
|
| 444 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 445 |
+
|
| 446 |
+
} // namespace cutlass::conv::device
|
| 447 |
+
|
| 448 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Template for device-level Depthwise Convolution
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <limits>
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/device_kernel.h"
|
| 41 |
+
#include "cutlass/conv/convolution.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace conv {
|
| 47 |
+
namespace device {
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
template<typename DirectConvolutionKernel_>
|
| 52 |
+
class DirectConvolution {
|
| 53 |
+
public:
|
| 54 |
+
|
| 55 |
+
using UnderlyingKernel = DirectConvolutionKernel_;
|
| 56 |
+
|
| 57 |
+
using ElementA = typename UnderlyingKernel::ElementA;
|
| 58 |
+
using LayoutA = typename UnderlyingKernel::LayoutA;
|
| 59 |
+
using ElementB = typename UnderlyingKernel::ElementB;
|
| 60 |
+
using LayoutB = typename UnderlyingKernel::LayoutB;
|
| 61 |
+
using ElementC = typename UnderlyingKernel::ElementC;
|
| 62 |
+
using LayoutC = typename UnderlyingKernel::LayoutC;
|
| 63 |
+
using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
|
| 64 |
+
using ElementCompute = typename UnderlyingKernel::ElementCompute;
|
| 65 |
+
using OperatorClass = typename UnderlyingKernel::OperatorClass;
|
| 66 |
+
using ArchTag = typename UnderlyingKernel::ArchTag;
|
| 67 |
+
using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
|
| 68 |
+
using WarpShape = typename UnderlyingKernel::WarpShape;
|
| 69 |
+
using InstructionShape = typename UnderlyingKernel::InstructionShape;
|
| 70 |
+
using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
|
| 71 |
+
using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
|
| 72 |
+
static int const kStages = UnderlyingKernel::kStages;
|
| 73 |
+
static int const kConvDim = UnderlyingKernel::kConvDim;
|
| 74 |
+
using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
|
| 75 |
+
using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
|
| 76 |
+
using MathOperator = typename UnderlyingKernel::MathOperator;
|
| 77 |
+
|
| 78 |
+
static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
|
| 79 |
+
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
|
| 80 |
+
static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
|
| 81 |
+
static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
|
| 82 |
+
|
| 83 |
+
static int const kWarpCount =
|
| 84 |
+
(ThreadblockShape::kM / WarpShape::kM) *
|
| 85 |
+
(ThreadblockShape::kN / WarpShape::kN) *
|
| 86 |
+
(ThreadblockShape::kK / WarpShape::kK);
|
| 87 |
+
|
| 88 |
+
/// Argument structure
|
| 89 |
+
using Arguments = typename UnderlyingKernel::Arguments;
|
| 90 |
+
|
| 91 |
+
using ReorderKernel = typename UnderlyingKernel::ReorderKernel;
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
|
| 95 |
+
/// Kernel parameters object
|
| 96 |
+
typename UnderlyingKernel::Params params_;
|
| 97 |
+
|
| 98 |
+
public:
|
| 99 |
+
|
| 100 |
+
/// Constructs Implicit GEMM
|
| 101 |
+
DirectConvolution() { }
|
| 102 |
+
|
| 103 |
+
/// Determines whether the Implicit GEMM can execute the given problem.
|
| 104 |
+
static Status can_implement(Arguments const &args) {
|
| 105 |
+
|
| 106 |
+
// dispatch to iterators
|
| 107 |
+
Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
|
| 108 |
+
if (Status::kSuccess != status) {
|
| 109 |
+
return status;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
|
| 113 |
+
if (Status::kSuccess != status) {
|
| 114 |
+
return status;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
if (kGroupMode != conv::GroupMode::kDepthwise) {
|
| 118 |
+
return Status::kErrorInvalidProblem;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// C and K should be multiple of groups
|
| 122 |
+
if (args.problem_size.K != args.problem_size.groups &&
|
| 123 |
+
args.problem_size.C != args.problem_size.groups) {
|
| 124 |
+
return Status::kErrorInvalidProblem;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 129 |
+
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
| 130 |
+
if (args.problem_size.K % kAlignmentC)
|
| 131 |
+
return Status::kErrorMisalignedOperand;
|
| 132 |
+
} else if (kConvolutionalOperator == conv::Operator::kDgrad) {
|
| 133 |
+
if (args.problem_size.C % kAlignmentC)
|
| 134 |
+
return Status::kErrorMisalignedOperand;
|
| 135 |
+
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
| 136 |
+
if (args.problem_size.C % kAlignmentC)
|
| 137 |
+
return Status::kErrorMisalignedOperand;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Determine grid shape
|
| 141 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 142 |
+
|
| 143 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(
|
| 144 |
+
threadblock_swizzle.get_tiled_shape(
|
| 145 |
+
kConvolutionalOperator,
|
| 146 |
+
args.problem_size,
|
| 147 |
+
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 148 |
+
args.problem_size.split_k_slices));
|
| 149 |
+
|
| 150 |
+
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
| 151 |
+
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
| 152 |
+
|
| 153 |
+
return Status::kErrorInvalidProblem;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
return Status::kSuccess;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Gets the workspace size
|
| 160 |
+
static size_t get_workspace_size(Arguments const &args) {
|
| 161 |
+
return 0;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Initializes GEMM state from arguments.
|
| 165 |
+
Status initialize(
|
| 166 |
+
Arguments const &args,
|
| 167 |
+
void *workspace = nullptr,
|
| 168 |
+
cudaStream_t stream = nullptr) {
|
| 169 |
+
|
| 170 |
+
// initialize the params structure from the arguments
|
| 171 |
+
params_ = typename UnderlyingKernel::Params(
|
| 172 |
+
args,
|
| 173 |
+
static_cast<int *>(workspace)
|
| 174 |
+
);
|
| 175 |
+
|
| 176 |
+
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
|
| 177 |
+
|
| 178 |
+
if (smem_size >= (48 << 10)) {
|
| 179 |
+
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
|
| 180 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 181 |
+
smem_size);
|
| 182 |
+
|
| 183 |
+
if (result != cudaSuccess) {
|
| 184 |
+
return Status::kErrorInternal;
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
return Status::kSuccess;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
/// Initializes GEMM state from arguments.
|
| 192 |
+
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 193 |
+
|
| 194 |
+
// update the params structure from the arguments
|
| 195 |
+
params_.ptr_A = args.ref_A.data();
|
| 196 |
+
params_.ptr_B = args.ref_B.data();
|
| 197 |
+
params_.ptr_C = args.ref_C.data();
|
| 198 |
+
params_.ptr_D = args.ref_D.data();
|
| 199 |
+
params_.output_op = args.output_op;
|
| 200 |
+
params_.ptr_reordered_B = args.ref_reordered_B.data();
|
| 201 |
+
params_.semaphore = static_cast<int *>(workspace);
|
| 202 |
+
|
| 203 |
+
return Status::kSuccess;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/// Runs the kernel using initialized state.
|
| 207 |
+
Status run(cudaStream_t stream = nullptr) {
|
| 208 |
+
|
| 209 |
+
// Launch reorder kernel
|
| 210 |
+
if (params_.ptr_reordered_B != nullptr) {
|
| 211 |
+
dim3 grid = ReorderKernel::get_grid_shape(params_);
|
| 212 |
+
dim3 block = ReorderKernel::get_block_shape();
|
| 213 |
+
|
| 214 |
+
cutlass::arch::synclog_setup();
|
| 215 |
+
cutlass::Kernel<ReorderKernel><<<grid, block, 0, stream>>>(params_);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
// Launch main kernel
|
| 219 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 220 |
+
|
| 221 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
| 222 |
+
dim3 block(32 * kWarpCount, 1, 1);
|
| 223 |
+
|
| 224 |
+
// Dynamic SMEM size based on input params.
|
| 225 |
+
int smem_size = int(params_.get_smem_size());
|
| 226 |
+
|
| 227 |
+
// Make sure we can use that much shared memory.
|
| 228 |
+
cudaError_t status =
|
| 229 |
+
cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
| 230 |
+
if (status != cudaSuccess)
|
| 231 |
+
return Status::kErrorInternal;
|
| 232 |
+
|
| 233 |
+
cutlass::arch::synclog_setup();
|
| 234 |
+
cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
|
| 235 |
+
|
| 236 |
+
cudaError_t result = cudaGetLastError();
|
| 237 |
+
|
| 238 |
+
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
/// Runs the kernel using initialized state.
|
| 242 |
+
Status operator()(cudaStream_t stream = nullptr) {
|
| 243 |
+
return run(stream);
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
/// Runs the kernel using initialized state.
|
| 247 |
+
Status operator()(
|
| 248 |
+
Arguments const &args,
|
| 249 |
+
void *workspace = nullptr,
|
| 250 |
+
cudaStream_t stream = nullptr) {
|
| 251 |
+
|
| 252 |
+
Status status = initialize(args, workspace, stream);
|
| 253 |
+
|
| 254 |
+
if (status == Status::kSuccess) {
|
| 255 |
+
status = run(stream);
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
return status;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
int get_smem_size() { return int(params_.get_smem_size()); }
|
| 262 |
+
};
|
| 263 |
+
|
| 264 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 265 |
+
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Template for device-level Implicit GEMM Convolution
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <limits>
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/device_kernel.h"
|
| 41 |
+
#include "cutlass/conv/convolution.h"
|
| 42 |
+
#include "cutlass/cuda_host_adapter.hpp"
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace conv {
|
| 48 |
+
namespace device {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
template<typename ImplicitGemmKernel_>
|
| 53 |
+
class ImplicitGemmConvolution {
|
| 54 |
+
public:
|
| 55 |
+
|
| 56 |
+
using UnderlyingKernel = GetUnderlyingKernel_t<ImplicitGemmKernel_>;
|
| 57 |
+
|
| 58 |
+
using ElementA = typename UnderlyingKernel::ElementA;
|
| 59 |
+
using LayoutA = typename UnderlyingKernel::LayoutA;
|
| 60 |
+
using ElementB = typename UnderlyingKernel::ElementB;
|
| 61 |
+
using LayoutB = typename UnderlyingKernel::LayoutB;
|
| 62 |
+
using ElementC = typename UnderlyingKernel::ElementC;
|
| 63 |
+
using LayoutC = typename UnderlyingKernel::LayoutC;
|
| 64 |
+
using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
|
| 65 |
+
using ElementCompute = typename UnderlyingKernel::ElementCompute;
|
| 66 |
+
using OperatorClass = typename UnderlyingKernel::OperatorClass;
|
| 67 |
+
using ArchTag = typename UnderlyingKernel::ArchTag;
|
| 68 |
+
using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
|
| 69 |
+
using WarpShape = typename UnderlyingKernel::WarpShape;
|
| 70 |
+
using InstructionShape = typename UnderlyingKernel::InstructionShape;
|
| 71 |
+
using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
|
| 72 |
+
using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
|
| 73 |
+
static int const kStages = UnderlyingKernel::kStages;
|
| 74 |
+
static int const kConvDim = UnderlyingKernel::kConvDim;
|
| 75 |
+
using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
|
| 76 |
+
using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
|
| 77 |
+
using MathOperator = typename UnderlyingKernel::MathOperator;
|
| 78 |
+
|
| 79 |
+
static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
|
| 80 |
+
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
|
| 81 |
+
static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
|
| 82 |
+
static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
|
| 83 |
+
|
| 84 |
+
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
| 85 |
+
|
| 86 |
+
static int const kWarpCount =
|
| 87 |
+
(ThreadblockShape::kM / WarpShape::kM) *
|
| 88 |
+
(ThreadblockShape::kN / WarpShape::kN) *
|
| 89 |
+
(ThreadblockShape::kK / WarpShape::kK);
|
| 90 |
+
|
| 91 |
+
/// Argument structure
|
| 92 |
+
using Arguments = typename UnderlyingKernel::Arguments;
|
| 93 |
+
|
| 94 |
+
private:
|
| 95 |
+
|
| 96 |
+
/// Kernel parameters object
|
| 97 |
+
typename UnderlyingKernel::Params params_;
|
| 98 |
+
|
| 99 |
+
public:
|
| 100 |
+
|
| 101 |
+
/// Constructs Implicit GEMM
|
| 102 |
+
ImplicitGemmConvolution() { }
|
| 103 |
+
|
| 104 |
+
/// Determines whether the Implicit GEMM can execute the given problem.
|
| 105 |
+
static Status can_implement(Arguments const &args) {
|
| 106 |
+
// dispatch to iterators
|
| 107 |
+
Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
|
| 108 |
+
if (Status::kSuccess != status) {
|
| 109 |
+
return status;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
|
| 113 |
+
if (Status::kSuccess != status) {
|
| 114 |
+
return status;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Check that tensor sizes don't exceed maximum supported size
|
| 118 |
+
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
| 119 |
+
if (args.problem_size.activation_size() * sizeof(ElementA) >=
|
| 120 |
+
(1ull << 31) ||
|
| 121 |
+
args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
|
| 122 |
+
args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) {
|
| 123 |
+
return Status::kErrorInvalidProblem;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
else if (kConvolutionalOperator == conv::Operator::kDgrad ||
|
| 127 |
+
kConvolutionalOperator == conv::Operator::kDeconv) {
|
| 128 |
+
if (args.problem_size.activation_size() * sizeof(ElementC) >=
|
| 129 |
+
(1ull << 31) ||
|
| 130 |
+
args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) ||
|
| 131 |
+
args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
|
| 132 |
+
return Status::kErrorInvalidProblem;
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
| 136 |
+
if (args.problem_size.activation_size() * sizeof(ElementB) >=
|
| 137 |
+
(1ull << 31) ||
|
| 138 |
+
args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) ||
|
| 139 |
+
args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) {
|
| 140 |
+
return Status::kErrorInvalidProblem;
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// check group conv constraint
|
| 145 |
+
if (args.problem_size.groups != 1) {
|
| 146 |
+
if (kGroupMode == conv::GroupMode::kNone) {
|
| 147 |
+
return Status::kErrorInvalidProblem;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// C and K should be multiple of groups
|
| 151 |
+
if (args.problem_size.K % args.problem_size.groups ||
|
| 152 |
+
args.problem_size.C % args.problem_size.groups) {
|
| 153 |
+
return Status::kErrorInvalidProblem;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// split-k is not supported
|
| 157 |
+
if (args.problem_size.split_k_slices != 1) {
|
| 158 |
+
return Status::kErrorInvalidProblem;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
int k_per_group = args.problem_size.K / args.problem_size.groups;
|
| 162 |
+
// k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group
|
| 163 |
+
if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) {
|
| 164 |
+
return Status::kErrorInvalidProblem;
|
| 165 |
+
}
|
| 166 |
+
// ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups
|
| 167 |
+
if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) {
|
| 168 |
+
return Status::kErrorInvalidProblem;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// current optimized iterator algo only supports SingleGroup mode
|
| 172 |
+
if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized &&
|
| 173 |
+
kGroupMode != conv::GroupMode::kSingleGroup) {
|
| 174 |
+
return Status::kErrorInvalidProblem;
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 179 |
+
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
| 180 |
+
if (args.problem_size.K % kAlignmentC)
|
| 181 |
+
return Status::kErrorMisalignedOperand;
|
| 182 |
+
} else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) {
|
| 183 |
+
if (args.problem_size.C % kAlignmentC)
|
| 184 |
+
return Status::kErrorMisalignedOperand;
|
| 185 |
+
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
| 186 |
+
if (args.problem_size.C % kAlignmentC)
|
| 187 |
+
return Status::kErrorMisalignedOperand;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// check for unsupported problem sizes for strided dgrad / deconv implementation
|
| 191 |
+
if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) &&
|
| 192 |
+
kStrideSupport == conv::StrideSupport::kStrided) {
|
| 193 |
+
// split-k (serial or parallel) is not supported for strided dgrad / deconv
|
| 194 |
+
if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) {
|
| 195 |
+
return Status::kErrorNotSupported;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// dilation > {1x1} is not supported for strided dgrad / deconv
|
| 199 |
+
if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) {
|
| 200 |
+
return Status::kErrorNotSupported;
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// Determine grid shape
|
| 205 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 206 |
+
|
| 207 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(
|
| 208 |
+
threadblock_swizzle.get_tiled_shape(
|
| 209 |
+
kConvolutionalOperator,
|
| 210 |
+
args.problem_size,
|
| 211 |
+
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 212 |
+
args.problem_size.split_k_slices));
|
| 213 |
+
|
| 214 |
+
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
| 215 |
+
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
| 216 |
+
|
| 217 |
+
return Status::kErrorInvalidProblem;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
return Status::kSuccess;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
/// Gets the workspace size
|
| 224 |
+
static size_t get_workspace_size(Arguments const &args) {
|
| 225 |
+
|
| 226 |
+
size_t workspace_bytes = 0;
|
| 227 |
+
|
| 228 |
+
// Determine grid shape
|
| 229 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 230 |
+
|
| 231 |
+
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
| 232 |
+
kConvolutionalOperator,
|
| 233 |
+
args.problem_size,
|
| 234 |
+
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 235 |
+
args.problem_size.split_k_slices);
|
| 236 |
+
|
| 237 |
+
if(args.split_k_mode == SplitKMode::kParallel) {
|
| 238 |
+
|
| 239 |
+
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
| 240 |
+
// The user needs to call a reduction operator to optain the final output tensor
|
| 241 |
+
workspace_bytes =
|
| 242 |
+
sizeof(ElementAccumulator) *
|
| 243 |
+
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
|
| 244 |
+
size_t(grid_tiled_shape.k());
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
|
| 248 |
+
|
| 249 |
+
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
|
| 250 |
+
// final reduced output to user's output tensor
|
| 251 |
+
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
return workspace_bytes;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
/// Initializes GEMM state from arguments.
|
| 258 |
+
Status initialize(
|
| 259 |
+
Arguments const &args,
|
| 260 |
+
void *workspace = nullptr,
|
| 261 |
+
cudaStream_t stream = nullptr,
|
| 262 |
+
CudaHostAdapter *cuda_adapter = nullptr) {
|
| 263 |
+
|
| 264 |
+
if (args.problem_size.split_k_slices > 1) {
|
| 265 |
+
|
| 266 |
+
if (!workspace) {
|
| 267 |
+
return Status::kErrorWorkspaceNull;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
|
| 271 |
+
|
| 272 |
+
if (status != cudaSuccess) {
|
| 273 |
+
return Status::kErrorInternal;
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// initialize the params structure from the arguments
|
| 278 |
+
params_ = typename UnderlyingKernel::Params(
|
| 279 |
+
args,
|
| 280 |
+
static_cast<int *>(workspace)
|
| 281 |
+
);
|
| 282 |
+
|
| 283 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 284 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 285 |
+
return Status::kSuccess;
|
| 286 |
+
}
|
| 287 |
+
else {
|
| 288 |
+
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
|
| 289 |
+
|
| 290 |
+
if (smem_size >= (48 << 10)) {
|
| 291 |
+
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
|
| 292 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 293 |
+
smem_size);
|
| 294 |
+
|
| 295 |
+
if (result != cudaSuccess) {
|
| 296 |
+
return Status::kErrorInternal;
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
return Status::kSuccess;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/// Initializes GEMM state from arguments.
|
| 305 |
+
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 306 |
+
|
| 307 |
+
// update the params structure from the arguments
|
| 308 |
+
params_.ptr_A = args.ref_A.data();
|
| 309 |
+
params_.ptr_B = args.ref_B.data();
|
| 310 |
+
params_.ptr_C = args.ref_C.data();
|
| 311 |
+
params_.ptr_D = args.ref_D.data();
|
| 312 |
+
params_.output_op = args.output_op;
|
| 313 |
+
params_.semaphore = static_cast<int *>(workspace);
|
| 314 |
+
|
| 315 |
+
return Status::kSuccess;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
/// Runs the kernel using initialized state.
|
| 319 |
+
Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 323 |
+
|
| 324 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
| 325 |
+
dim3 block(32 * kWarpCount, 1, 1);
|
| 326 |
+
|
| 327 |
+
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
|
| 328 |
+
cutlass::Status launch_result = cutlass::Status::kSuccess ;
|
| 329 |
+
|
| 330 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 331 |
+
//
|
| 332 |
+
// Use the cuda host adapter
|
| 333 |
+
//
|
| 334 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 335 |
+
if (cuda_adapter) {
|
| 336 |
+
|
| 337 |
+
void* kernel_params[] = {¶ms_};
|
| 338 |
+
launch_result = cuda_adapter->launch(
|
| 339 |
+
grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index
|
| 340 |
+
);
|
| 341 |
+
}
|
| 342 |
+
else {
|
| 343 |
+
launch_result = Status::kErrorInternal;
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
else {
|
| 347 |
+
cutlass::arch::synclog_setup();
|
| 348 |
+
cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
cudaError_t result = cudaGetLastError();
|
| 352 |
+
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
| 353 |
+
return Status::kSuccess;
|
| 354 |
+
}
|
| 355 |
+
else {
|
| 356 |
+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
| 357 |
+
return Status::kErrorInternal;
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
/// Runs the kernel using initialized state.
|
| 362 |
+
Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 363 |
+
return run(stream, cuda_adapter, kernel_index);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
/// Runs the kernel using initialized state.
|
| 367 |
+
Status operator()(
|
| 368 |
+
Arguments const &args,
|
| 369 |
+
void *workspace = nullptr,
|
| 370 |
+
cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 371 |
+
|
| 372 |
+
Status status = initialize(args, workspace, stream, cuda_adapter);
|
| 373 |
+
|
| 374 |
+
if (status == Status::kSuccess) {
|
| 375 |
+
status = run(stream, cuda_adapter, kernel_index);
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
return status;
|
| 379 |
+
}
|
| 380 |
+
};
|
| 381 |
+
|
| 382 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 383 |
+
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <limits>
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/device_kernel.h"
|
| 41 |
+
#include "cutlass/conv/convolution.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace conv {
|
| 47 |
+
namespace device {
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
template<typename ImplicitGemmFusionKernel_>
|
| 52 |
+
class ImplicitGemmConvolutionFusion {
|
| 53 |
+
public:
|
| 54 |
+
|
| 55 |
+
using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_;
|
| 56 |
+
|
| 57 |
+
using ElementA = typename ImplicitGemmFusionKernel::ElementA;
|
| 58 |
+
using LayoutA = typename ImplicitGemmFusionKernel::LayoutA;
|
| 59 |
+
using ElementB = typename ImplicitGemmFusionKernel::ElementB;
|
| 60 |
+
using LayoutB = typename ImplicitGemmFusionKernel::LayoutB;
|
| 61 |
+
|
| 62 |
+
// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias;
|
| 63 |
+
// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias;
|
| 64 |
+
|
| 65 |
+
using ElementC = typename ImplicitGemmFusionKernel::ElementC;
|
| 66 |
+
using LayoutC = typename ImplicitGemmFusionKernel::LayoutC;
|
| 67 |
+
using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator;
|
| 68 |
+
using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute;
|
| 69 |
+
using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass;
|
| 70 |
+
using ArchTag = typename ImplicitGemmFusionKernel::ArchTag;
|
| 71 |
+
using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape;
|
| 72 |
+
using WarpShape = typename ImplicitGemmFusionKernel::WarpShape;
|
| 73 |
+
using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape;
|
| 74 |
+
using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle;
|
| 75 |
+
using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp;
|
| 76 |
+
static int const kStages = ImplicitGemmFusionKernel::kStages;
|
| 77 |
+
static int const kConvDim = ImplicitGemmFusionKernel::kConvDim;
|
| 78 |
+
using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator;
|
| 79 |
+
using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator;
|
| 80 |
+
using MathOperator = typename ImplicitGemmFusionKernel::MathOperator;
|
| 81 |
+
|
| 82 |
+
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator;
|
| 83 |
+
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm;
|
| 84 |
+
|
| 85 |
+
static int const kWarpCount =
|
| 86 |
+
(ThreadblockShape::kM / WarpShape::kM) *
|
| 87 |
+
(ThreadblockShape::kN / WarpShape::kN) *
|
| 88 |
+
(ThreadblockShape::kK / WarpShape::kK);
|
| 89 |
+
|
| 90 |
+
/// Argument structure
|
| 91 |
+
using Arguments = typename ImplicitGemmFusionKernel::Arguments;
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
|
| 95 |
+
/// Kernel parameters object
|
| 96 |
+
typename ImplicitGemmFusionKernel::Params params_;
|
| 97 |
+
|
| 98 |
+
public:
|
| 99 |
+
|
| 100 |
+
/// Constructs Implicit GEMM
|
| 101 |
+
ImplicitGemmConvolutionFusion() { }
|
| 102 |
+
|
| 103 |
+
/// Determines whether the Implicit GEMM can execute the given problem.
|
| 104 |
+
static Status can_implement(Arguments const &args) {
|
| 105 |
+
|
| 106 |
+
// dispatch to iterators
|
| 107 |
+
Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size);
|
| 108 |
+
if (Status::kSuccess != status) {
|
| 109 |
+
return status;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size);
|
| 113 |
+
if (Status::kSuccess != status) {
|
| 114 |
+
return status;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Determine grid shape
|
| 118 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 119 |
+
|
| 120 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(
|
| 121 |
+
threadblock_swizzle.get_tiled_shape(
|
| 122 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
| 123 |
+
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 124 |
+
args.problem_size.split_k_slices));
|
| 125 |
+
|
| 126 |
+
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
| 127 |
+
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
| 128 |
+
|
| 129 |
+
return Status::kErrorInvalidProblem;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
return Status::kSuccess;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
/// Gets the workspace size
|
| 136 |
+
static size_t get_workspace_size(Arguments const &args) {
|
| 137 |
+
|
| 138 |
+
size_t workspace_bytes = 0;
|
| 139 |
+
|
| 140 |
+
// Determine grid shape
|
| 141 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 142 |
+
|
| 143 |
+
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
| 144 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
| 145 |
+
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
| 146 |
+
args.problem_size.split_k_slices);
|
| 147 |
+
|
| 148 |
+
if(args.split_k_mode == SplitKMode::kParallel) {
|
| 149 |
+
|
| 150 |
+
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
| 151 |
+
// The user needs to call a reduction operator to optain the final output tensor
|
| 152 |
+
workspace_bytes =
|
| 153 |
+
sizeof(ElementAccumulator) *
|
| 154 |
+
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
|
| 155 |
+
size_t(grid_tiled_shape.k());
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
|
| 159 |
+
|
| 160 |
+
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
|
| 161 |
+
// final reduced output to user's output tensor
|
| 162 |
+
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
return workspace_bytes;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/// Initializes GEMM state from arguments.
|
| 169 |
+
Status initialize(
|
| 170 |
+
Arguments const &args,
|
| 171 |
+
void *workspace = nullptr,
|
| 172 |
+
cudaStream_t stream = nullptr) {
|
| 173 |
+
|
| 174 |
+
if (args.problem_size.split_k_slices > 1) {
|
| 175 |
+
|
| 176 |
+
if (!workspace) {
|
| 177 |
+
return Status::kErrorWorkspaceNull;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
|
| 181 |
+
|
| 182 |
+
if (status != cudaSuccess) {
|
| 183 |
+
return Status::kErrorInternal;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// initialize the params structure from the arguments
|
| 188 |
+
params_ = typename ImplicitGemmFusionKernel::Params(
|
| 189 |
+
args,
|
| 190 |
+
static_cast<int *>(workspace)
|
| 191 |
+
);
|
| 192 |
+
|
| 193 |
+
int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
|
| 194 |
+
|
| 195 |
+
if (smem_size >= (48 << 10)) {
|
| 196 |
+
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmFusionKernel>,
|
| 197 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 198 |
+
smem_size);
|
| 199 |
+
|
| 200 |
+
if (result != cudaSuccess) {
|
| 201 |
+
return Status::kErrorInternal;
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return Status::kSuccess;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
/// Initializes Impicit GEMM state from arguments.
|
| 209 |
+
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 210 |
+
|
| 211 |
+
// update the params structure from the arguments
|
| 212 |
+
params_.ptr_A = args.ref_A.data();
|
| 213 |
+
params_.ptr_B = args.ref_B.data();
|
| 214 |
+
params_.ptr_scale = args.ref_A_scale.data();
|
| 215 |
+
params_.ptr_bias = args.ref_A_bias.data();
|
| 216 |
+
params_.ptr_C = args.ref_C.data();
|
| 217 |
+
params_.ptr_D = args.ref_D.data();
|
| 218 |
+
params_.output_op = args.output_op;
|
| 219 |
+
params_.semaphore = static_cast<int *>(workspace);
|
| 220 |
+
|
| 221 |
+
return Status::kSuccess;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
/// Runs the kernel using initialized state.
|
| 225 |
+
Status run(cudaStream_t stream = nullptr) {
|
| 226 |
+
|
| 227 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 228 |
+
|
| 229 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
| 230 |
+
dim3 block(32 * kWarpCount, 1, 1);
|
| 231 |
+
|
| 232 |
+
int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
|
| 233 |
+
|
| 234 |
+
cutlass::arch::synclog_setup();
|
| 235 |
+
cutlass::Kernel<ImplicitGemmFusionKernel><<<grid, block, smem_size, stream>>>(params_);
|
| 236 |
+
|
| 237 |
+
cudaError_t result = cudaGetLastError();
|
| 238 |
+
|
| 239 |
+
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
/// Runs the kernel using initialized state.
|
| 243 |
+
Status operator()(cudaStream_t stream = nullptr) {
|
| 244 |
+
return run(stream);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/// Runs the kernel using initialized state.
|
| 248 |
+
Status operator()(
|
| 249 |
+
Arguments const &args,
|
| 250 |
+
void *workspace = nullptr,
|
| 251 |
+
cudaStream_t stream = nullptr) {
|
| 252 |
+
|
| 253 |
+
Status status = initialize(args, workspace, stream);
|
| 254 |
+
|
| 255 |
+
if (status == Status::kSuccess) {
|
| 256 |
+
status = run(stream);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return status;
|
| 260 |
+
}
|
| 261 |
+
};
|
| 262 |
+
|
| 263 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 264 |
+
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/conv/convolution.h"
|
| 34 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 35 |
+
#include "cutlass/arch/arch.h"
|
| 36 |
+
|
| 37 |
+
#include "cute/layout.hpp"
|
| 38 |
+
#include "cute/numeric/integral_constant.hpp"
|
| 39 |
+
|
| 40 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 41 |
+
|
| 42 |
+
//////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
//////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace cutlass::conv {
|
| 47 |
+
|
| 48 |
+
//////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
//
|
| 51 |
+
// Policies for categorical dispatch of mainloop against kernel grid schedules
|
| 52 |
+
//
|
| 53 |
+
struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { };
|
| 54 |
+
struct KernelImplicitTmaWarpSpecializedSm90Cooperative { };
|
| 55 |
+
struct KernelImplicitTmaWarpSpecializedSm90Pingpong { };
|
| 56 |
+
|
| 57 |
+
//
|
| 58 |
+
// Collective Mainloop Policies
|
| 59 |
+
//
|
| 60 |
+
|
| 61 |
+
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA
|
| 62 |
+
// for fprop
|
| 63 |
+
template<
|
| 64 |
+
conv::Operator ConvOp_,
|
| 65 |
+
int Stages_,
|
| 66 |
+
int NumSpatialDimensions_,
|
| 67 |
+
class ClusterShape_ = cute::Shape<cute::C<1>,cute::C<1>,cute::C<1>>,
|
| 68 |
+
class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90,
|
| 69 |
+
int PipelineAsyncMmaStages_ = 1
|
| 70 |
+
>
|
| 71 |
+
struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm {
|
| 72 |
+
static constexpr int Stages = Stages_;
|
| 73 |
+
static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
|
| 74 |
+
static constexpr Operator ConvOp = ConvOp_;
|
| 75 |
+
static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_;
|
| 76 |
+
using ClusterShape = ClusterShape_;
|
| 77 |
+
using ArchTag = arch::Sm90;
|
| 78 |
+
using Schedule = KernelSchedule;
|
| 79 |
+
|
| 80 |
+
static_assert(NumSpatialDimensions >= 1);
|
| 81 |
+
static_assert(! (cute::is_same_v<KernelSchedule,KernelImplicitTmaWarpSpecializedSm90Cooperative> ||
|
| 82 |
+
cute::is_same_v<KernelSchedule,KernelImplicitTmaWarpSpecializedSm90Pingpong>),
|
| 83 |
+
"Persistent schedules not support for conv yet.");
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
// SM100 tensor op kernel schedule
|
| 89 |
+
struct KernelImplicitTmaWarpSpecializedSm100 {
|
| 90 |
+
static constexpr int SchedulerPipelineStageCount = 0;
|
| 91 |
+
static constexpr int AccumulatorPipelineStageCount = 0;
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
// Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100
|
| 95 |
+
// but for opting into 1 or 2 SM atoms
|
| 96 |
+
struct KernelImplicitTmaWarpSpecialized1SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { };
|
| 97 |
+
struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { };
|
| 98 |
+
|
| 99 |
+
struct KernelStridedDgradTmaWs1SmSm100 { };
|
| 100 |
+
struct KernelStridedDgradTmaWs2SmSm100 { };
|
| 101 |
+
|
| 102 |
+
// Policy for implicit gemm kernel
|
| 103 |
+
template<
|
| 104 |
+
int SchedulerPipelineStageCount_,
|
| 105 |
+
int AccumulatorPipelineStageCount_
|
| 106 |
+
>
|
| 107 |
+
struct KernelScheduleImplicitTmaWarpSpecializedSm100 : KernelImplicitTmaWarpSpecializedSm100 {
|
| 108 |
+
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
| 109 |
+
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
// n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop
|
| 113 |
+
template<
|
| 114 |
+
conv::Operator ConvOp_,
|
| 115 |
+
int Stages_,
|
| 116 |
+
int NumSpatialDimensions_,
|
| 117 |
+
int SchedulerPipelineStageCount_,
|
| 118 |
+
int AccumulatorPipelineStageCount_,
|
| 119 |
+
class ClusterShape_ = cute::Shape<cute::C<1>,cute::C<1>,cute::C<1>>
|
| 120 |
+
>
|
| 121 |
+
struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm {
|
| 122 |
+
static constexpr int Stages = Stages_;
|
| 123 |
+
static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
|
| 124 |
+
static constexpr Operator ConvOp = ConvOp_;
|
| 125 |
+
using ClusterShape = ClusterShape_;
|
| 126 |
+
using ArchTag = arch::Sm100;
|
| 127 |
+
using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
| 128 |
+
|
| 129 |
+
static_assert(NumSpatialDimensions >= 1);
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
//////////////////////////////////////////////////////////////////////////////
|
| 133 |
+
|
| 134 |
+
} // namespace cutlass::conv
|
| 135 |
+
|
| 136 |
+
//////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/conv/convnd_problem_shape.hpp"
|
| 34 |
+
#include "cutlass/detail/dependent_false.hpp"
|
| 35 |
+
|
| 36 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
|
| 38 |
+
namespace cutlass::conv::kernel {
|
| 39 |
+
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
/*
|
| 43 |
+
* Stateless universal device CONV kernel type that treats CONV as
|
| 44 |
+
* a composition of a collective mainloop and a collective epilogue.
|
| 45 |
+
**/
|
| 46 |
+
template <
|
| 47 |
+
class ProblemShape_,
|
| 48 |
+
class CollectiveMainloop_,
|
| 49 |
+
class CollectiveEpilogue_,
|
| 50 |
+
class TileSchedulerTag_ = void,
|
| 51 |
+
class Enable = void
|
| 52 |
+
>
|
| 53 |
+
class ConvUniversal {
|
| 54 |
+
static_assert(cutlass::detail::dependent_false<Enable>,
|
| 55 |
+
"Could not find a valid specialization at the kernel layer to dispatch against.");
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
} // namespace cutlass::conv::kernel
|
| 61 |
+
|
| 62 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
#include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp"
|
| 64 |
+
#include "cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp"
|
| 65 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue.
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
| 41 |
+
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
| 42 |
+
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
|
| 43 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
| 44 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 45 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 46 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h"
|
| 47 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h"
|
| 48 |
+
#include "cutlass/conv/convolution.h"
|
| 49 |
+
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
| 50 |
+
#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h"
|
| 51 |
+
#include "cutlass/conv/threadblock/implicit_gemm_multistage.h"
|
| 52 |
+
#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h"
|
| 53 |
+
#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h"
|
| 54 |
+
#include "cutlass/conv/kernel/implicit_gemm_convolution.h"
|
| 55 |
+
#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h"
|
| 56 |
+
#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace kernel {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
namespace detail {
|
| 67 |
+
|
| 68 |
+
template <
|
| 69 |
+
typename ArchTag,
|
| 70 |
+
typename Shape,
|
| 71 |
+
typename WarpMmaTensorOp,
|
| 72 |
+
int PartitionsK,
|
| 73 |
+
typename OutputOp
|
| 74 |
+
>
|
| 75 |
+
struct DefaultConvEpilogue {
|
| 76 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 77 |
+
Shape,
|
| 78 |
+
WarpMmaTensorOp,
|
| 79 |
+
PartitionsK,
|
| 80 |
+
OutputOp,
|
| 81 |
+
OutputOp::kCount
|
| 82 |
+
>::Epilogue;
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
template <
|
| 86 |
+
typename Shape,
|
| 87 |
+
typename WarpMmaTensorOp,
|
| 88 |
+
int PartitionsK,
|
| 89 |
+
typename OutputOp
|
| 90 |
+
>
|
| 91 |
+
struct DefaultConvEpilogue<
|
| 92 |
+
arch::Sm70,
|
| 93 |
+
Shape,
|
| 94 |
+
WarpMmaTensorOp,
|
| 95 |
+
PartitionsK,
|
| 96 |
+
OutputOp
|
| 97 |
+
> {
|
| 98 |
+
|
| 99 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
|
| 100 |
+
Shape,
|
| 101 |
+
WarpMmaTensorOp,
|
| 102 |
+
PartitionsK,
|
| 103 |
+
OutputOp,
|
| 104 |
+
OutputOp::kCount
|
| 105 |
+
>::Epilogue;
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 109 |
+
template <
|
| 110 |
+
typename ArchTag,
|
| 111 |
+
typename Shape,
|
| 112 |
+
typename WarpMmaSimt,
|
| 113 |
+
typename ElementOutput,
|
| 114 |
+
typename ElementTensor,
|
| 115 |
+
typename ElementVector,
|
| 116 |
+
typename OutputOp,
|
| 117 |
+
int ElementsPerAccess,
|
| 118 |
+
typename PermuteDLayout = layout::NoPermute,
|
| 119 |
+
conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity,
|
| 120 |
+
int Rank = 4
|
| 121 |
+
>
|
| 122 |
+
struct DefaultConvEpilogueWithBroadcastSimt {
|
| 123 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt<
|
| 124 |
+
Shape,
|
| 125 |
+
WarpMmaSimt,
|
| 126 |
+
ElementOutput,
|
| 127 |
+
ElementTensor,
|
| 128 |
+
ElementVector,
|
| 129 |
+
OutputOp,
|
| 130 |
+
ElementsPerAccess,
|
| 131 |
+
false,
|
| 132 |
+
PermuteDLayout,
|
| 133 |
+
StrideSupport,
|
| 134 |
+
Rank
|
| 135 |
+
>::Epilogue;
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
template <
|
| 139 |
+
typename ArchTag,
|
| 140 |
+
typename Shape,
|
| 141 |
+
typename WarpMmaSimt,
|
| 142 |
+
typename ElementOutput,
|
| 143 |
+
typename ElementTensor,
|
| 144 |
+
typename ElementVector,
|
| 145 |
+
typename OutputOp,
|
| 146 |
+
int ElementsPerAccess
|
| 147 |
+
>
|
| 148 |
+
struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad {
|
| 149 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad<
|
| 150 |
+
Shape,
|
| 151 |
+
WarpMmaSimt,
|
| 152 |
+
ElementOutput,
|
| 153 |
+
ElementTensor,
|
| 154 |
+
ElementVector,
|
| 155 |
+
OutputOp,
|
| 156 |
+
ElementsPerAccess
|
| 157 |
+
>::Epilogue;
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
template <
|
| 161 |
+
typename ArchTag,
|
| 162 |
+
typename Shape,
|
| 163 |
+
typename WarpMmaTensorOp,
|
| 164 |
+
int PartitionsK,
|
| 165 |
+
typename ElementOutput,
|
| 166 |
+
typename ElementTensor,
|
| 167 |
+
typename ElementVector,
|
| 168 |
+
typename OutputOp,
|
| 169 |
+
int ElementsPerAccess
|
| 170 |
+
>
|
| 171 |
+
struct DefaultConvEpilogueWithBroadcastTensorOp {
|
| 172 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp<
|
| 173 |
+
Shape,
|
| 174 |
+
WarpMmaTensorOp,
|
| 175 |
+
PartitionsK,
|
| 176 |
+
ElementOutput,
|
| 177 |
+
ElementTensor,
|
| 178 |
+
ElementVector,
|
| 179 |
+
OutputOp,
|
| 180 |
+
ElementsPerAccess
|
| 181 |
+
>::Epilogue;
|
| 182 |
+
};
|
| 183 |
+
|
| 184 |
+
template <
|
| 185 |
+
typename Shape,
|
| 186 |
+
typename WarpMmaTensorOp,
|
| 187 |
+
int PartitionsK,
|
| 188 |
+
typename ElementOutput,
|
| 189 |
+
typename ElementTensor,
|
| 190 |
+
typename ElementVector,
|
| 191 |
+
typename OutputOp,
|
| 192 |
+
int ElementsPerAccess
|
| 193 |
+
>
|
| 194 |
+
struct DefaultConvEpilogueWithBroadcastTensorOp<
|
| 195 |
+
arch::Sm70,
|
| 196 |
+
Shape,
|
| 197 |
+
WarpMmaTensorOp,
|
| 198 |
+
PartitionsK,
|
| 199 |
+
ElementOutput,
|
| 200 |
+
ElementTensor,
|
| 201 |
+
ElementVector,
|
| 202 |
+
OutputOp,
|
| 203 |
+
ElementsPerAccess
|
| 204 |
+
> {
|
| 205 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp<
|
| 206 |
+
Shape,
|
| 207 |
+
WarpMmaTensorOp,
|
| 208 |
+
PartitionsK,
|
| 209 |
+
ElementOutput,
|
| 210 |
+
ElementTensor,
|
| 211 |
+
ElementVector,
|
| 212 |
+
OutputOp,
|
| 213 |
+
ElementsPerAccess
|
| 214 |
+
>::Epilogue;
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 218 |
+
|
| 219 |
+
template <
|
| 220 |
+
typename ArchTag,
|
| 221 |
+
typename Shape,
|
| 222 |
+
typename WarpMmaTensorOp,
|
| 223 |
+
int PartitionsK,
|
| 224 |
+
typename ElementOutput,
|
| 225 |
+
typename OutputOp,
|
| 226 |
+
typename ReductionOp,
|
| 227 |
+
int ElementsPerAccess
|
| 228 |
+
>
|
| 229 |
+
struct DefaultConvEpilogueWithReductionTensorOp {
|
| 230 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp<
|
| 231 |
+
Shape,
|
| 232 |
+
WarpMmaTensorOp,
|
| 233 |
+
PartitionsK,
|
| 234 |
+
ElementOutput,
|
| 235 |
+
OutputOp,
|
| 236 |
+
ReductionOp,
|
| 237 |
+
ElementsPerAccess
|
| 238 |
+
>::Epilogue;
|
| 239 |
+
};
|
| 240 |
+
|
| 241 |
+
template <
|
| 242 |
+
typename Shape,
|
| 243 |
+
typename WarpMmaTensorOp,
|
| 244 |
+
int PartitionsK,
|
| 245 |
+
typename ElementOutput,
|
| 246 |
+
typename OutputOp,
|
| 247 |
+
typename ReductionOp,
|
| 248 |
+
int ElementsPerAccess
|
| 249 |
+
>
|
| 250 |
+
struct DefaultConvEpilogueWithReductionTensorOp<
|
| 251 |
+
arch::Sm70,
|
| 252 |
+
Shape,
|
| 253 |
+
WarpMmaTensorOp,
|
| 254 |
+
PartitionsK,
|
| 255 |
+
ElementOutput,
|
| 256 |
+
OutputOp,
|
| 257 |
+
ReductionOp,
|
| 258 |
+
ElementsPerAccess
|
| 259 |
+
> {
|
| 260 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp<
|
| 261 |
+
Shape,
|
| 262 |
+
WarpMmaTensorOp,
|
| 263 |
+
PartitionsK,
|
| 264 |
+
ElementOutput,
|
| 265 |
+
OutputOp,
|
| 266 |
+
ReductionOp,
|
| 267 |
+
ElementsPerAccess
|
| 268 |
+
>::Epilogue;
|
| 269 |
+
};
|
| 270 |
+
|
| 271 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 272 |
+
|
| 273 |
+
// Defaults for strided Dgrad
|
| 274 |
+
template <
|
| 275 |
+
typename ArchTag,
|
| 276 |
+
typename Shape,
|
| 277 |
+
typename WarpMmaTensorOp,
|
| 278 |
+
int PartitionsK,
|
| 279 |
+
typename OutputOp
|
| 280 |
+
>
|
| 281 |
+
struct DefaultConvEpilogueStridedDgrad {
|
| 282 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
| 283 |
+
Shape,
|
| 284 |
+
WarpMmaTensorOp,
|
| 285 |
+
PartitionsK,
|
| 286 |
+
OutputOp,
|
| 287 |
+
OutputOp::kCount
|
| 288 |
+
>::Epilogue;
|
| 289 |
+
};
|
| 290 |
+
|
| 291 |
+
template <
|
| 292 |
+
typename Shape,
|
| 293 |
+
typename WarpMmaTensorOp,
|
| 294 |
+
int PartitionsK,
|
| 295 |
+
typename OutputOp
|
| 296 |
+
>
|
| 297 |
+
struct DefaultConvEpilogueStridedDgrad<
|
| 298 |
+
arch::Sm70,
|
| 299 |
+
Shape,
|
| 300 |
+
WarpMmaTensorOp,
|
| 301 |
+
PartitionsK,
|
| 302 |
+
OutputOp
|
| 303 |
+
> {
|
| 304 |
+
|
| 305 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad<
|
| 306 |
+
Shape,
|
| 307 |
+
WarpMmaTensorOp,
|
| 308 |
+
PartitionsK,
|
| 309 |
+
OutputOp,
|
| 310 |
+
OutputOp::kCount
|
| 311 |
+
>::Epilogue;
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
} // namespace detail
|
| 315 |
+
|
| 316 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 317 |
+
|
| 318 |
+
} // namespace kernel
|
| 319 |
+
} // namespace conv
|
| 320 |
+
} // namespace cutlass
|
| 321 |
+
|
| 322 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h
ADDED
|
@@ -0,0 +1,1927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass {
|
| 52 |
+
namespace conv {
|
| 53 |
+
namespace kernel {
|
| 54 |
+
|
| 55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
/// Defines a kernel for Conv2dDgrad
|
| 57 |
+
template <
|
| 58 |
+
typename ElementA,
|
| 59 |
+
typename LayoutA,
|
| 60 |
+
typename ElementB,
|
| 61 |
+
typename LayoutB,
|
| 62 |
+
typename ElementC,
|
| 63 |
+
typename LayoutC,
|
| 64 |
+
typename ElementAccumulator,
|
| 65 |
+
typename OperatorClass,
|
| 66 |
+
typename ArchTag,
|
| 67 |
+
typename ThreadblockShape,
|
| 68 |
+
typename WarpShape,
|
| 69 |
+
typename InstructionShape,
|
| 70 |
+
typename EpilogueOutputOp,
|
| 71 |
+
typename ThreadblockSwizzle,
|
| 72 |
+
int Stages,
|
| 73 |
+
typename MathOperatorTag,
|
| 74 |
+
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
| 75 |
+
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
| 76 |
+
/// Access granularity of A matrix in units of elements
|
| 77 |
+
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
| 78 |
+
/// Access granularity of B matrix in units of elements
|
| 79 |
+
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
| 80 |
+
> struct DefaultConv2dDgrad;
|
| 81 |
+
|
| 82 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 83 |
+
// OpClassTensorOp convolutions
|
| 84 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
|
| 86 |
+
/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and
|
| 87 |
+
// multistage pipeline.
|
| 88 |
+
template <
|
| 89 |
+
typename ElementA,
|
| 90 |
+
typename LayoutA,
|
| 91 |
+
typename ElementB,
|
| 92 |
+
typename LayoutB,
|
| 93 |
+
typename ElementC,
|
| 94 |
+
typename LayoutC,
|
| 95 |
+
typename ElementAccumulator,
|
| 96 |
+
typename ArchTag,
|
| 97 |
+
typename ThreadblockShape,
|
| 98 |
+
typename WarpShape,
|
| 99 |
+
typename InstructionShape,
|
| 100 |
+
typename EpilogueOutputOp,
|
| 101 |
+
typename ThreadblockSwizzle,
|
| 102 |
+
int Stages,
|
| 103 |
+
typename MathOperatorTag,
|
| 104 |
+
int AlignmentA,
|
| 105 |
+
int AlignmentB
|
| 106 |
+
>
|
| 107 |
+
struct DefaultConv2dDgrad <
|
| 108 |
+
ElementA,
|
| 109 |
+
LayoutA,
|
| 110 |
+
ElementB,
|
| 111 |
+
LayoutB,
|
| 112 |
+
ElementC,
|
| 113 |
+
LayoutC,
|
| 114 |
+
ElementAccumulator,
|
| 115 |
+
arch::OpClassTensorOp,
|
| 116 |
+
ArchTag,
|
| 117 |
+
ThreadblockShape,
|
| 118 |
+
WarpShape,
|
| 119 |
+
InstructionShape,
|
| 120 |
+
EpilogueOutputOp,
|
| 121 |
+
ThreadblockSwizzle,
|
| 122 |
+
Stages,
|
| 123 |
+
MathOperatorTag,
|
| 124 |
+
IteratorAlgorithm::kAnalytic,
|
| 125 |
+
StrideSupport::kStrided,
|
| 126 |
+
AlignmentA,
|
| 127 |
+
AlignmentB
|
| 128 |
+
> {
|
| 129 |
+
|
| 130 |
+
// Define the core components from GEMM
|
| 131 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 132 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 133 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 134 |
+
Stages, MathOperatorTag>;
|
| 135 |
+
|
| 136 |
+
// Define iterators over tiles from the A operand
|
| 137 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 138 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 139 |
+
using IteratorA =
|
| 140 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 141 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 142 |
+
ElementA,
|
| 143 |
+
ThreadMapA,
|
| 144 |
+
StrideSupport::kStrided,
|
| 145 |
+
AccessTypeA
|
| 146 |
+
>;
|
| 147 |
+
|
| 148 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 149 |
+
|
| 150 |
+
// Define iterators over tiles from the B operand
|
| 151 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 152 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 153 |
+
using IteratorB =
|
| 154 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 155 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 156 |
+
ElementB,
|
| 157 |
+
ThreadMapB,
|
| 158 |
+
StrideSupport::kStrided,
|
| 159 |
+
AccessTypeB
|
| 160 |
+
>;
|
| 161 |
+
|
| 162 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 163 |
+
|
| 164 |
+
// Warp-level GEMM components
|
| 165 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 166 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 167 |
+
|
| 168 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 169 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 170 |
+
? cutlass::arch::CacheOperation::Global
|
| 171 |
+
: cutlass::arch::CacheOperation::Always;
|
| 172 |
+
|
| 173 |
+
// Define the Mma
|
| 174 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 175 |
+
ThreadblockShape,
|
| 176 |
+
IteratorA,
|
| 177 |
+
SmemIteratorA,
|
| 178 |
+
arch::CacheOperation::Always,
|
| 179 |
+
IteratorB,
|
| 180 |
+
SmemIteratorB,
|
| 181 |
+
CacheOpB,
|
| 182 |
+
MmaPolicy,
|
| 183 |
+
Stages
|
| 184 |
+
>;
|
| 185 |
+
|
| 186 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 187 |
+
|
| 188 |
+
// Define the epilogue
|
| 189 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
| 190 |
+
ThreadblockShape,
|
| 191 |
+
WarpMmaTensorOp,
|
| 192 |
+
kPartitionsK,
|
| 193 |
+
EpilogueOutputOp,
|
| 194 |
+
EpilogueOutputOp::kCount
|
| 195 |
+
>::Epilogue;
|
| 196 |
+
|
| 197 |
+
// Define the kernel
|
| 198 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 199 |
+
Mma,
|
| 200 |
+
Epilogue,
|
| 201 |
+
ThreadblockSwizzle,
|
| 202 |
+
conv::Operator::kDgrad
|
| 203 |
+
>;
|
| 204 |
+
};
|
| 205 |
+
|
| 206 |
+
/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided
|
| 207 |
+
// and 2 stage pipeline.
|
| 208 |
+
template <
|
| 209 |
+
typename ElementA,
|
| 210 |
+
typename LayoutA,
|
| 211 |
+
typename ElementB,
|
| 212 |
+
typename LayoutB,
|
| 213 |
+
typename ElementC,
|
| 214 |
+
typename LayoutC,
|
| 215 |
+
typename ElementAccumulator,
|
| 216 |
+
typename ArchTag,
|
| 217 |
+
typename ThreadblockShape,
|
| 218 |
+
typename WarpShape,
|
| 219 |
+
typename InstructionShape,
|
| 220 |
+
typename EpilogueOutputOp,
|
| 221 |
+
typename ThreadblockSwizzle,
|
| 222 |
+
typename MathOperatorTag,
|
| 223 |
+
int AlignmentA,
|
| 224 |
+
int AlignmentB
|
| 225 |
+
>
|
| 226 |
+
struct DefaultConv2dDgrad <
|
| 227 |
+
ElementA,
|
| 228 |
+
LayoutA,
|
| 229 |
+
ElementB,
|
| 230 |
+
LayoutB,
|
| 231 |
+
ElementC,
|
| 232 |
+
LayoutC,
|
| 233 |
+
ElementAccumulator,
|
| 234 |
+
arch::OpClassTensorOp,
|
| 235 |
+
ArchTag,
|
| 236 |
+
ThreadblockShape,
|
| 237 |
+
WarpShape,
|
| 238 |
+
InstructionShape,
|
| 239 |
+
EpilogueOutputOp,
|
| 240 |
+
ThreadblockSwizzle,
|
| 241 |
+
2,
|
| 242 |
+
MathOperatorTag,
|
| 243 |
+
IteratorAlgorithm::kAnalytic,
|
| 244 |
+
StrideSupport::kStrided,
|
| 245 |
+
AlignmentA,
|
| 246 |
+
AlignmentB
|
| 247 |
+
> {
|
| 248 |
+
|
| 249 |
+
// Define the core components from GEMM
|
| 250 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 251 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 252 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 253 |
+
2, MathOperatorTag>;
|
| 254 |
+
|
| 255 |
+
// Define iterators over tiles from the A operand
|
| 256 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 257 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 258 |
+
using IteratorA =
|
| 259 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 260 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 261 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 262 |
+
ElementA,
|
| 263 |
+
ThreadMapA,
|
| 264 |
+
StrideSupport::kStrided,
|
| 265 |
+
AccessTypeA
|
| 266 |
+
>
|
| 267 |
+
>;
|
| 268 |
+
|
| 269 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 270 |
+
|
| 271 |
+
// Define iterators over tiles from the B operand
|
| 272 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 273 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 274 |
+
using IteratorB =
|
| 275 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 276 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 277 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 278 |
+
ElementB,
|
| 279 |
+
ThreadMapB,
|
| 280 |
+
StrideSupport::kStrided,
|
| 281 |
+
AccessTypeB
|
| 282 |
+
>
|
| 283 |
+
>;
|
| 284 |
+
|
| 285 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 286 |
+
|
| 287 |
+
// Warp-level GEMM components
|
| 288 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 289 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 290 |
+
|
| 291 |
+
// Define the Mma
|
| 292 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 293 |
+
ThreadblockShape,
|
| 294 |
+
IteratorA,
|
| 295 |
+
SmemIteratorA,
|
| 296 |
+
IteratorB,
|
| 297 |
+
SmemIteratorB,
|
| 298 |
+
ElementC,
|
| 299 |
+
LayoutC,
|
| 300 |
+
MmaPolicy
|
| 301 |
+
>;
|
| 302 |
+
|
| 303 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 304 |
+
|
| 305 |
+
// Define the epilogue
|
| 306 |
+
using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
|
| 307 |
+
ArchTag,
|
| 308 |
+
ThreadblockShape,
|
| 309 |
+
WarpMmaTensorOp,
|
| 310 |
+
kPartitionsK,
|
| 311 |
+
EpilogueOutputOp
|
| 312 |
+
>::Epilogue;
|
| 313 |
+
|
| 314 |
+
// Define the kernel
|
| 315 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 316 |
+
Mma,
|
| 317 |
+
Epilogue,
|
| 318 |
+
ThreadblockSwizzle,
|
| 319 |
+
conv::Operator::kDgrad
|
| 320 |
+
>;
|
| 321 |
+
};
|
| 322 |
+
|
| 323 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 324 |
+
|
| 325 |
+
/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided
|
| 326 |
+
// and multistage pipeline.
|
| 327 |
+
template <
|
| 328 |
+
typename ElementA,
|
| 329 |
+
typename LayoutA,
|
| 330 |
+
typename ElementB,
|
| 331 |
+
typename LayoutB,
|
| 332 |
+
typename ElementC,
|
| 333 |
+
typename LayoutC,
|
| 334 |
+
typename ElementAccumulator,
|
| 335 |
+
typename ArchTag,
|
| 336 |
+
typename ThreadblockShape,
|
| 337 |
+
typename WarpShape,
|
| 338 |
+
typename InstructionShape,
|
| 339 |
+
typename EpilogueOutputOp,
|
| 340 |
+
typename ThreadblockSwizzle,
|
| 341 |
+
int Stages,
|
| 342 |
+
typename MathOperatorTag,
|
| 343 |
+
int AlignmentA,
|
| 344 |
+
int AlignmentB
|
| 345 |
+
>
|
| 346 |
+
struct DefaultConv2dDgrad <
|
| 347 |
+
ElementA,
|
| 348 |
+
LayoutA,
|
| 349 |
+
ElementB,
|
| 350 |
+
LayoutB,
|
| 351 |
+
ElementC,
|
| 352 |
+
LayoutC,
|
| 353 |
+
ElementAccumulator,
|
| 354 |
+
arch::OpClassTensorOp,
|
| 355 |
+
ArchTag,
|
| 356 |
+
ThreadblockShape,
|
| 357 |
+
WarpShape,
|
| 358 |
+
InstructionShape,
|
| 359 |
+
EpilogueOutputOp,
|
| 360 |
+
ThreadblockSwizzle,
|
| 361 |
+
Stages,
|
| 362 |
+
MathOperatorTag,
|
| 363 |
+
IteratorAlgorithm::kAnalytic,
|
| 364 |
+
StrideSupport::kUnity,
|
| 365 |
+
AlignmentA,
|
| 366 |
+
AlignmentB
|
| 367 |
+
> {
|
| 368 |
+
|
| 369 |
+
// Define the core components from GEMM
|
| 370 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 371 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 372 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 373 |
+
Stages, MathOperatorTag>;
|
| 374 |
+
|
| 375 |
+
// Define iterators over tiles from the A operand
|
| 376 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 377 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 378 |
+
using IteratorA =
|
| 379 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 380 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 381 |
+
ElementA,
|
| 382 |
+
ThreadMapA,
|
| 383 |
+
StrideSupport::kUnity,
|
| 384 |
+
AccessTypeA
|
| 385 |
+
>;
|
| 386 |
+
|
| 387 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 388 |
+
|
| 389 |
+
// Define iterators over tiles from the B operand
|
| 390 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 391 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 392 |
+
using IteratorB =
|
| 393 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 394 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 395 |
+
ElementB,
|
| 396 |
+
ThreadMapB,
|
| 397 |
+
StrideSupport::kUnity,
|
| 398 |
+
AccessTypeB
|
| 399 |
+
>;
|
| 400 |
+
|
| 401 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 402 |
+
|
| 403 |
+
// Warp-level GEMM components
|
| 404 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 405 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 406 |
+
|
| 407 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 408 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 409 |
+
? cutlass::arch::CacheOperation::Global
|
| 410 |
+
: cutlass::arch::CacheOperation::Always;
|
| 411 |
+
|
| 412 |
+
// Define the Mma
|
| 413 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 414 |
+
ThreadblockShape,
|
| 415 |
+
IteratorA,
|
| 416 |
+
SmemIteratorA,
|
| 417 |
+
arch::CacheOperation::Always,
|
| 418 |
+
IteratorB,
|
| 419 |
+
SmemIteratorB,
|
| 420 |
+
CacheOpB,
|
| 421 |
+
MmaPolicy,
|
| 422 |
+
Stages
|
| 423 |
+
>;
|
| 424 |
+
|
| 425 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 426 |
+
|
| 427 |
+
// Define the epilogue
|
| 428 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 429 |
+
ThreadblockShape,
|
| 430 |
+
WarpMmaTensorOp,
|
| 431 |
+
kPartitionsK,
|
| 432 |
+
EpilogueOutputOp,
|
| 433 |
+
EpilogueOutputOp::kCount
|
| 434 |
+
>::Epilogue;
|
| 435 |
+
|
| 436 |
+
// Define the kernel
|
| 437 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 438 |
+
Mma,
|
| 439 |
+
Epilogue,
|
| 440 |
+
ThreadblockSwizzle,
|
| 441 |
+
conv::Operator::kDgrad
|
| 442 |
+
>;
|
| 443 |
+
};
|
| 444 |
+
|
| 445 |
+
/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity
|
| 446 |
+
// 2 stage pipeline.
|
| 447 |
+
template <
|
| 448 |
+
typename ElementA,
|
| 449 |
+
typename LayoutA,
|
| 450 |
+
typename ElementB,
|
| 451 |
+
typename LayoutB,
|
| 452 |
+
typename ElementC,
|
| 453 |
+
typename LayoutC,
|
| 454 |
+
typename ElementAccumulator,
|
| 455 |
+
typename ArchTag,
|
| 456 |
+
typename ThreadblockShape,
|
| 457 |
+
typename WarpShape,
|
| 458 |
+
typename InstructionShape,
|
| 459 |
+
typename EpilogueOutputOp,
|
| 460 |
+
typename ThreadblockSwizzle,
|
| 461 |
+
typename MathOperatorTag,
|
| 462 |
+
int AlignmentA,
|
| 463 |
+
int AlignmentB
|
| 464 |
+
>
|
| 465 |
+
struct DefaultConv2dDgrad <
|
| 466 |
+
ElementA,
|
| 467 |
+
LayoutA,
|
| 468 |
+
ElementB,
|
| 469 |
+
LayoutB,
|
| 470 |
+
ElementC,
|
| 471 |
+
LayoutC,
|
| 472 |
+
ElementAccumulator,
|
| 473 |
+
arch::OpClassTensorOp,
|
| 474 |
+
ArchTag,
|
| 475 |
+
ThreadblockShape,
|
| 476 |
+
WarpShape,
|
| 477 |
+
InstructionShape,
|
| 478 |
+
EpilogueOutputOp,
|
| 479 |
+
ThreadblockSwizzle,
|
| 480 |
+
2,
|
| 481 |
+
MathOperatorTag,
|
| 482 |
+
IteratorAlgorithm::kAnalytic,
|
| 483 |
+
StrideSupport::kUnity,
|
| 484 |
+
AlignmentA,
|
| 485 |
+
AlignmentB
|
| 486 |
+
> {
|
| 487 |
+
|
| 488 |
+
// Define the core components from GEMM
|
| 489 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 490 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 491 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 492 |
+
2, MathOperatorTag>;
|
| 493 |
+
|
| 494 |
+
// Define iterators over tiles from the A operand
|
| 495 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 496 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 497 |
+
using IteratorA =
|
| 498 |
+
cutlass::conv::threadblock::TileIterator<
|
| 499 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 500 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 501 |
+
ElementA,
|
| 502 |
+
ThreadMapA,
|
| 503 |
+
StrideSupport::kUnity,
|
| 504 |
+
AccessTypeA
|
| 505 |
+
>
|
| 506 |
+
>;
|
| 507 |
+
|
| 508 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 509 |
+
|
| 510 |
+
// Define iterators over tiles from the B operand
|
| 511 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 512 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 513 |
+
using IteratorB =
|
| 514 |
+
cutlass::conv::threadblock::TileIterator<
|
| 515 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 516 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 517 |
+
ElementB,
|
| 518 |
+
ThreadMapB,
|
| 519 |
+
StrideSupport::kUnity,
|
| 520 |
+
AccessTypeB
|
| 521 |
+
>
|
| 522 |
+
>;
|
| 523 |
+
|
| 524 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 525 |
+
|
| 526 |
+
// Warp-level GEMM components
|
| 527 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 528 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 529 |
+
|
| 530 |
+
// Define the Mma
|
| 531 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 532 |
+
ThreadblockShape,
|
| 533 |
+
IteratorA,
|
| 534 |
+
SmemIteratorA,
|
| 535 |
+
IteratorB,
|
| 536 |
+
SmemIteratorB,
|
| 537 |
+
ElementC,
|
| 538 |
+
LayoutC,
|
| 539 |
+
MmaPolicy
|
| 540 |
+
>;
|
| 541 |
+
|
| 542 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 543 |
+
|
| 544 |
+
// Define the epilogue
|
| 545 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 546 |
+
ArchTag,
|
| 547 |
+
ThreadblockShape,
|
| 548 |
+
WarpMmaTensorOp,
|
| 549 |
+
kPartitionsK,
|
| 550 |
+
EpilogueOutputOp
|
| 551 |
+
>::Epilogue;
|
| 552 |
+
|
| 553 |
+
// Define the kernel
|
| 554 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 555 |
+
Mma,
|
| 556 |
+
Epilogue,
|
| 557 |
+
ThreadblockSwizzle,
|
| 558 |
+
conv::Operator::kDgrad
|
| 559 |
+
>;
|
| 560 |
+
};
|
| 561 |
+
|
| 562 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 563 |
+
|
| 564 |
+
/// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided
|
| 565 |
+
// and multistage pipeline.
|
| 566 |
+
template <
|
| 567 |
+
typename ElementA,
|
| 568 |
+
typename LayoutA,
|
| 569 |
+
typename ElementB,
|
| 570 |
+
typename LayoutB,
|
| 571 |
+
typename ElementC,
|
| 572 |
+
typename LayoutC,
|
| 573 |
+
typename ElementAccumulator,
|
| 574 |
+
typename ArchTag,
|
| 575 |
+
typename ThreadblockShape,
|
| 576 |
+
typename WarpShape,
|
| 577 |
+
typename InstructionShape,
|
| 578 |
+
typename EpilogueOutputOp,
|
| 579 |
+
typename ThreadblockSwizzle,
|
| 580 |
+
int Stages,
|
| 581 |
+
typename MathOperatorTag,
|
| 582 |
+
int AlignmentA,
|
| 583 |
+
int AlignmentB
|
| 584 |
+
>
|
| 585 |
+
struct DefaultConv2dDgrad <
|
| 586 |
+
ElementA,
|
| 587 |
+
LayoutA,
|
| 588 |
+
ElementB,
|
| 589 |
+
LayoutB,
|
| 590 |
+
ElementC,
|
| 591 |
+
LayoutC,
|
| 592 |
+
ElementAccumulator,
|
| 593 |
+
arch::OpClassTensorOp,
|
| 594 |
+
ArchTag,
|
| 595 |
+
ThreadblockShape,
|
| 596 |
+
WarpShape,
|
| 597 |
+
InstructionShape,
|
| 598 |
+
EpilogueOutputOp,
|
| 599 |
+
ThreadblockSwizzle,
|
| 600 |
+
Stages,
|
| 601 |
+
MathOperatorTag,
|
| 602 |
+
IteratorAlgorithm::kOptimized,
|
| 603 |
+
StrideSupport::kUnity,
|
| 604 |
+
AlignmentA,
|
| 605 |
+
AlignmentB
|
| 606 |
+
> {
|
| 607 |
+
|
| 608 |
+
// Define the core components from GEMM
|
| 609 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 610 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 611 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 612 |
+
Stages, MathOperatorTag>;
|
| 613 |
+
|
| 614 |
+
// Define iterators over tiles from the A operand
|
| 615 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 616 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 617 |
+
using IteratorA =
|
| 618 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 619 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 620 |
+
ElementA,
|
| 621 |
+
ThreadMapA,
|
| 622 |
+
StrideSupport::kUnity,
|
| 623 |
+
AccessTypeA
|
| 624 |
+
>;
|
| 625 |
+
|
| 626 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 627 |
+
|
| 628 |
+
// Define iterators over tiles from the B operand
|
| 629 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 630 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 631 |
+
using IteratorB =
|
| 632 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 633 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 634 |
+
ElementB,
|
| 635 |
+
ThreadMapB,
|
| 636 |
+
StrideSupport::kUnity,
|
| 637 |
+
AccessTypeB
|
| 638 |
+
>;
|
| 639 |
+
|
| 640 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 641 |
+
|
| 642 |
+
// Warp-level GEMM components
|
| 643 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 644 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 645 |
+
|
| 646 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 647 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 648 |
+
? cutlass::arch::CacheOperation::Global
|
| 649 |
+
: cutlass::arch::CacheOperation::Always;
|
| 650 |
+
|
| 651 |
+
// Define the Mma
|
| 652 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 653 |
+
ThreadblockShape,
|
| 654 |
+
IteratorA,
|
| 655 |
+
SmemIteratorA,
|
| 656 |
+
arch::CacheOperation::Always,
|
| 657 |
+
IteratorB,
|
| 658 |
+
SmemIteratorB,
|
| 659 |
+
CacheOpB,
|
| 660 |
+
MmaPolicy,
|
| 661 |
+
Stages
|
| 662 |
+
>;
|
| 663 |
+
|
| 664 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 665 |
+
|
| 666 |
+
// Define the epilogue
|
| 667 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 668 |
+
ThreadblockShape,
|
| 669 |
+
WarpMmaTensorOp,
|
| 670 |
+
kPartitionsK,
|
| 671 |
+
EpilogueOutputOp,
|
| 672 |
+
EpilogueOutputOp::kCount
|
| 673 |
+
>::Epilogue;
|
| 674 |
+
|
| 675 |
+
// Define the kernel
|
| 676 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 677 |
+
Mma,
|
| 678 |
+
Epilogue,
|
| 679 |
+
ThreadblockSwizzle,
|
| 680 |
+
conv::Operator::kDgrad
|
| 681 |
+
>;
|
| 682 |
+
};
|
| 683 |
+
|
| 684 |
+
/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and
|
| 685 |
+
// multistage pipeline.
|
| 686 |
+
template <
|
| 687 |
+
typename ElementA,
|
| 688 |
+
typename LayoutA,
|
| 689 |
+
typename ElementB,
|
| 690 |
+
typename LayoutB,
|
| 691 |
+
typename ElementC,
|
| 692 |
+
typename LayoutC,
|
| 693 |
+
typename ElementAccumulator,
|
| 694 |
+
typename ArchTag,
|
| 695 |
+
typename ThreadblockShape,
|
| 696 |
+
typename WarpShape,
|
| 697 |
+
typename InstructionShape,
|
| 698 |
+
typename EpilogueOutputOp,
|
| 699 |
+
typename ThreadblockSwizzle,
|
| 700 |
+
int Stages,
|
| 701 |
+
typename MathOperatorTag,
|
| 702 |
+
int AlignmentA,
|
| 703 |
+
int AlignmentB
|
| 704 |
+
>
|
| 705 |
+
struct DefaultConv2dDgrad <
|
| 706 |
+
ElementA,
|
| 707 |
+
LayoutA,
|
| 708 |
+
ElementB,
|
| 709 |
+
LayoutB,
|
| 710 |
+
ElementC,
|
| 711 |
+
LayoutC,
|
| 712 |
+
ElementAccumulator,
|
| 713 |
+
arch::OpClassTensorOp,
|
| 714 |
+
ArchTag,
|
| 715 |
+
ThreadblockShape,
|
| 716 |
+
WarpShape,
|
| 717 |
+
InstructionShape,
|
| 718 |
+
EpilogueOutputOp,
|
| 719 |
+
ThreadblockSwizzle,
|
| 720 |
+
Stages,
|
| 721 |
+
MathOperatorTag,
|
| 722 |
+
IteratorAlgorithm::kOptimized,
|
| 723 |
+
StrideSupport::kStrided,
|
| 724 |
+
AlignmentA,
|
| 725 |
+
AlignmentB
|
| 726 |
+
> {
|
| 727 |
+
|
| 728 |
+
// Define the core components from GEMM
|
| 729 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 730 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 731 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 732 |
+
Stages, MathOperatorTag>;
|
| 733 |
+
|
| 734 |
+
// Define iterators over tiles from the A operand
|
| 735 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 736 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 737 |
+
using IteratorA =
|
| 738 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 739 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 740 |
+
ElementA,
|
| 741 |
+
ThreadMapA,
|
| 742 |
+
StrideSupport::kStrided,
|
| 743 |
+
AccessTypeA
|
| 744 |
+
>;
|
| 745 |
+
|
| 746 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 747 |
+
|
| 748 |
+
// Define iterators over tiles from the B operand
|
| 749 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 750 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 751 |
+
using IteratorB =
|
| 752 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 753 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 754 |
+
ElementB,
|
| 755 |
+
ThreadMapB,
|
| 756 |
+
StrideSupport::kStrided,
|
| 757 |
+
AccessTypeB
|
| 758 |
+
>;
|
| 759 |
+
|
| 760 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 761 |
+
|
| 762 |
+
// Warp-level GEMM components
|
| 763 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 764 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 765 |
+
|
| 766 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 767 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 768 |
+
? cutlass::arch::CacheOperation::Global
|
| 769 |
+
: cutlass::arch::CacheOperation::Always;
|
| 770 |
+
|
| 771 |
+
// Define the Mma
|
| 772 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 773 |
+
ThreadblockShape,
|
| 774 |
+
IteratorA,
|
| 775 |
+
SmemIteratorA,
|
| 776 |
+
arch::CacheOperation::Always,
|
| 777 |
+
IteratorB,
|
| 778 |
+
SmemIteratorB,
|
| 779 |
+
CacheOpB,
|
| 780 |
+
MmaPolicy,
|
| 781 |
+
Stages
|
| 782 |
+
>;
|
| 783 |
+
|
| 784 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 785 |
+
|
| 786 |
+
// Define the epilogue
|
| 787 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
| 788 |
+
ThreadblockShape,
|
| 789 |
+
WarpMmaTensorOp,
|
| 790 |
+
kPartitionsK,
|
| 791 |
+
EpilogueOutputOp,
|
| 792 |
+
EpilogueOutputOp::kCount
|
| 793 |
+
>::Epilogue;
|
| 794 |
+
|
| 795 |
+
// Define the kernel
|
| 796 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 797 |
+
Mma,
|
| 798 |
+
Epilogue,
|
| 799 |
+
ThreadblockSwizzle,
|
| 800 |
+
conv::Operator::kDgrad
|
| 801 |
+
>;
|
| 802 |
+
};
|
| 803 |
+
|
| 804 |
+
/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided
|
| 805 |
+
// and 2 stage pipeline.
|
| 806 |
+
template <
|
| 807 |
+
typename ElementA,
|
| 808 |
+
typename LayoutA,
|
| 809 |
+
typename ElementB,
|
| 810 |
+
typename LayoutB,
|
| 811 |
+
typename ElementC,
|
| 812 |
+
typename LayoutC,
|
| 813 |
+
typename ElementAccumulator,
|
| 814 |
+
typename ArchTag,
|
| 815 |
+
typename ThreadblockShape,
|
| 816 |
+
typename WarpShape,
|
| 817 |
+
typename InstructionShape,
|
| 818 |
+
typename EpilogueOutputOp,
|
| 819 |
+
typename ThreadblockSwizzle,
|
| 820 |
+
typename MathOperatorTag,
|
| 821 |
+
int AlignmentA,
|
| 822 |
+
int AlignmentB
|
| 823 |
+
>
|
| 824 |
+
struct DefaultConv2dDgrad <
|
| 825 |
+
ElementA,
|
| 826 |
+
LayoutA,
|
| 827 |
+
ElementB,
|
| 828 |
+
LayoutB,
|
| 829 |
+
ElementC,
|
| 830 |
+
LayoutC,
|
| 831 |
+
ElementAccumulator,
|
| 832 |
+
arch::OpClassTensorOp,
|
| 833 |
+
ArchTag,
|
| 834 |
+
ThreadblockShape,
|
| 835 |
+
WarpShape,
|
| 836 |
+
InstructionShape,
|
| 837 |
+
EpilogueOutputOp,
|
| 838 |
+
ThreadblockSwizzle,
|
| 839 |
+
2,
|
| 840 |
+
MathOperatorTag,
|
| 841 |
+
IteratorAlgorithm::kOptimized,
|
| 842 |
+
StrideSupport::kStrided,
|
| 843 |
+
AlignmentA,
|
| 844 |
+
AlignmentB
|
| 845 |
+
> {
|
| 846 |
+
|
| 847 |
+
// Define the core components from GEMM
|
| 848 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 849 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 850 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 851 |
+
2, MathOperatorTag>;
|
| 852 |
+
|
| 853 |
+
// Define iterators over tiles from the A operand
|
| 854 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 855 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 856 |
+
using IteratorA =
|
| 857 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 858 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 859 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 860 |
+
ElementA,
|
| 861 |
+
ThreadMapA,
|
| 862 |
+
StrideSupport::kStrided,
|
| 863 |
+
AccessTypeA
|
| 864 |
+
>
|
| 865 |
+
>;
|
| 866 |
+
|
| 867 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 868 |
+
|
| 869 |
+
// Define iterators over tiles from the B operand
|
| 870 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 871 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 872 |
+
using IteratorB =
|
| 873 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 874 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 875 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 876 |
+
ElementB,
|
| 877 |
+
ThreadMapB,
|
| 878 |
+
StrideSupport::kStrided,
|
| 879 |
+
AccessTypeB
|
| 880 |
+
>
|
| 881 |
+
>;
|
| 882 |
+
|
| 883 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 884 |
+
|
| 885 |
+
// Warp-level GEMM components
|
| 886 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 887 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 888 |
+
|
| 889 |
+
// Define the Mma
|
| 890 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 891 |
+
ThreadblockShape,
|
| 892 |
+
IteratorA,
|
| 893 |
+
SmemIteratorA,
|
| 894 |
+
IteratorB,
|
| 895 |
+
SmemIteratorB,
|
| 896 |
+
ElementC,
|
| 897 |
+
LayoutC,
|
| 898 |
+
MmaPolicy
|
| 899 |
+
>;
|
| 900 |
+
|
| 901 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 902 |
+
|
| 903 |
+
// Define the epilogue
|
| 904 |
+
using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
|
| 905 |
+
ArchTag,
|
| 906 |
+
ThreadblockShape,
|
| 907 |
+
WarpMmaTensorOp,
|
| 908 |
+
kPartitionsK,
|
| 909 |
+
EpilogueOutputOp
|
| 910 |
+
>::Epilogue;
|
| 911 |
+
|
| 912 |
+
// Define the kernel
|
| 913 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 914 |
+
Mma,
|
| 915 |
+
Epilogue,
|
| 916 |
+
ThreadblockSwizzle,
|
| 917 |
+
conv::Operator::kDgrad
|
| 918 |
+
>;
|
| 919 |
+
};
|
| 920 |
+
|
| 921 |
+
/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity
|
| 922 |
+
// 2 stage pipeline
|
| 923 |
+
template <
|
| 924 |
+
typename ElementA,
|
| 925 |
+
typename LayoutA,
|
| 926 |
+
typename ElementB,
|
| 927 |
+
typename LayoutB,
|
| 928 |
+
typename ElementC,
|
| 929 |
+
typename LayoutC,
|
| 930 |
+
typename ElementAccumulator,
|
| 931 |
+
typename ArchTag,
|
| 932 |
+
typename ThreadblockShape,
|
| 933 |
+
typename WarpShape,
|
| 934 |
+
typename InstructionShape,
|
| 935 |
+
typename EpilogueOutputOp,
|
| 936 |
+
typename ThreadblockSwizzle,
|
| 937 |
+
typename MathOperatorTag,
|
| 938 |
+
int AlignmentA,
|
| 939 |
+
int AlignmentB
|
| 940 |
+
>
|
| 941 |
+
struct DefaultConv2dDgrad <
|
| 942 |
+
ElementA,
|
| 943 |
+
LayoutA,
|
| 944 |
+
ElementB,
|
| 945 |
+
LayoutB,
|
| 946 |
+
ElementC,
|
| 947 |
+
LayoutC,
|
| 948 |
+
ElementAccumulator,
|
| 949 |
+
arch::OpClassTensorOp,
|
| 950 |
+
ArchTag,
|
| 951 |
+
ThreadblockShape,
|
| 952 |
+
WarpShape,
|
| 953 |
+
InstructionShape,
|
| 954 |
+
EpilogueOutputOp,
|
| 955 |
+
ThreadblockSwizzle,
|
| 956 |
+
2,
|
| 957 |
+
MathOperatorTag,
|
| 958 |
+
IteratorAlgorithm::kOptimized,
|
| 959 |
+
StrideSupport::kUnity,
|
| 960 |
+
AlignmentA,
|
| 961 |
+
AlignmentB
|
| 962 |
+
> {
|
| 963 |
+
|
| 964 |
+
// Define the core components from GEMM
|
| 965 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 966 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 967 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 968 |
+
2, MathOperatorTag>;
|
| 969 |
+
|
| 970 |
+
// Define iterators over tiles from the A operand
|
| 971 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 972 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 973 |
+
using IteratorA =
|
| 974 |
+
cutlass::conv::threadblock::TileIterator<
|
| 975 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 976 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 977 |
+
ElementA,
|
| 978 |
+
ThreadMapA,
|
| 979 |
+
StrideSupport::kUnity,
|
| 980 |
+
AccessTypeA
|
| 981 |
+
>
|
| 982 |
+
>;
|
| 983 |
+
|
| 984 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 985 |
+
|
| 986 |
+
// Define iterators over tiles from the B operand
|
| 987 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 988 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 989 |
+
using IteratorB =
|
| 990 |
+
cutlass::conv::threadblock::TileIterator<
|
| 991 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 992 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 993 |
+
ElementB,
|
| 994 |
+
ThreadMapB,
|
| 995 |
+
StrideSupport::kUnity,
|
| 996 |
+
AccessTypeB
|
| 997 |
+
>
|
| 998 |
+
>;
|
| 999 |
+
|
| 1000 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1001 |
+
|
| 1002 |
+
// Warp-level GEMM components
|
| 1003 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 1004 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1005 |
+
|
| 1006 |
+
// Define the Mma
|
| 1007 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1008 |
+
ThreadblockShape,
|
| 1009 |
+
IteratorA,
|
| 1010 |
+
SmemIteratorA,
|
| 1011 |
+
IteratorB,
|
| 1012 |
+
SmemIteratorB,
|
| 1013 |
+
ElementC,
|
| 1014 |
+
LayoutC,
|
| 1015 |
+
MmaPolicy
|
| 1016 |
+
>;
|
| 1017 |
+
|
| 1018 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 1019 |
+
|
| 1020 |
+
// Define the epilogue
|
| 1021 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 1022 |
+
ArchTag,
|
| 1023 |
+
ThreadblockShape,
|
| 1024 |
+
WarpMmaTensorOp,
|
| 1025 |
+
kPartitionsK,
|
| 1026 |
+
EpilogueOutputOp
|
| 1027 |
+
>::Epilogue;
|
| 1028 |
+
|
| 1029 |
+
// Define the kernel
|
| 1030 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1031 |
+
Mma,
|
| 1032 |
+
Epilogue,
|
| 1033 |
+
ThreadblockSwizzle,
|
| 1034 |
+
conv::Operator::kDgrad
|
| 1035 |
+
>;
|
| 1036 |
+
};
|
| 1037 |
+
|
| 1038 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1039 |
+
// OpClassSimt convolutions
|
| 1040 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1041 |
+
/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm,
|
| 1042 |
+
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
| 1043 |
+
|
| 1044 |
+
template <
|
| 1045 |
+
typename ElementA,
|
| 1046 |
+
typename LayoutA,
|
| 1047 |
+
typename ElementB,
|
| 1048 |
+
typename LayoutB,
|
| 1049 |
+
typename ElementC,
|
| 1050 |
+
typename LayoutC,
|
| 1051 |
+
typename ElementAccumulator,
|
| 1052 |
+
typename ArchTag,
|
| 1053 |
+
typename ThreadblockShape,
|
| 1054 |
+
typename WarpShape,
|
| 1055 |
+
typename InstructionShape,
|
| 1056 |
+
typename EpilogueOutputOp,
|
| 1057 |
+
typename ThreadblockSwizzle,
|
| 1058 |
+
int Stages,
|
| 1059 |
+
typename MathOperatorTag,
|
| 1060 |
+
int AlignmentA,
|
| 1061 |
+
int AlignmentB
|
| 1062 |
+
>
|
| 1063 |
+
struct DefaultConv2dDgrad <
|
| 1064 |
+
ElementA,
|
| 1065 |
+
LayoutA,
|
| 1066 |
+
ElementB,
|
| 1067 |
+
LayoutB,
|
| 1068 |
+
ElementC,
|
| 1069 |
+
LayoutC,
|
| 1070 |
+
ElementAccumulator,
|
| 1071 |
+
arch::OpClassSimt,
|
| 1072 |
+
ArchTag,
|
| 1073 |
+
ThreadblockShape,
|
| 1074 |
+
WarpShape,
|
| 1075 |
+
InstructionShape,
|
| 1076 |
+
EpilogueOutputOp,
|
| 1077 |
+
ThreadblockSwizzle,
|
| 1078 |
+
Stages,
|
| 1079 |
+
MathOperatorTag,
|
| 1080 |
+
IteratorAlgorithm::kAnalytic,
|
| 1081 |
+
conv::StrideSupport::kUnity,
|
| 1082 |
+
AlignmentA,
|
| 1083 |
+
AlignmentB
|
| 1084 |
+
> {
|
| 1085 |
+
|
| 1086 |
+
// Define the core components from GEMM
|
| 1087 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1088 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1089 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1090 |
+
Stages, MathOperatorTag>;
|
| 1091 |
+
|
| 1092 |
+
// Define iterators over tiles from the A operand
|
| 1093 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1094 |
+
using IteratorA =
|
| 1095 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 1096 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1097 |
+
ElementA,
|
| 1098 |
+
ThreadMapA,
|
| 1099 |
+
conv::StrideSupport::kUnity
|
| 1100 |
+
>;
|
| 1101 |
+
|
| 1102 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1103 |
+
|
| 1104 |
+
// Define iterators over tiles from the B operand
|
| 1105 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1106 |
+
using IteratorB =
|
| 1107 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 1108 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1109 |
+
ElementB,
|
| 1110 |
+
ThreadMapB,
|
| 1111 |
+
conv::StrideSupport::kUnity
|
| 1112 |
+
>;
|
| 1113 |
+
|
| 1114 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1115 |
+
|
| 1116 |
+
// Warp-level GEMM components
|
| 1117 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1118 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1119 |
+
|
| 1120 |
+
// Define the Mma
|
| 1121 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1122 |
+
ThreadblockShape,
|
| 1123 |
+
IteratorA,
|
| 1124 |
+
SmemIteratorA,
|
| 1125 |
+
arch::CacheOperation::Always,
|
| 1126 |
+
IteratorB,
|
| 1127 |
+
SmemIteratorB,
|
| 1128 |
+
arch::CacheOperation::Always,
|
| 1129 |
+
MmaPolicy,
|
| 1130 |
+
Stages
|
| 1131 |
+
>;
|
| 1132 |
+
|
| 1133 |
+
// Define the epilogue
|
| 1134 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1135 |
+
ThreadblockShape,
|
| 1136 |
+
WarpMmaSimtOp,
|
| 1137 |
+
EpilogueOutputOp,
|
| 1138 |
+
EpilogueOutputOp::kCount
|
| 1139 |
+
>::Epilogue;
|
| 1140 |
+
|
| 1141 |
+
// Define the kernel
|
| 1142 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1143 |
+
Mma,
|
| 1144 |
+
Epilogue,
|
| 1145 |
+
ThreadblockSwizzle,
|
| 1146 |
+
conv::Operator::kDgrad
|
| 1147 |
+
>;
|
| 1148 |
+
|
| 1149 |
+
};
|
| 1150 |
+
|
| 1151 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1152 |
+
|
| 1153 |
+
template <
|
| 1154 |
+
typename ElementA,
|
| 1155 |
+
typename LayoutA,
|
| 1156 |
+
typename ElementB,
|
| 1157 |
+
typename LayoutB,
|
| 1158 |
+
typename ElementC,
|
| 1159 |
+
typename LayoutC,
|
| 1160 |
+
typename ElementAccumulator,
|
| 1161 |
+
typename ArchTag,
|
| 1162 |
+
typename ThreadblockShape,
|
| 1163 |
+
typename WarpShape,
|
| 1164 |
+
typename InstructionShape,
|
| 1165 |
+
typename EpilogueOutputOp,
|
| 1166 |
+
typename ThreadblockSwizzle,
|
| 1167 |
+
int Stages,
|
| 1168 |
+
typename MathOperatorTag,
|
| 1169 |
+
int AlignmentA,
|
| 1170 |
+
int AlignmentB
|
| 1171 |
+
>
|
| 1172 |
+
struct DefaultConv2dDgrad <
|
| 1173 |
+
ElementA,
|
| 1174 |
+
LayoutA,
|
| 1175 |
+
ElementB,
|
| 1176 |
+
LayoutB,
|
| 1177 |
+
ElementC,
|
| 1178 |
+
LayoutC,
|
| 1179 |
+
ElementAccumulator,
|
| 1180 |
+
arch::OpClassSimt,
|
| 1181 |
+
ArchTag,
|
| 1182 |
+
ThreadblockShape,
|
| 1183 |
+
WarpShape,
|
| 1184 |
+
InstructionShape,
|
| 1185 |
+
EpilogueOutputOp,
|
| 1186 |
+
ThreadblockSwizzle,
|
| 1187 |
+
Stages,
|
| 1188 |
+
MathOperatorTag,
|
| 1189 |
+
IteratorAlgorithm::kAnalytic,
|
| 1190 |
+
conv::StrideSupport::kStrided,
|
| 1191 |
+
AlignmentA,
|
| 1192 |
+
AlignmentB
|
| 1193 |
+
> {
|
| 1194 |
+
|
| 1195 |
+
// Define the core components from GEMM
|
| 1196 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1197 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1198 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1199 |
+
Stages, MathOperatorTag>;
|
| 1200 |
+
|
| 1201 |
+
// Define iterators over tiles from the A operand
|
| 1202 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1203 |
+
using IteratorA =
|
| 1204 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 1205 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1206 |
+
ElementA,
|
| 1207 |
+
ThreadMapA,
|
| 1208 |
+
conv::StrideSupport::kStrided
|
| 1209 |
+
>;
|
| 1210 |
+
|
| 1211 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1212 |
+
|
| 1213 |
+
// Define iterators over tiles from the B operand
|
| 1214 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1215 |
+
using IteratorB =
|
| 1216 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 1217 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1218 |
+
ElementB,
|
| 1219 |
+
ThreadMapB,
|
| 1220 |
+
conv::StrideSupport::kStrided
|
| 1221 |
+
>;
|
| 1222 |
+
|
| 1223 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1224 |
+
|
| 1225 |
+
// Warp-level GEMM components
|
| 1226 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1227 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1228 |
+
|
| 1229 |
+
// Define the Mma
|
| 1230 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1231 |
+
ThreadblockShape,
|
| 1232 |
+
IteratorA,
|
| 1233 |
+
SmemIteratorA,
|
| 1234 |
+
arch::CacheOperation::Always,
|
| 1235 |
+
IteratorB,
|
| 1236 |
+
SmemIteratorB,
|
| 1237 |
+
arch::CacheOperation::Always,
|
| 1238 |
+
MmaPolicy,
|
| 1239 |
+
Stages
|
| 1240 |
+
>;
|
| 1241 |
+
|
| 1242 |
+
// Define the epilogue
|
| 1243 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
| 1244 |
+
ThreadblockShape,
|
| 1245 |
+
WarpMmaSimtOp,
|
| 1246 |
+
EpilogueOutputOp,
|
| 1247 |
+
EpilogueOutputOp::kCount
|
| 1248 |
+
>::Epilogue;
|
| 1249 |
+
|
| 1250 |
+
// Define the kernel
|
| 1251 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 1252 |
+
Mma,
|
| 1253 |
+
Epilogue,
|
| 1254 |
+
ThreadblockSwizzle,
|
| 1255 |
+
conv::Operator::kDgrad
|
| 1256 |
+
>;
|
| 1257 |
+
|
| 1258 |
+
};
|
| 1259 |
+
|
| 1260 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1261 |
+
|
| 1262 |
+
/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm,
|
| 1263 |
+
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
| 1264 |
+
|
| 1265 |
+
template <
|
| 1266 |
+
typename ElementA,
|
| 1267 |
+
typename LayoutA,
|
| 1268 |
+
typename ElementB,
|
| 1269 |
+
typename LayoutB,
|
| 1270 |
+
typename ElementC,
|
| 1271 |
+
typename LayoutC,
|
| 1272 |
+
typename ElementAccumulator,
|
| 1273 |
+
typename ArchTag,
|
| 1274 |
+
typename ThreadblockShape,
|
| 1275 |
+
typename WarpShape,
|
| 1276 |
+
typename InstructionShape,
|
| 1277 |
+
typename EpilogueOutputOp,
|
| 1278 |
+
typename ThreadblockSwizzle,
|
| 1279 |
+
int Stages,
|
| 1280 |
+
typename MathOperatorTag,
|
| 1281 |
+
int AlignmentA,
|
| 1282 |
+
int AlignmentB
|
| 1283 |
+
>
|
| 1284 |
+
struct DefaultConv2dDgrad <
|
| 1285 |
+
ElementA,
|
| 1286 |
+
LayoutA,
|
| 1287 |
+
ElementB,
|
| 1288 |
+
LayoutB,
|
| 1289 |
+
ElementC,
|
| 1290 |
+
LayoutC,
|
| 1291 |
+
ElementAccumulator,
|
| 1292 |
+
arch::OpClassSimt,
|
| 1293 |
+
ArchTag,
|
| 1294 |
+
ThreadblockShape,
|
| 1295 |
+
WarpShape,
|
| 1296 |
+
InstructionShape,
|
| 1297 |
+
EpilogueOutputOp,
|
| 1298 |
+
ThreadblockSwizzle,
|
| 1299 |
+
Stages,
|
| 1300 |
+
MathOperatorTag,
|
| 1301 |
+
IteratorAlgorithm::kOptimized,
|
| 1302 |
+
StrideSupport::kUnity,
|
| 1303 |
+
AlignmentA,
|
| 1304 |
+
AlignmentB
|
| 1305 |
+
> {
|
| 1306 |
+
|
| 1307 |
+
// Define the core components from GEMM
|
| 1308 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1309 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1310 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1311 |
+
Stages, MathOperatorTag>;
|
| 1312 |
+
|
| 1313 |
+
// Define iterators over tiles from the A operand
|
| 1314 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1315 |
+
using IteratorA =
|
| 1316 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 1317 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1318 |
+
ElementA,
|
| 1319 |
+
ThreadMapA,
|
| 1320 |
+
StrideSupport::kUnity
|
| 1321 |
+
>;
|
| 1322 |
+
|
| 1323 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1324 |
+
|
| 1325 |
+
// Define iterators over tiles from the B operand
|
| 1326 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1327 |
+
using IteratorB =
|
| 1328 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 1329 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1330 |
+
ElementB,
|
| 1331 |
+
ThreadMapB,
|
| 1332 |
+
StrideSupport::kUnity
|
| 1333 |
+
>;
|
| 1334 |
+
|
| 1335 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1336 |
+
|
| 1337 |
+
// Warp-level GEMM components
|
| 1338 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1339 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1340 |
+
|
| 1341 |
+
// Define the Mma
|
| 1342 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1343 |
+
ThreadblockShape,
|
| 1344 |
+
IteratorA,
|
| 1345 |
+
SmemIteratorA,
|
| 1346 |
+
arch::CacheOperation::Always,
|
| 1347 |
+
IteratorB,
|
| 1348 |
+
SmemIteratorB,
|
| 1349 |
+
arch::CacheOperation::Always,
|
| 1350 |
+
MmaPolicy,
|
| 1351 |
+
Stages
|
| 1352 |
+
>;
|
| 1353 |
+
|
| 1354 |
+
// Define the epilogue
|
| 1355 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1356 |
+
ThreadblockShape,
|
| 1357 |
+
WarpMmaSimtOp,
|
| 1358 |
+
EpilogueOutputOp,
|
| 1359 |
+
EpilogueOutputOp::kCount
|
| 1360 |
+
>::Epilogue;
|
| 1361 |
+
|
| 1362 |
+
// Define the kernel
|
| 1363 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1364 |
+
Mma,
|
| 1365 |
+
Epilogue,
|
| 1366 |
+
ThreadblockSwizzle,
|
| 1367 |
+
conv::Operator::kDgrad
|
| 1368 |
+
>;
|
| 1369 |
+
};
|
| 1370 |
+
|
| 1371 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1372 |
+
template <
|
| 1373 |
+
typename ElementA,
|
| 1374 |
+
typename LayoutA,
|
| 1375 |
+
typename ElementB,
|
| 1376 |
+
typename LayoutB,
|
| 1377 |
+
typename ElementC,
|
| 1378 |
+
typename LayoutC,
|
| 1379 |
+
typename ElementAccumulator,
|
| 1380 |
+
typename ArchTag,
|
| 1381 |
+
typename ThreadblockShape,
|
| 1382 |
+
typename WarpShape,
|
| 1383 |
+
typename InstructionShape,
|
| 1384 |
+
typename EpilogueOutputOp,
|
| 1385 |
+
typename ThreadblockSwizzle,
|
| 1386 |
+
int Stages,
|
| 1387 |
+
typename MathOperatorTag,
|
| 1388 |
+
int AlignmentA,
|
| 1389 |
+
int AlignmentB
|
| 1390 |
+
>
|
| 1391 |
+
struct DefaultConv2dDgrad <
|
| 1392 |
+
ElementA,
|
| 1393 |
+
LayoutA,
|
| 1394 |
+
ElementB,
|
| 1395 |
+
LayoutB,
|
| 1396 |
+
ElementC,
|
| 1397 |
+
LayoutC,
|
| 1398 |
+
ElementAccumulator,
|
| 1399 |
+
arch::OpClassSimt,
|
| 1400 |
+
ArchTag,
|
| 1401 |
+
ThreadblockShape,
|
| 1402 |
+
WarpShape,
|
| 1403 |
+
InstructionShape,
|
| 1404 |
+
EpilogueOutputOp,
|
| 1405 |
+
ThreadblockSwizzle,
|
| 1406 |
+
Stages,
|
| 1407 |
+
MathOperatorTag,
|
| 1408 |
+
IteratorAlgorithm::kOptimized,
|
| 1409 |
+
conv::StrideSupport::kStrided,
|
| 1410 |
+
AlignmentA,
|
| 1411 |
+
AlignmentB
|
| 1412 |
+
> {
|
| 1413 |
+
|
| 1414 |
+
// Define the core components from GEMM
|
| 1415 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1416 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1417 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1418 |
+
Stages, MathOperatorTag>;
|
| 1419 |
+
|
| 1420 |
+
// Define iterators over tiles from the A operand
|
| 1421 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1422 |
+
using IteratorA =
|
| 1423 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 1424 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1425 |
+
ElementA,
|
| 1426 |
+
ThreadMapA,
|
| 1427 |
+
conv::StrideSupport::kStrided
|
| 1428 |
+
>;
|
| 1429 |
+
|
| 1430 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1431 |
+
|
| 1432 |
+
// Define iterators over tiles from the B operand
|
| 1433 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1434 |
+
using IteratorB =
|
| 1435 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 1436 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1437 |
+
ElementB,
|
| 1438 |
+
ThreadMapB,
|
| 1439 |
+
conv::StrideSupport::kStrided
|
| 1440 |
+
>;
|
| 1441 |
+
|
| 1442 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1443 |
+
|
| 1444 |
+
// Warp-level GEMM components
|
| 1445 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1446 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1447 |
+
|
| 1448 |
+
// Define the Mma
|
| 1449 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1450 |
+
ThreadblockShape,
|
| 1451 |
+
IteratorA,
|
| 1452 |
+
SmemIteratorA,
|
| 1453 |
+
arch::CacheOperation::Always,
|
| 1454 |
+
IteratorB,
|
| 1455 |
+
SmemIteratorB,
|
| 1456 |
+
arch::CacheOperation::Always,
|
| 1457 |
+
MmaPolicy,
|
| 1458 |
+
Stages
|
| 1459 |
+
>;
|
| 1460 |
+
|
| 1461 |
+
// Define the epilogue
|
| 1462 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
| 1463 |
+
ThreadblockShape,
|
| 1464 |
+
WarpMmaSimtOp,
|
| 1465 |
+
EpilogueOutputOp,
|
| 1466 |
+
EpilogueOutputOp::kCount
|
| 1467 |
+
>::Epilogue;
|
| 1468 |
+
|
| 1469 |
+
// Define the kernel
|
| 1470 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 1471 |
+
Mma,
|
| 1472 |
+
Epilogue,
|
| 1473 |
+
ThreadblockSwizzle,
|
| 1474 |
+
conv::Operator::kDgrad
|
| 1475 |
+
>;
|
| 1476 |
+
|
| 1477 |
+
};
|
| 1478 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1479 |
+
|
| 1480 |
+
/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm,
|
| 1481 |
+
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
| 1482 |
+
template <
|
| 1483 |
+
typename ElementA,
|
| 1484 |
+
typename LayoutA,
|
| 1485 |
+
typename ElementB,
|
| 1486 |
+
typename LayoutB,
|
| 1487 |
+
typename ElementC,
|
| 1488 |
+
typename LayoutC,
|
| 1489 |
+
typename ElementAccumulator,
|
| 1490 |
+
typename ArchTag,
|
| 1491 |
+
typename ThreadblockShape,
|
| 1492 |
+
typename WarpShape,
|
| 1493 |
+
typename InstructionShape,
|
| 1494 |
+
typename EpilogueOutputOp,
|
| 1495 |
+
typename ThreadblockSwizzle,
|
| 1496 |
+
typename MathOperatorTag,
|
| 1497 |
+
int AlignmentA,
|
| 1498 |
+
int AlignmentB
|
| 1499 |
+
>
|
| 1500 |
+
struct DefaultConv2dDgrad <
|
| 1501 |
+
ElementA,
|
| 1502 |
+
LayoutA,
|
| 1503 |
+
ElementB,
|
| 1504 |
+
LayoutB,
|
| 1505 |
+
ElementC,
|
| 1506 |
+
LayoutC,
|
| 1507 |
+
ElementAccumulator,
|
| 1508 |
+
arch::OpClassSimt,
|
| 1509 |
+
ArchTag,
|
| 1510 |
+
ThreadblockShape,
|
| 1511 |
+
WarpShape,
|
| 1512 |
+
InstructionShape,
|
| 1513 |
+
EpilogueOutputOp,
|
| 1514 |
+
ThreadblockSwizzle,
|
| 1515 |
+
2,
|
| 1516 |
+
MathOperatorTag,
|
| 1517 |
+
IteratorAlgorithm::kAnalytic,
|
| 1518 |
+
conv::StrideSupport::kUnity,
|
| 1519 |
+
AlignmentA,
|
| 1520 |
+
AlignmentB
|
| 1521 |
+
> {
|
| 1522 |
+
|
| 1523 |
+
// Define the core components from GEMM
|
| 1524 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1525 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1526 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1527 |
+
2, MathOperatorTag>;
|
| 1528 |
+
|
| 1529 |
+
// Define iterators over tiles from the A operand
|
| 1530 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1531 |
+
using IteratorA =
|
| 1532 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1533 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 1534 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1535 |
+
ElementA,
|
| 1536 |
+
ThreadMapA,
|
| 1537 |
+
conv::StrideSupport::kUnity
|
| 1538 |
+
>
|
| 1539 |
+
>;
|
| 1540 |
+
|
| 1541 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1542 |
+
|
| 1543 |
+
// Define iterators over tiles from the B operand
|
| 1544 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1545 |
+
using IteratorB =
|
| 1546 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1547 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 1548 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1549 |
+
ElementB,
|
| 1550 |
+
ThreadMapB,
|
| 1551 |
+
conv::StrideSupport::kUnity
|
| 1552 |
+
>
|
| 1553 |
+
>;
|
| 1554 |
+
|
| 1555 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1556 |
+
|
| 1557 |
+
// Warp-level GEMM components
|
| 1558 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1559 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1560 |
+
|
| 1561 |
+
// Define the Mma
|
| 1562 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1563 |
+
ThreadblockShape,
|
| 1564 |
+
IteratorA,
|
| 1565 |
+
SmemIteratorA,
|
| 1566 |
+
IteratorB,
|
| 1567 |
+
SmemIteratorB,
|
| 1568 |
+
ElementC,
|
| 1569 |
+
LayoutC,
|
| 1570 |
+
MmaPolicy
|
| 1571 |
+
>;
|
| 1572 |
+
|
| 1573 |
+
// Define the epilogue
|
| 1574 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1575 |
+
ThreadblockShape,
|
| 1576 |
+
WarpMmaSimtOp,
|
| 1577 |
+
EpilogueOutputOp,
|
| 1578 |
+
EpilogueOutputOp::kCount
|
| 1579 |
+
>::Epilogue;
|
| 1580 |
+
|
| 1581 |
+
// Define the kernel
|
| 1582 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1583 |
+
Mma,
|
| 1584 |
+
Epilogue,
|
| 1585 |
+
ThreadblockSwizzle,
|
| 1586 |
+
conv::Operator::kDgrad
|
| 1587 |
+
>;
|
| 1588 |
+
|
| 1589 |
+
};
|
| 1590 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1591 |
+
|
| 1592 |
+
template <
|
| 1593 |
+
typename ElementA,
|
| 1594 |
+
typename LayoutA,
|
| 1595 |
+
typename ElementB,
|
| 1596 |
+
typename LayoutB,
|
| 1597 |
+
typename ElementC,
|
| 1598 |
+
typename LayoutC,
|
| 1599 |
+
typename ElementAccumulator,
|
| 1600 |
+
typename ArchTag,
|
| 1601 |
+
typename ThreadblockShape,
|
| 1602 |
+
typename WarpShape,
|
| 1603 |
+
typename InstructionShape,
|
| 1604 |
+
typename EpilogueOutputOp,
|
| 1605 |
+
typename ThreadblockSwizzle,
|
| 1606 |
+
typename MathOperatorTag,
|
| 1607 |
+
int AlignmentA,
|
| 1608 |
+
int AlignmentB
|
| 1609 |
+
>
|
| 1610 |
+
struct DefaultConv2dDgrad <
|
| 1611 |
+
ElementA,
|
| 1612 |
+
LayoutA,
|
| 1613 |
+
ElementB,
|
| 1614 |
+
LayoutB,
|
| 1615 |
+
ElementC,
|
| 1616 |
+
LayoutC,
|
| 1617 |
+
ElementAccumulator,
|
| 1618 |
+
arch::OpClassSimt,
|
| 1619 |
+
ArchTag,
|
| 1620 |
+
ThreadblockShape,
|
| 1621 |
+
WarpShape,
|
| 1622 |
+
InstructionShape,
|
| 1623 |
+
EpilogueOutputOp,
|
| 1624 |
+
ThreadblockSwizzle,
|
| 1625 |
+
2,
|
| 1626 |
+
MathOperatorTag,
|
| 1627 |
+
IteratorAlgorithm::kAnalytic,
|
| 1628 |
+
conv::StrideSupport::kStrided,
|
| 1629 |
+
AlignmentA,
|
| 1630 |
+
AlignmentB
|
| 1631 |
+
> {
|
| 1632 |
+
|
| 1633 |
+
// Define the core components from GEMM
|
| 1634 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1635 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1636 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1637 |
+
2, MathOperatorTag>;
|
| 1638 |
+
|
| 1639 |
+
// Define iterators over tiles from the A operand
|
| 1640 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1641 |
+
using IteratorA =
|
| 1642 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 1643 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
| 1644 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1645 |
+
ElementA,
|
| 1646 |
+
ThreadMapA,
|
| 1647 |
+
conv::StrideSupport::kStrided
|
| 1648 |
+
>
|
| 1649 |
+
>;
|
| 1650 |
+
|
| 1651 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1652 |
+
|
| 1653 |
+
// Define iterators over tiles from the B operand
|
| 1654 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1655 |
+
using IteratorB =
|
| 1656 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 1657 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
| 1658 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1659 |
+
ElementB,
|
| 1660 |
+
ThreadMapB,
|
| 1661 |
+
conv::StrideSupport::kStrided
|
| 1662 |
+
>
|
| 1663 |
+
>;
|
| 1664 |
+
|
| 1665 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1666 |
+
|
| 1667 |
+
// Warp-level GEMM components
|
| 1668 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1669 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1670 |
+
|
| 1671 |
+
// Define the Mma
|
| 1672 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1673 |
+
ThreadblockShape,
|
| 1674 |
+
IteratorA,
|
| 1675 |
+
SmemIteratorA,
|
| 1676 |
+
IteratorB,
|
| 1677 |
+
SmemIteratorB,
|
| 1678 |
+
ElementC,
|
| 1679 |
+
LayoutC,
|
| 1680 |
+
MmaPolicy
|
| 1681 |
+
>;
|
| 1682 |
+
|
| 1683 |
+
// Define the epilogue
|
| 1684 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
| 1685 |
+
ThreadblockShape,
|
| 1686 |
+
WarpMmaSimtOp,
|
| 1687 |
+
EpilogueOutputOp,
|
| 1688 |
+
EpilogueOutputOp::kCount
|
| 1689 |
+
>::Epilogue;
|
| 1690 |
+
|
| 1691 |
+
// Define the kernel
|
| 1692 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 1693 |
+
Mma,
|
| 1694 |
+
Epilogue,
|
| 1695 |
+
ThreadblockSwizzle,
|
| 1696 |
+
conv::Operator::kDgrad
|
| 1697 |
+
>;
|
| 1698 |
+
};
|
| 1699 |
+
|
| 1700 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1701 |
+
|
| 1702 |
+
/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm,
|
| 1703 |
+
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
| 1704 |
+
template <
|
| 1705 |
+
typename ElementA,
|
| 1706 |
+
typename LayoutA,
|
| 1707 |
+
typename ElementB,
|
| 1708 |
+
typename LayoutB,
|
| 1709 |
+
typename ElementC,
|
| 1710 |
+
typename LayoutC,
|
| 1711 |
+
typename ElementAccumulator,
|
| 1712 |
+
typename ArchTag,
|
| 1713 |
+
typename ThreadblockShape,
|
| 1714 |
+
typename WarpShape,
|
| 1715 |
+
typename InstructionShape,
|
| 1716 |
+
typename EpilogueOutputOp,
|
| 1717 |
+
typename ThreadblockSwizzle,
|
| 1718 |
+
typename MathOperatorTag,
|
| 1719 |
+
int AlignmentA,
|
| 1720 |
+
int AlignmentB
|
| 1721 |
+
>
|
| 1722 |
+
struct DefaultConv2dDgrad <
|
| 1723 |
+
ElementA,
|
| 1724 |
+
LayoutA,
|
| 1725 |
+
ElementB,
|
| 1726 |
+
LayoutB,
|
| 1727 |
+
ElementC,
|
| 1728 |
+
LayoutC,
|
| 1729 |
+
ElementAccumulator,
|
| 1730 |
+
arch::OpClassSimt,
|
| 1731 |
+
ArchTag,
|
| 1732 |
+
ThreadblockShape,
|
| 1733 |
+
WarpShape,
|
| 1734 |
+
InstructionShape,
|
| 1735 |
+
EpilogueOutputOp,
|
| 1736 |
+
ThreadblockSwizzle,
|
| 1737 |
+
2,
|
| 1738 |
+
MathOperatorTag,
|
| 1739 |
+
IteratorAlgorithm::kOptimized,
|
| 1740 |
+
StrideSupport::kUnity,
|
| 1741 |
+
AlignmentA,
|
| 1742 |
+
AlignmentB
|
| 1743 |
+
> {
|
| 1744 |
+
|
| 1745 |
+
// Define the core components from GEMM
|
| 1746 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1747 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1748 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1749 |
+
2, MathOperatorTag>;
|
| 1750 |
+
|
| 1751 |
+
// Define iterators over tiles from the A operand
|
| 1752 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1753 |
+
using IteratorA =
|
| 1754 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1755 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 1756 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1757 |
+
ElementA,
|
| 1758 |
+
ThreadMapA,
|
| 1759 |
+
StrideSupport::kUnity
|
| 1760 |
+
>
|
| 1761 |
+
>;
|
| 1762 |
+
|
| 1763 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1764 |
+
|
| 1765 |
+
// Define iterators over tiles from the B operand
|
| 1766 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1767 |
+
using IteratorB =
|
| 1768 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1769 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 1770 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1771 |
+
ElementB,
|
| 1772 |
+
ThreadMapB,
|
| 1773 |
+
StrideSupport::kUnity
|
| 1774 |
+
>
|
| 1775 |
+
>;
|
| 1776 |
+
|
| 1777 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1778 |
+
|
| 1779 |
+
// Warp-level GEMM components
|
| 1780 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1781 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1782 |
+
|
| 1783 |
+
// Define the Mma
|
| 1784 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1785 |
+
ThreadblockShape,
|
| 1786 |
+
IteratorA,
|
| 1787 |
+
SmemIteratorA,
|
| 1788 |
+
IteratorB,
|
| 1789 |
+
SmemIteratorB,
|
| 1790 |
+
ElementC,
|
| 1791 |
+
LayoutC,
|
| 1792 |
+
MmaPolicy
|
| 1793 |
+
>;
|
| 1794 |
+
|
| 1795 |
+
// Define the epilogue
|
| 1796 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1797 |
+
ThreadblockShape,
|
| 1798 |
+
WarpMmaSimtOp,
|
| 1799 |
+
EpilogueOutputOp,
|
| 1800 |
+
EpilogueOutputOp::kCount
|
| 1801 |
+
>::Epilogue;
|
| 1802 |
+
|
| 1803 |
+
// Define the kernel
|
| 1804 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1805 |
+
Mma,
|
| 1806 |
+
Epilogue,
|
| 1807 |
+
ThreadblockSwizzle,
|
| 1808 |
+
conv::Operator::kDgrad
|
| 1809 |
+
>;
|
| 1810 |
+
|
| 1811 |
+
};
|
| 1812 |
+
|
| 1813 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1814 |
+
template <
|
| 1815 |
+
typename ElementA,
|
| 1816 |
+
typename LayoutA,
|
| 1817 |
+
typename ElementB,
|
| 1818 |
+
typename LayoutB,
|
| 1819 |
+
typename ElementC,
|
| 1820 |
+
typename LayoutC,
|
| 1821 |
+
typename ElementAccumulator,
|
| 1822 |
+
typename ArchTag,
|
| 1823 |
+
typename ThreadblockShape,
|
| 1824 |
+
typename WarpShape,
|
| 1825 |
+
typename InstructionShape,
|
| 1826 |
+
typename EpilogueOutputOp,
|
| 1827 |
+
typename ThreadblockSwizzle,
|
| 1828 |
+
typename MathOperatorTag,
|
| 1829 |
+
int AlignmentA,
|
| 1830 |
+
int AlignmentB
|
| 1831 |
+
>
|
| 1832 |
+
struct DefaultConv2dDgrad <
|
| 1833 |
+
ElementA,
|
| 1834 |
+
LayoutA,
|
| 1835 |
+
ElementB,
|
| 1836 |
+
LayoutB,
|
| 1837 |
+
ElementC,
|
| 1838 |
+
LayoutC,
|
| 1839 |
+
ElementAccumulator,
|
| 1840 |
+
arch::OpClassSimt,
|
| 1841 |
+
ArchTag,
|
| 1842 |
+
ThreadblockShape,
|
| 1843 |
+
WarpShape,
|
| 1844 |
+
InstructionShape,
|
| 1845 |
+
EpilogueOutputOp,
|
| 1846 |
+
ThreadblockSwizzle,
|
| 1847 |
+
2,
|
| 1848 |
+
MathOperatorTag,
|
| 1849 |
+
IteratorAlgorithm::kOptimized,
|
| 1850 |
+
conv::StrideSupport::kStrided,
|
| 1851 |
+
AlignmentA,
|
| 1852 |
+
AlignmentB
|
| 1853 |
+
> {
|
| 1854 |
+
|
| 1855 |
+
// Define the core components from GEMM
|
| 1856 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1857 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1858 |
+
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1859 |
+
2, MathOperatorTag>;
|
| 1860 |
+
|
| 1861 |
+
// Define iterators over tiles from the A operand
|
| 1862 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1863 |
+
using IteratorA =
|
| 1864 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 1865 |
+
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
| 1866 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1867 |
+
ElementA,
|
| 1868 |
+
ThreadMapA,
|
| 1869 |
+
conv::StrideSupport::kStrided
|
| 1870 |
+
>
|
| 1871 |
+
>;
|
| 1872 |
+
|
| 1873 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1874 |
+
|
| 1875 |
+
// Define iterators over tiles from the B operand
|
| 1876 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1877 |
+
using IteratorB =
|
| 1878 |
+
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
| 1879 |
+
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
| 1880 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1881 |
+
ElementB,
|
| 1882 |
+
ThreadMapB,
|
| 1883 |
+
conv::StrideSupport::kStrided
|
| 1884 |
+
>
|
| 1885 |
+
>;
|
| 1886 |
+
|
| 1887 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1888 |
+
|
| 1889 |
+
// Warp-level GEMM components
|
| 1890 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1891 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1892 |
+
|
| 1893 |
+
// Define the Mma
|
| 1894 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1895 |
+
ThreadblockShape,
|
| 1896 |
+
IteratorA,
|
| 1897 |
+
SmemIteratorA,
|
| 1898 |
+
IteratorB,
|
| 1899 |
+
SmemIteratorB,
|
| 1900 |
+
ElementC,
|
| 1901 |
+
LayoutC,
|
| 1902 |
+
MmaPolicy
|
| 1903 |
+
>;
|
| 1904 |
+
|
| 1905 |
+
// Define the epilogue
|
| 1906 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
| 1907 |
+
ThreadblockShape,
|
| 1908 |
+
WarpMmaSimtOp,
|
| 1909 |
+
EpilogueOutputOp,
|
| 1910 |
+
EpilogueOutputOp::kCount
|
| 1911 |
+
>::Epilogue;
|
| 1912 |
+
|
| 1913 |
+
// Define the kernel
|
| 1914 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
| 1915 |
+
Mma,
|
| 1916 |
+
Epilogue,
|
| 1917 |
+
ThreadblockSwizzle,
|
| 1918 |
+
conv::Operator::kDgrad
|
| 1919 |
+
>;
|
| 1920 |
+
|
| 1921 |
+
};
|
| 1922 |
+
|
| 1923 |
+
} // namespace kernel
|
| 1924 |
+
} // namespace conv
|
| 1925 |
+
} // namespace cutlass
|
| 1926 |
+
|
| 1927 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h
ADDED
|
@@ -0,0 +1,2007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 49 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 50 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h"
|
| 51 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h"
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
namespace cutlass {
|
| 56 |
+
namespace conv {
|
| 57 |
+
namespace kernel {
|
| 58 |
+
|
| 59 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
/// Defines a kernel for Conv2dFprop
|
| 61 |
+
template <
|
| 62 |
+
typename ElementA,
|
| 63 |
+
typename LayoutA,
|
| 64 |
+
typename ElementB,
|
| 65 |
+
typename LayoutB,
|
| 66 |
+
typename ElementC,
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
typename ElementAccumulator,
|
| 69 |
+
typename OperatorClass,
|
| 70 |
+
typename ArchTag,
|
| 71 |
+
typename ThreadblockShape,
|
| 72 |
+
typename WarpShape,
|
| 73 |
+
typename InstructionShape,
|
| 74 |
+
typename EpilogueOutputOp,
|
| 75 |
+
typename ThreadblockSwizzle,
|
| 76 |
+
int Stages,
|
| 77 |
+
typename MathOperatorTag,
|
| 78 |
+
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
| 79 |
+
conv::StrideSupport StrideSupport = StrideSupport::kUnity,
|
| 80 |
+
/// Access granularity of A matrix in units of elements
|
| 81 |
+
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
| 82 |
+
/// Access granularity of B matrix in units of elements
|
| 83 |
+
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
| 84 |
+
> struct DefaultConv2dFprop;
|
| 85 |
+
|
| 86 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 87 |
+
// OpClassTensorOp convolutions
|
| 88 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 89 |
+
|
| 90 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 91 |
+
/// pipeline.
|
| 92 |
+
template <
|
| 93 |
+
typename ElementA,
|
| 94 |
+
typename LayoutA,
|
| 95 |
+
typename ElementB,
|
| 96 |
+
typename LayoutB,
|
| 97 |
+
typename ElementC,
|
| 98 |
+
typename LayoutC,
|
| 99 |
+
typename ElementAccumulator,
|
| 100 |
+
typename ArchTag,
|
| 101 |
+
typename ThreadblockShape,
|
| 102 |
+
typename WarpShape,
|
| 103 |
+
typename InstructionShape,
|
| 104 |
+
typename EpilogueOutputOp,
|
| 105 |
+
typename ThreadblockSwizzle,
|
| 106 |
+
int Stages,
|
| 107 |
+
typename MathOperatorTag,
|
| 108 |
+
conv::StrideSupport StrideSupport,
|
| 109 |
+
int AlignmentA,
|
| 110 |
+
int AlignmentB
|
| 111 |
+
>
|
| 112 |
+
struct DefaultConv2dFprop <
|
| 113 |
+
ElementA,
|
| 114 |
+
LayoutA,
|
| 115 |
+
ElementB,
|
| 116 |
+
LayoutB,
|
| 117 |
+
ElementC,
|
| 118 |
+
LayoutC,
|
| 119 |
+
ElementAccumulator,
|
| 120 |
+
arch::OpClassTensorOp,
|
| 121 |
+
ArchTag,
|
| 122 |
+
ThreadblockShape,
|
| 123 |
+
WarpShape,
|
| 124 |
+
InstructionShape,
|
| 125 |
+
EpilogueOutputOp,
|
| 126 |
+
ThreadblockSwizzle,
|
| 127 |
+
Stages,
|
| 128 |
+
MathOperatorTag,
|
| 129 |
+
IteratorAlgorithm::kAnalytic,
|
| 130 |
+
StrideSupport,
|
| 131 |
+
AlignmentA,
|
| 132 |
+
AlignmentB
|
| 133 |
+
> {
|
| 134 |
+
|
| 135 |
+
// Define the core components from GEMM
|
| 136 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 137 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 138 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 139 |
+
Stages, MathOperatorTag>;
|
| 140 |
+
|
| 141 |
+
// Define iterators over tiles from the A operand
|
| 142 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 143 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 144 |
+
using IteratorA =
|
| 145 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 146 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 147 |
+
ElementA, LayoutA,
|
| 148 |
+
ThreadMapA,
|
| 149 |
+
AccessTypeA
|
| 150 |
+
>;
|
| 151 |
+
|
| 152 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 153 |
+
|
| 154 |
+
// Define iterators over tiles from the B operand
|
| 155 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 156 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 157 |
+
using IteratorB =
|
| 158 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 159 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 160 |
+
ElementB, LayoutB,
|
| 161 |
+
ThreadMapB,
|
| 162 |
+
AccessTypeB
|
| 163 |
+
>;
|
| 164 |
+
|
| 165 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 166 |
+
|
| 167 |
+
// Warp-level GEMM components
|
| 168 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 169 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 170 |
+
|
| 171 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 172 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 173 |
+
? cutlass::arch::CacheOperation::Global
|
| 174 |
+
: cutlass::arch::CacheOperation::Always;
|
| 175 |
+
|
| 176 |
+
// Define the Mma
|
| 177 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 178 |
+
ThreadblockShape,
|
| 179 |
+
IteratorA,
|
| 180 |
+
SmemIteratorA,
|
| 181 |
+
arch::CacheOperation::Always,
|
| 182 |
+
IteratorB,
|
| 183 |
+
SmemIteratorB,
|
| 184 |
+
CacheOpB,
|
| 185 |
+
MmaPolicy,
|
| 186 |
+
Stages
|
| 187 |
+
>;
|
| 188 |
+
|
| 189 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 190 |
+
|
| 191 |
+
// Define the epilogue
|
| 192 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 193 |
+
ThreadblockShape,
|
| 194 |
+
WarpMmaTensorOp,
|
| 195 |
+
kPartitionsK,
|
| 196 |
+
EpilogueOutputOp,
|
| 197 |
+
EpilogueOutputOp::kCount
|
| 198 |
+
>::Epilogue;
|
| 199 |
+
|
| 200 |
+
// Define the kernel
|
| 201 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 202 |
+
Mma,
|
| 203 |
+
Epilogue,
|
| 204 |
+
ThreadblockSwizzle,
|
| 205 |
+
conv::Operator::kFprop
|
| 206 |
+
>;
|
| 207 |
+
};
|
| 208 |
+
|
| 209 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 210 |
+
|
| 211 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 212 |
+
/// pipeline.
|
| 213 |
+
template <
|
| 214 |
+
typename ElementA,
|
| 215 |
+
typename LayoutA,
|
| 216 |
+
typename ElementB,
|
| 217 |
+
typename LayoutB,
|
| 218 |
+
typename ElementC,
|
| 219 |
+
typename LayoutC,
|
| 220 |
+
typename ElementAccumulator,
|
| 221 |
+
typename ArchTag,
|
| 222 |
+
typename ThreadblockShape,
|
| 223 |
+
typename WarpShape,
|
| 224 |
+
typename InstructionShape,
|
| 225 |
+
typename EpilogueOutputOp,
|
| 226 |
+
typename ThreadblockSwizzle,
|
| 227 |
+
int Stages,
|
| 228 |
+
typename MathOperatorTag,
|
| 229 |
+
conv::StrideSupport StrideSupport,
|
| 230 |
+
int AlignmentA,
|
| 231 |
+
int AlignmentB
|
| 232 |
+
>
|
| 233 |
+
struct DefaultConv2dFprop <
|
| 234 |
+
ElementA,
|
| 235 |
+
LayoutA,
|
| 236 |
+
ElementB,
|
| 237 |
+
LayoutB,
|
| 238 |
+
ElementC,
|
| 239 |
+
LayoutC,
|
| 240 |
+
ElementAccumulator,
|
| 241 |
+
arch::OpClassTensorOp,
|
| 242 |
+
ArchTag,
|
| 243 |
+
ThreadblockShape,
|
| 244 |
+
WarpShape,
|
| 245 |
+
InstructionShape,
|
| 246 |
+
EpilogueOutputOp,
|
| 247 |
+
ThreadblockSwizzle,
|
| 248 |
+
Stages,
|
| 249 |
+
MathOperatorTag,
|
| 250 |
+
IteratorAlgorithm::kFixedChannels,
|
| 251 |
+
StrideSupport,
|
| 252 |
+
AlignmentA,
|
| 253 |
+
AlignmentB
|
| 254 |
+
> {
|
| 255 |
+
|
| 256 |
+
// Define the core components from GEMM
|
| 257 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 258 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 259 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 260 |
+
Stages, MathOperatorTag>;
|
| 261 |
+
|
| 262 |
+
// Define iterators over tiles from the A operand
|
| 263 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 264 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 265 |
+
using IteratorA =
|
| 266 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels<
|
| 267 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 268 |
+
ElementA, LayoutA,
|
| 269 |
+
ThreadMapA,
|
| 270 |
+
AccessTypeA
|
| 271 |
+
>;
|
| 272 |
+
|
| 273 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 274 |
+
|
| 275 |
+
// Define iterators over tiles from the B operand
|
| 276 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 277 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 278 |
+
using IteratorB =
|
| 279 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels<
|
| 280 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 281 |
+
ElementB, LayoutB,
|
| 282 |
+
ThreadMapB,
|
| 283 |
+
AccessTypeB
|
| 284 |
+
>;
|
| 285 |
+
|
| 286 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 287 |
+
|
| 288 |
+
// Warp-level GEMM components
|
| 289 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 290 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 291 |
+
|
| 292 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 293 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 294 |
+
? cutlass::arch::CacheOperation::Global
|
| 295 |
+
: cutlass::arch::CacheOperation::Always;
|
| 296 |
+
|
| 297 |
+
// Define the Mma
|
| 298 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 299 |
+
ThreadblockShape,
|
| 300 |
+
IteratorA,
|
| 301 |
+
SmemIteratorA,
|
| 302 |
+
arch::CacheOperation::Always,
|
| 303 |
+
IteratorB,
|
| 304 |
+
SmemIteratorB,
|
| 305 |
+
CacheOpB,
|
| 306 |
+
MmaPolicy,
|
| 307 |
+
Stages
|
| 308 |
+
>;
|
| 309 |
+
|
| 310 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 311 |
+
|
| 312 |
+
// Define the epilogue
|
| 313 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 314 |
+
ThreadblockShape,
|
| 315 |
+
WarpMmaTensorOp,
|
| 316 |
+
kPartitionsK,
|
| 317 |
+
EpilogueOutputOp,
|
| 318 |
+
EpilogueOutputOp::kCount
|
| 319 |
+
>::Epilogue;
|
| 320 |
+
|
| 321 |
+
// Define the kernel
|
| 322 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 323 |
+
Mma,
|
| 324 |
+
Epilogue,
|
| 325 |
+
ThreadblockSwizzle,
|
| 326 |
+
conv::Operator::kFprop
|
| 327 |
+
>;
|
| 328 |
+
};
|
| 329 |
+
|
| 330 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 331 |
+
|
| 332 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage
|
| 333 |
+
/// pipeline.
|
| 334 |
+
template <
|
| 335 |
+
typename ElementA,
|
| 336 |
+
typename LayoutA,
|
| 337 |
+
typename ElementB,
|
| 338 |
+
typename LayoutB,
|
| 339 |
+
typename ElementC,
|
| 340 |
+
typename LayoutC,
|
| 341 |
+
typename ElementAccumulator,
|
| 342 |
+
typename ArchTag,
|
| 343 |
+
typename ThreadblockShape,
|
| 344 |
+
typename WarpShape,
|
| 345 |
+
typename InstructionShape,
|
| 346 |
+
typename EpilogueOutputOp,
|
| 347 |
+
typename ThreadblockSwizzle,
|
| 348 |
+
typename MathOperatorTag,
|
| 349 |
+
conv::StrideSupport StrideSupport,
|
| 350 |
+
int AlignmentA,
|
| 351 |
+
int AlignmentB
|
| 352 |
+
>
|
| 353 |
+
struct DefaultConv2dFprop <
|
| 354 |
+
ElementA,
|
| 355 |
+
LayoutA,
|
| 356 |
+
ElementB,
|
| 357 |
+
LayoutB,
|
| 358 |
+
ElementC,
|
| 359 |
+
LayoutC,
|
| 360 |
+
ElementAccumulator,
|
| 361 |
+
arch::OpClassTensorOp,
|
| 362 |
+
ArchTag,
|
| 363 |
+
ThreadblockShape,
|
| 364 |
+
WarpShape,
|
| 365 |
+
InstructionShape,
|
| 366 |
+
EpilogueOutputOp,
|
| 367 |
+
ThreadblockSwizzle,
|
| 368 |
+
2,
|
| 369 |
+
MathOperatorTag,
|
| 370 |
+
IteratorAlgorithm::kFixedChannels,
|
| 371 |
+
StrideSupport,
|
| 372 |
+
AlignmentA,
|
| 373 |
+
AlignmentB
|
| 374 |
+
> {
|
| 375 |
+
|
| 376 |
+
// Define the core components from GEMM
|
| 377 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 378 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 379 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 380 |
+
2, MathOperatorTag>;
|
| 381 |
+
|
| 382 |
+
// Define iterators over tiles from the A operand
|
| 383 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 384 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 385 |
+
using IteratorA =
|
| 386 |
+
cutlass::conv::threadblock::TileIterator<
|
| 387 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels<
|
| 388 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 389 |
+
ElementA, LayoutA,
|
| 390 |
+
ThreadMapA,
|
| 391 |
+
AccessTypeA
|
| 392 |
+
>
|
| 393 |
+
>;
|
| 394 |
+
|
| 395 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 396 |
+
|
| 397 |
+
// Define iterators over tiles from the B operand
|
| 398 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 399 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 400 |
+
using IteratorB =
|
| 401 |
+
cutlass::conv::threadblock::TileIterator<
|
| 402 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels<
|
| 403 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 404 |
+
ElementB, LayoutB,
|
| 405 |
+
ThreadMapB,
|
| 406 |
+
AccessTypeB
|
| 407 |
+
>
|
| 408 |
+
>;
|
| 409 |
+
|
| 410 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 411 |
+
|
| 412 |
+
// Warp-level GEMM components
|
| 413 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 414 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 415 |
+
|
| 416 |
+
// Define the Mma
|
| 417 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 418 |
+
ThreadblockShape,
|
| 419 |
+
IteratorA,
|
| 420 |
+
SmemIteratorA,
|
| 421 |
+
IteratorB,
|
| 422 |
+
SmemIteratorB,
|
| 423 |
+
ElementC,
|
| 424 |
+
LayoutC,
|
| 425 |
+
MmaPolicy
|
| 426 |
+
>;
|
| 427 |
+
|
| 428 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 429 |
+
|
| 430 |
+
// Define the epilogue
|
| 431 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 432 |
+
ThreadblockShape,
|
| 433 |
+
WarpMmaTensorOp,
|
| 434 |
+
kPartitionsK,
|
| 435 |
+
EpilogueOutputOp,
|
| 436 |
+
EpilogueOutputOp::kCount
|
| 437 |
+
>::Epilogue;
|
| 438 |
+
|
| 439 |
+
// Define the kernel
|
| 440 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 441 |
+
Mma,
|
| 442 |
+
Epilogue,
|
| 443 |
+
ThreadblockSwizzle,
|
| 444 |
+
conv::Operator::kFprop
|
| 445 |
+
>;
|
| 446 |
+
};
|
| 447 |
+
|
| 448 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 449 |
+
|
| 450 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 451 |
+
/// pipeline.
|
| 452 |
+
template <
|
| 453 |
+
typename ElementA,
|
| 454 |
+
typename LayoutA,
|
| 455 |
+
typename ElementB,
|
| 456 |
+
typename LayoutB,
|
| 457 |
+
typename ElementC,
|
| 458 |
+
typename LayoutC,
|
| 459 |
+
typename ElementAccumulator,
|
| 460 |
+
typename ArchTag,
|
| 461 |
+
typename ThreadblockShape,
|
| 462 |
+
typename WarpShape,
|
| 463 |
+
typename InstructionShape,
|
| 464 |
+
typename EpilogueOutputOp,
|
| 465 |
+
typename ThreadblockSwizzle,
|
| 466 |
+
int Stages,
|
| 467 |
+
typename MathOperatorTag,
|
| 468 |
+
conv::StrideSupport StrideSupport,
|
| 469 |
+
int AlignmentA,
|
| 470 |
+
int AlignmentB
|
| 471 |
+
>
|
| 472 |
+
struct DefaultConv2dFprop <
|
| 473 |
+
ElementA,
|
| 474 |
+
LayoutA,
|
| 475 |
+
ElementB,
|
| 476 |
+
LayoutB,
|
| 477 |
+
ElementC,
|
| 478 |
+
LayoutC,
|
| 479 |
+
ElementAccumulator,
|
| 480 |
+
arch::OpClassTensorOp,
|
| 481 |
+
ArchTag,
|
| 482 |
+
ThreadblockShape,
|
| 483 |
+
WarpShape,
|
| 484 |
+
InstructionShape,
|
| 485 |
+
EpilogueOutputOp,
|
| 486 |
+
ThreadblockSwizzle,
|
| 487 |
+
Stages,
|
| 488 |
+
MathOperatorTag,
|
| 489 |
+
IteratorAlgorithm::kFewChannels,
|
| 490 |
+
StrideSupport,
|
| 491 |
+
AlignmentA,
|
| 492 |
+
AlignmentB
|
| 493 |
+
> {
|
| 494 |
+
|
| 495 |
+
// Define the core components from GEMM
|
| 496 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 497 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 498 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 499 |
+
Stages, MathOperatorTag>;
|
| 500 |
+
|
| 501 |
+
// Define iterators over tiles from the A operand
|
| 502 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 503 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 504 |
+
using IteratorA =
|
| 505 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels<
|
| 506 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 507 |
+
ElementA, LayoutA,
|
| 508 |
+
ThreadMapA,
|
| 509 |
+
AccessTypeA
|
| 510 |
+
>;
|
| 511 |
+
|
| 512 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 513 |
+
|
| 514 |
+
// Define iterators over tiles from the B operand
|
| 515 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 516 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 517 |
+
using IteratorB =
|
| 518 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels<
|
| 519 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 520 |
+
ElementB, LayoutB,
|
| 521 |
+
ThreadMapB,
|
| 522 |
+
AccessTypeB
|
| 523 |
+
>;
|
| 524 |
+
|
| 525 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 526 |
+
|
| 527 |
+
// Warp-level GEMM components
|
| 528 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 529 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 530 |
+
|
| 531 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 532 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 533 |
+
? cutlass::arch::CacheOperation::Global
|
| 534 |
+
: cutlass::arch::CacheOperation::Always;
|
| 535 |
+
|
| 536 |
+
// Define the Mma
|
| 537 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 538 |
+
ThreadblockShape,
|
| 539 |
+
IteratorA,
|
| 540 |
+
SmemIteratorA,
|
| 541 |
+
arch::CacheOperation::Always,
|
| 542 |
+
IteratorB,
|
| 543 |
+
SmemIteratorB,
|
| 544 |
+
CacheOpB,
|
| 545 |
+
MmaPolicy,
|
| 546 |
+
Stages
|
| 547 |
+
>;
|
| 548 |
+
|
| 549 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 550 |
+
|
| 551 |
+
// Define the epilogue
|
| 552 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 553 |
+
ThreadblockShape,
|
| 554 |
+
WarpMmaTensorOp,
|
| 555 |
+
kPartitionsK,
|
| 556 |
+
EpilogueOutputOp,
|
| 557 |
+
EpilogueOutputOp::kCount
|
| 558 |
+
>::Epilogue;
|
| 559 |
+
|
| 560 |
+
// Define the kernel
|
| 561 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 562 |
+
Mma,
|
| 563 |
+
Epilogue,
|
| 564 |
+
ThreadblockSwizzle,
|
| 565 |
+
conv::Operator::kFprop
|
| 566 |
+
>;
|
| 567 |
+
};
|
| 568 |
+
|
| 569 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 570 |
+
/// pipeline.
|
| 571 |
+
template <
|
| 572 |
+
typename ElementA,
|
| 573 |
+
typename LayoutA,
|
| 574 |
+
typename ElementB,
|
| 575 |
+
typename LayoutB,
|
| 576 |
+
typename ElementC,
|
| 577 |
+
typename LayoutC,
|
| 578 |
+
typename ElementAccumulator,
|
| 579 |
+
typename ArchTag,
|
| 580 |
+
typename ThreadblockShape,
|
| 581 |
+
typename WarpShape,
|
| 582 |
+
typename InstructionShape,
|
| 583 |
+
typename EpilogueOutputOp,
|
| 584 |
+
typename ThreadblockSwizzle,
|
| 585 |
+
typename MathOperatorTag,
|
| 586 |
+
conv::StrideSupport StrideSupport,
|
| 587 |
+
int AlignmentA,
|
| 588 |
+
int AlignmentB
|
| 589 |
+
>
|
| 590 |
+
struct DefaultConv2dFprop <
|
| 591 |
+
ElementA,
|
| 592 |
+
LayoutA,
|
| 593 |
+
ElementB,
|
| 594 |
+
LayoutB,
|
| 595 |
+
ElementC,
|
| 596 |
+
LayoutC,
|
| 597 |
+
ElementAccumulator,
|
| 598 |
+
arch::OpClassTensorOp,
|
| 599 |
+
ArchTag,
|
| 600 |
+
ThreadblockShape,
|
| 601 |
+
WarpShape,
|
| 602 |
+
InstructionShape,
|
| 603 |
+
EpilogueOutputOp,
|
| 604 |
+
ThreadblockSwizzle,
|
| 605 |
+
2,
|
| 606 |
+
MathOperatorTag,
|
| 607 |
+
IteratorAlgorithm::kFewChannels,
|
| 608 |
+
StrideSupport,
|
| 609 |
+
AlignmentA,
|
| 610 |
+
AlignmentB
|
| 611 |
+
> {
|
| 612 |
+
|
| 613 |
+
// Define the core components from GEMM
|
| 614 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 615 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 616 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 617 |
+
2, MathOperatorTag>;
|
| 618 |
+
|
| 619 |
+
// Define iterators over tiles from the A operand
|
| 620 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 621 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 622 |
+
using IteratorA =
|
| 623 |
+
cutlass::conv::threadblock::TileIterator<
|
| 624 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels<
|
| 625 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 626 |
+
ElementA, LayoutA,
|
| 627 |
+
ThreadMapA,
|
| 628 |
+
AccessTypeA
|
| 629 |
+
>
|
| 630 |
+
>;
|
| 631 |
+
|
| 632 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 633 |
+
|
| 634 |
+
// Define iterators over tiles from the B operand
|
| 635 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 636 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 637 |
+
using IteratorB =
|
| 638 |
+
|
| 639 |
+
cutlass::conv::threadblock::TileIterator<
|
| 640 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels<
|
| 641 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 642 |
+
ElementB, LayoutB,
|
| 643 |
+
ThreadMapB,
|
| 644 |
+
AccessTypeB
|
| 645 |
+
>
|
| 646 |
+
>;
|
| 647 |
+
|
| 648 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 649 |
+
|
| 650 |
+
// Warp-level GEMM components
|
| 651 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 652 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 653 |
+
|
| 654 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 655 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 656 |
+
? cutlass::arch::CacheOperation::Global
|
| 657 |
+
: cutlass::arch::CacheOperation::Always;
|
| 658 |
+
|
| 659 |
+
// Define the Mma
|
| 660 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 661 |
+
ThreadblockShape,
|
| 662 |
+
IteratorA,
|
| 663 |
+
SmemIteratorA,
|
| 664 |
+
IteratorB,
|
| 665 |
+
SmemIteratorB,
|
| 666 |
+
ElementC,
|
| 667 |
+
LayoutC,
|
| 668 |
+
MmaPolicy
|
| 669 |
+
>;
|
| 670 |
+
|
| 671 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 672 |
+
|
| 673 |
+
// Define the epilogue
|
| 674 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 675 |
+
ThreadblockShape,
|
| 676 |
+
WarpMmaTensorOp,
|
| 677 |
+
kPartitionsK,
|
| 678 |
+
EpilogueOutputOp,
|
| 679 |
+
EpilogueOutputOp::kCount
|
| 680 |
+
>::Epilogue;
|
| 681 |
+
|
| 682 |
+
// Define the kernel
|
| 683 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 684 |
+
Mma,
|
| 685 |
+
Epilogue,
|
| 686 |
+
ThreadblockSwizzle,
|
| 687 |
+
conv::Operator::kFprop
|
| 688 |
+
>;
|
| 689 |
+
};
|
| 690 |
+
|
| 691 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 692 |
+
|
| 693 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 694 |
+
/// pipeline with interleaved layout.
|
| 695 |
+
template <
|
| 696 |
+
typename ElementA,
|
| 697 |
+
typename ElementB,
|
| 698 |
+
typename ElementC,
|
| 699 |
+
typename LayoutC,
|
| 700 |
+
typename ElementAccumulator,
|
| 701 |
+
typename ArchTag,
|
| 702 |
+
typename ThreadblockShape,
|
| 703 |
+
typename WarpShape,
|
| 704 |
+
typename InstructionShape,
|
| 705 |
+
typename EpilogueOutputOp,
|
| 706 |
+
typename ThreadblockSwizzle,
|
| 707 |
+
int Stages,
|
| 708 |
+
typename MathOperatorTag,
|
| 709 |
+
conv::StrideSupport StrideSupport,
|
| 710 |
+
int AlignmentA,
|
| 711 |
+
int AlignmentB,
|
| 712 |
+
int InterleavedK
|
| 713 |
+
>
|
| 714 |
+
struct DefaultConv2dFprop <
|
| 715 |
+
ElementA,
|
| 716 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 717 |
+
ElementB,
|
| 718 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 719 |
+
ElementC,
|
| 720 |
+
LayoutC,
|
| 721 |
+
ElementAccumulator,
|
| 722 |
+
arch::OpClassTensorOp,
|
| 723 |
+
ArchTag,
|
| 724 |
+
ThreadblockShape,
|
| 725 |
+
WarpShape,
|
| 726 |
+
InstructionShape,
|
| 727 |
+
EpilogueOutputOp,
|
| 728 |
+
ThreadblockSwizzle,
|
| 729 |
+
Stages,
|
| 730 |
+
MathOperatorTag,
|
| 731 |
+
IteratorAlgorithm::kAnalytic,
|
| 732 |
+
StrideSupport,
|
| 733 |
+
AlignmentA,
|
| 734 |
+
AlignmentB
|
| 735 |
+
> {
|
| 736 |
+
|
| 737 |
+
// Define the core components from GEMM
|
| 738 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 739 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 740 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 741 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 742 |
+
Stages, MathOperatorTag, true>;
|
| 743 |
+
|
| 744 |
+
// Define iterators over tiles from the A operand
|
| 745 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 746 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 747 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 748 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 749 |
+
// layout.
|
| 750 |
+
using ThreadMapA = typename MmaCore::SmemThreadMapA;
|
| 751 |
+
using IteratorA =
|
| 752 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 753 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 754 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 755 |
+
ThreadMapA
|
| 756 |
+
>;
|
| 757 |
+
|
| 758 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 759 |
+
|
| 760 |
+
// Define iterators over tiles from the B operand
|
| 761 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 762 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 763 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 764 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 765 |
+
// layout.
|
| 766 |
+
using ThreadMapB = typename MmaCore::SmemThreadMapB;
|
| 767 |
+
using IteratorB =
|
| 768 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 769 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 770 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 771 |
+
ThreadMapB
|
| 772 |
+
>;
|
| 773 |
+
|
| 774 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 775 |
+
|
| 776 |
+
// Warp-level GEMM components
|
| 777 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 778 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 779 |
+
|
| 780 |
+
// Define the Mma
|
| 781 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 782 |
+
ThreadblockShape,
|
| 783 |
+
IteratorA,
|
| 784 |
+
SmemIteratorA,
|
| 785 |
+
arch::CacheOperation::Always,
|
| 786 |
+
IteratorB,
|
| 787 |
+
SmemIteratorB,
|
| 788 |
+
arch::CacheOperation::Global,
|
| 789 |
+
MmaPolicy,
|
| 790 |
+
Stages
|
| 791 |
+
>;
|
| 792 |
+
|
| 793 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 794 |
+
|
| 795 |
+
// Define the epilogue
|
| 796 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 797 |
+
ThreadblockShape,
|
| 798 |
+
WarpMmaTensorOp,
|
| 799 |
+
kPartitionsK,
|
| 800 |
+
EpilogueOutputOp,
|
| 801 |
+
EpilogueOutputOp::kCount,
|
| 802 |
+
InterleavedK
|
| 803 |
+
>::Epilogue;
|
| 804 |
+
|
| 805 |
+
// Define the kernel
|
| 806 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 807 |
+
Mma,
|
| 808 |
+
Epilogue,
|
| 809 |
+
ThreadblockSwizzle,
|
| 810 |
+
conv::Operator::kFprop
|
| 811 |
+
>;
|
| 812 |
+
};
|
| 813 |
+
|
| 814 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 815 |
+
|
| 816 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
| 817 |
+
/// and 2 stage pipeline.
|
| 818 |
+
template <
|
| 819 |
+
typename ElementA,
|
| 820 |
+
typename LayoutA,
|
| 821 |
+
typename ElementB,
|
| 822 |
+
typename LayoutB,
|
| 823 |
+
typename ElementC,
|
| 824 |
+
typename LayoutC,
|
| 825 |
+
typename ElementAccumulator,
|
| 826 |
+
typename ArchTag,
|
| 827 |
+
typename ThreadblockShape,
|
| 828 |
+
typename WarpShape,
|
| 829 |
+
typename InstructionShape,
|
| 830 |
+
typename EpilogueOutputOp,
|
| 831 |
+
typename ThreadblockSwizzle,
|
| 832 |
+
typename MathOperatorTag,
|
| 833 |
+
conv::StrideSupport StrideSupport,
|
| 834 |
+
int AlignmentA,
|
| 835 |
+
int AlignmentB
|
| 836 |
+
>
|
| 837 |
+
struct DefaultConv2dFprop <
|
| 838 |
+
ElementA,
|
| 839 |
+
LayoutA,
|
| 840 |
+
ElementB,
|
| 841 |
+
LayoutB,
|
| 842 |
+
ElementC,
|
| 843 |
+
LayoutC,
|
| 844 |
+
ElementAccumulator,
|
| 845 |
+
arch::OpClassTensorOp,
|
| 846 |
+
ArchTag,
|
| 847 |
+
ThreadblockShape,
|
| 848 |
+
WarpShape,
|
| 849 |
+
InstructionShape,
|
| 850 |
+
EpilogueOutputOp,
|
| 851 |
+
ThreadblockSwizzle,
|
| 852 |
+
2,
|
| 853 |
+
MathOperatorTag,
|
| 854 |
+
IteratorAlgorithm::kAnalytic,
|
| 855 |
+
StrideSupport,
|
| 856 |
+
AlignmentA,
|
| 857 |
+
AlignmentB
|
| 858 |
+
> {
|
| 859 |
+
|
| 860 |
+
// Define the core components from GEMM
|
| 861 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 862 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 863 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 864 |
+
2, MathOperatorTag>;
|
| 865 |
+
|
| 866 |
+
// Define iterators over tiles from the A operand
|
| 867 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 868 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 869 |
+
using IteratorA =
|
| 870 |
+
cutlass::conv::threadblock::TileIterator<
|
| 871 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 872 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 873 |
+
ElementA, LayoutA,
|
| 874 |
+
ThreadMapA,
|
| 875 |
+
AccessTypeA
|
| 876 |
+
>
|
| 877 |
+
>;
|
| 878 |
+
|
| 879 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 880 |
+
|
| 881 |
+
// Define iterators over tiles from the B operand
|
| 882 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 883 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 884 |
+
using IteratorB =
|
| 885 |
+
cutlass::conv::threadblock::TileIterator<
|
| 886 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 887 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 888 |
+
ElementB, LayoutB,
|
| 889 |
+
ThreadMapB,
|
| 890 |
+
AccessTypeB
|
| 891 |
+
>
|
| 892 |
+
>;
|
| 893 |
+
|
| 894 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 895 |
+
|
| 896 |
+
// Warp-level GEMM components
|
| 897 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 898 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 899 |
+
|
| 900 |
+
// Define the Mma
|
| 901 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 902 |
+
ThreadblockShape,
|
| 903 |
+
IteratorA,
|
| 904 |
+
SmemIteratorA,
|
| 905 |
+
IteratorB,
|
| 906 |
+
SmemIteratorB,
|
| 907 |
+
ElementC,
|
| 908 |
+
LayoutC,
|
| 909 |
+
MmaPolicy
|
| 910 |
+
>;
|
| 911 |
+
|
| 912 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 913 |
+
|
| 914 |
+
// Define the epilogue
|
| 915 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 916 |
+
ArchTag,
|
| 917 |
+
ThreadblockShape,
|
| 918 |
+
WarpMmaTensorOp,
|
| 919 |
+
kPartitionsK,
|
| 920 |
+
EpilogueOutputOp
|
| 921 |
+
>::Epilogue;
|
| 922 |
+
|
| 923 |
+
// Define the kernel
|
| 924 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 925 |
+
Mma,
|
| 926 |
+
Epilogue,
|
| 927 |
+
ThreadblockSwizzle,
|
| 928 |
+
conv::Operator::kFprop
|
| 929 |
+
>;
|
| 930 |
+
};
|
| 931 |
+
|
| 932 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 933 |
+
|
| 934 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
| 935 |
+
/// pipeline with interleaved layout.
|
| 936 |
+
template <
|
| 937 |
+
typename ElementA,
|
| 938 |
+
typename ElementB,
|
| 939 |
+
typename ElementC,
|
| 940 |
+
typename LayoutC,
|
| 941 |
+
typename ElementAccumulator,
|
| 942 |
+
typename ArchTag,
|
| 943 |
+
typename ThreadblockShape,
|
| 944 |
+
typename WarpShape,
|
| 945 |
+
typename InstructionShape,
|
| 946 |
+
typename EpilogueOutputOp,
|
| 947 |
+
typename ThreadblockSwizzle,
|
| 948 |
+
typename MathOperatorTag,
|
| 949 |
+
conv::StrideSupport StrideSupport,
|
| 950 |
+
int AlignmentA,
|
| 951 |
+
int AlignmentB,
|
| 952 |
+
int InterleavedK
|
| 953 |
+
>
|
| 954 |
+
struct DefaultConv2dFprop <
|
| 955 |
+
ElementA,
|
| 956 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 957 |
+
ElementB,
|
| 958 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 959 |
+
ElementC,
|
| 960 |
+
LayoutC,
|
| 961 |
+
ElementAccumulator,
|
| 962 |
+
arch::OpClassTensorOp,
|
| 963 |
+
ArchTag,
|
| 964 |
+
ThreadblockShape,
|
| 965 |
+
WarpShape,
|
| 966 |
+
InstructionShape,
|
| 967 |
+
EpilogueOutputOp,
|
| 968 |
+
ThreadblockSwizzle,
|
| 969 |
+
2,
|
| 970 |
+
MathOperatorTag,
|
| 971 |
+
IteratorAlgorithm::kAnalytic,
|
| 972 |
+
StrideSupport,
|
| 973 |
+
AlignmentA,
|
| 974 |
+
AlignmentB
|
| 975 |
+
> {
|
| 976 |
+
|
| 977 |
+
// Define the core components from GEMM
|
| 978 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 979 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 980 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 981 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 982 |
+
2, MathOperatorTag, true>;
|
| 983 |
+
|
| 984 |
+
// Define iterators over tiles from the A operand
|
| 985 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 986 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 987 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 988 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 989 |
+
// layout.
|
| 990 |
+
using ThreadMapA = typename MmaCore::SmemThreadMapA;
|
| 991 |
+
using IteratorA =
|
| 992 |
+
cutlass::conv::threadblock::TileIterator<
|
| 993 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 994 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 995 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 996 |
+
ThreadMapA
|
| 997 |
+
>
|
| 998 |
+
>;
|
| 999 |
+
|
| 1000 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1001 |
+
|
| 1002 |
+
// Define iterators over tiles from the B operand
|
| 1003 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 1004 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 1005 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 1006 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 1007 |
+
// layout.
|
| 1008 |
+
using ThreadMapB = typename MmaCore::SmemThreadMapB;
|
| 1009 |
+
using IteratorB =
|
| 1010 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1011 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 1012 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1013 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 1014 |
+
ThreadMapB
|
| 1015 |
+
>
|
| 1016 |
+
>;
|
| 1017 |
+
|
| 1018 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1019 |
+
|
| 1020 |
+
// Warp-level GEMM components
|
| 1021 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 1022 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1023 |
+
|
| 1024 |
+
// Define the Mma
|
| 1025 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1026 |
+
ThreadblockShape,
|
| 1027 |
+
IteratorA,
|
| 1028 |
+
SmemIteratorA,
|
| 1029 |
+
IteratorB,
|
| 1030 |
+
SmemIteratorB,
|
| 1031 |
+
ElementC,
|
| 1032 |
+
LayoutC,
|
| 1033 |
+
MmaPolicy
|
| 1034 |
+
>;
|
| 1035 |
+
|
| 1036 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 1037 |
+
|
| 1038 |
+
// Define the epilogue
|
| 1039 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 1040 |
+
ThreadblockShape,
|
| 1041 |
+
WarpMmaTensorOp,
|
| 1042 |
+
kPartitionsK,
|
| 1043 |
+
EpilogueOutputOp,
|
| 1044 |
+
EpilogueOutputOp::kCount,
|
| 1045 |
+
InterleavedK
|
| 1046 |
+
>::Epilogue;
|
| 1047 |
+
|
| 1048 |
+
// Define the kernel
|
| 1049 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1050 |
+
Mma,
|
| 1051 |
+
Epilogue,
|
| 1052 |
+
ThreadblockSwizzle,
|
| 1053 |
+
conv::Operator::kFprop
|
| 1054 |
+
>;
|
| 1055 |
+
};
|
| 1056 |
+
|
| 1057 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1058 |
+
|
| 1059 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 1060 |
+
/// multistage pipeline.
|
| 1061 |
+
template <
|
| 1062 |
+
typename ElementA,
|
| 1063 |
+
typename LayoutA,
|
| 1064 |
+
typename ElementB,
|
| 1065 |
+
typename LayoutB,
|
| 1066 |
+
typename ElementC,
|
| 1067 |
+
typename LayoutC,
|
| 1068 |
+
typename ElementAccumulator,
|
| 1069 |
+
typename ArchTag,
|
| 1070 |
+
typename ThreadblockShape,
|
| 1071 |
+
typename WarpShape,
|
| 1072 |
+
typename InstructionShape,
|
| 1073 |
+
typename EpilogueOutputOp,
|
| 1074 |
+
typename ThreadblockSwizzle,
|
| 1075 |
+
int Stages,
|
| 1076 |
+
typename MathOperatorTag,
|
| 1077 |
+
conv::StrideSupport StrideSupport,
|
| 1078 |
+
int AlignmentA,
|
| 1079 |
+
int AlignmentB
|
| 1080 |
+
>
|
| 1081 |
+
struct DefaultConv2dFprop <
|
| 1082 |
+
ElementA,
|
| 1083 |
+
LayoutA,
|
| 1084 |
+
ElementB,
|
| 1085 |
+
LayoutB,
|
| 1086 |
+
ElementC,
|
| 1087 |
+
LayoutC,
|
| 1088 |
+
ElementAccumulator,
|
| 1089 |
+
arch::OpClassTensorOp,
|
| 1090 |
+
ArchTag,
|
| 1091 |
+
ThreadblockShape,
|
| 1092 |
+
WarpShape,
|
| 1093 |
+
InstructionShape,
|
| 1094 |
+
EpilogueOutputOp,
|
| 1095 |
+
ThreadblockSwizzle,
|
| 1096 |
+
Stages,
|
| 1097 |
+
MathOperatorTag,
|
| 1098 |
+
IteratorAlgorithm::kOptimized,
|
| 1099 |
+
StrideSupport,
|
| 1100 |
+
AlignmentA,
|
| 1101 |
+
AlignmentB
|
| 1102 |
+
> {
|
| 1103 |
+
|
| 1104 |
+
// Define the core components from GEMM
|
| 1105 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1106 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1107 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 1108 |
+
Stages, MathOperatorTag
|
| 1109 |
+
>;
|
| 1110 |
+
|
| 1111 |
+
// Define iterators over tiles from the A operand
|
| 1112 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1113 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 1114 |
+
using IteratorA =
|
| 1115 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 1116 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1117 |
+
ElementA,
|
| 1118 |
+
LayoutA,
|
| 1119 |
+
ThreadMapA,
|
| 1120 |
+
AccessTypeA
|
| 1121 |
+
>;
|
| 1122 |
+
|
| 1123 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1124 |
+
|
| 1125 |
+
// Define iterators over tiles from the B operand
|
| 1126 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1127 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 1128 |
+
using IteratorB =
|
| 1129 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 1130 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1131 |
+
ElementB,
|
| 1132 |
+
LayoutB,
|
| 1133 |
+
ThreadMapB,
|
| 1134 |
+
AccessTypeB
|
| 1135 |
+
>;
|
| 1136 |
+
|
| 1137 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1138 |
+
|
| 1139 |
+
// Warp-level GEMM components
|
| 1140 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 1141 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1142 |
+
|
| 1143 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
| 1144 |
+
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
| 1145 |
+
? cutlass::arch::CacheOperation::Global
|
| 1146 |
+
: cutlass::arch::CacheOperation::Always;
|
| 1147 |
+
|
| 1148 |
+
// Define the Mma
|
| 1149 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1150 |
+
ThreadblockShape,
|
| 1151 |
+
IteratorA,
|
| 1152 |
+
SmemIteratorA,
|
| 1153 |
+
arch::CacheOperation::Always,
|
| 1154 |
+
IteratorB,
|
| 1155 |
+
SmemIteratorB,
|
| 1156 |
+
CacheOpB,
|
| 1157 |
+
MmaPolicy,
|
| 1158 |
+
Stages
|
| 1159 |
+
>;
|
| 1160 |
+
|
| 1161 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 1162 |
+
|
| 1163 |
+
// Define the epilogue
|
| 1164 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 1165 |
+
ThreadblockShape,
|
| 1166 |
+
WarpMmaTensorOp,
|
| 1167 |
+
kPartitionsK,
|
| 1168 |
+
EpilogueOutputOp,
|
| 1169 |
+
EpilogueOutputOp::kCount,
|
| 1170 |
+
false,
|
| 1171 |
+
layout::NoPermute,
|
| 1172 |
+
StrideSupport,
|
| 1173 |
+
4
|
| 1174 |
+
>::Epilogue;
|
| 1175 |
+
|
| 1176 |
+
// Define the kernel
|
| 1177 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1178 |
+
Mma,
|
| 1179 |
+
Epilogue,
|
| 1180 |
+
ThreadblockSwizzle,
|
| 1181 |
+
conv::Operator::kFprop
|
| 1182 |
+
>;
|
| 1183 |
+
};
|
| 1184 |
+
|
| 1185 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1186 |
+
|
| 1187 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 1188 |
+
// multistage pipeline with interleaved layout.
|
| 1189 |
+
template <
|
| 1190 |
+
typename ElementA,
|
| 1191 |
+
typename ElementB,
|
| 1192 |
+
typename ElementC,
|
| 1193 |
+
typename LayoutC,
|
| 1194 |
+
typename ElementAccumulator,
|
| 1195 |
+
typename ArchTag,
|
| 1196 |
+
typename ThreadblockShape,
|
| 1197 |
+
typename WarpShape,
|
| 1198 |
+
typename InstructionShape,
|
| 1199 |
+
typename EpilogueOutputOp,
|
| 1200 |
+
typename ThreadblockSwizzle,
|
| 1201 |
+
int Stages,
|
| 1202 |
+
typename MathOperatorTag,
|
| 1203 |
+
conv::StrideSupport StrideSupport,
|
| 1204 |
+
int AlignmentA,
|
| 1205 |
+
int AlignmentB,
|
| 1206 |
+
int InterleavedK
|
| 1207 |
+
>
|
| 1208 |
+
struct DefaultConv2dFprop <
|
| 1209 |
+
ElementA,
|
| 1210 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 1211 |
+
ElementB,
|
| 1212 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 1213 |
+
ElementC,
|
| 1214 |
+
LayoutC,
|
| 1215 |
+
ElementAccumulator,
|
| 1216 |
+
arch::OpClassTensorOp,
|
| 1217 |
+
ArchTag,
|
| 1218 |
+
ThreadblockShape,
|
| 1219 |
+
WarpShape,
|
| 1220 |
+
InstructionShape,
|
| 1221 |
+
EpilogueOutputOp,
|
| 1222 |
+
ThreadblockSwizzle,
|
| 1223 |
+
Stages,
|
| 1224 |
+
MathOperatorTag,
|
| 1225 |
+
IteratorAlgorithm::kOptimized,
|
| 1226 |
+
StrideSupport,
|
| 1227 |
+
AlignmentA,
|
| 1228 |
+
AlignmentB
|
| 1229 |
+
> {
|
| 1230 |
+
|
| 1231 |
+
// Define the core components from GEMM
|
| 1232 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1233 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 1234 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>, ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 1235 |
+
Stages, MathOperatorTag, true
|
| 1236 |
+
>;
|
| 1237 |
+
|
| 1238 |
+
// Define iterators over tiles from the A operand
|
| 1239 |
+
using ThreadMapA = typename MmaCore::SmemThreadMapA;
|
| 1240 |
+
using IteratorA =
|
| 1241 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 1242 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1243 |
+
ElementA,
|
| 1244 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 1245 |
+
ThreadMapA
|
| 1246 |
+
>;
|
| 1247 |
+
|
| 1248 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1249 |
+
|
| 1250 |
+
// Define iterators over tiles from the B operand
|
| 1251 |
+
using ThreadMapB = typename MmaCore::SmemThreadMapB;
|
| 1252 |
+
using IteratorB =
|
| 1253 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 1254 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1255 |
+
ElementB,
|
| 1256 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 1257 |
+
ThreadMapB
|
| 1258 |
+
>;
|
| 1259 |
+
|
| 1260 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1261 |
+
|
| 1262 |
+
// Warp-level GEMM components
|
| 1263 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 1264 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1265 |
+
|
| 1266 |
+
// Define the Mma
|
| 1267 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1268 |
+
ThreadblockShape,
|
| 1269 |
+
IteratorA,
|
| 1270 |
+
SmemIteratorA,
|
| 1271 |
+
arch::CacheOperation::Always,
|
| 1272 |
+
IteratorB,
|
| 1273 |
+
SmemIteratorB,
|
| 1274 |
+
arch::CacheOperation::Global,
|
| 1275 |
+
MmaPolicy,
|
| 1276 |
+
Stages
|
| 1277 |
+
>;
|
| 1278 |
+
|
| 1279 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 1280 |
+
|
| 1281 |
+
// Define the epilogue
|
| 1282 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 1283 |
+
ThreadblockShape,
|
| 1284 |
+
WarpMmaTensorOp,
|
| 1285 |
+
kPartitionsK,
|
| 1286 |
+
EpilogueOutputOp,
|
| 1287 |
+
EpilogueOutputOp::kCount,
|
| 1288 |
+
InterleavedK
|
| 1289 |
+
>::Epilogue;
|
| 1290 |
+
|
| 1291 |
+
// Define the kernel
|
| 1292 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1293 |
+
Mma,
|
| 1294 |
+
Epilogue,
|
| 1295 |
+
ThreadblockSwizzle,
|
| 1296 |
+
conv::Operator::kFprop
|
| 1297 |
+
>;
|
| 1298 |
+
};
|
| 1299 |
+
|
| 1300 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1301 |
+
|
| 1302 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
| 1303 |
+
/// and 2 stage pipeline.
|
| 1304 |
+
template <
|
| 1305 |
+
typename ElementA,
|
| 1306 |
+
typename LayoutA,
|
| 1307 |
+
typename ElementB,
|
| 1308 |
+
typename LayoutB,
|
| 1309 |
+
typename ElementC,
|
| 1310 |
+
typename LayoutC,
|
| 1311 |
+
typename ElementAccumulator,
|
| 1312 |
+
typename ArchTag,
|
| 1313 |
+
typename ThreadblockShape,
|
| 1314 |
+
typename WarpShape,
|
| 1315 |
+
typename InstructionShape,
|
| 1316 |
+
typename EpilogueOutputOp,
|
| 1317 |
+
typename ThreadblockSwizzle,
|
| 1318 |
+
typename MathOperatorTag,
|
| 1319 |
+
conv::StrideSupport StrideSupport,
|
| 1320 |
+
int AlignmentA,
|
| 1321 |
+
int AlignmentB
|
| 1322 |
+
>
|
| 1323 |
+
struct DefaultConv2dFprop <
|
| 1324 |
+
ElementA,
|
| 1325 |
+
LayoutA,
|
| 1326 |
+
ElementB,
|
| 1327 |
+
LayoutB,
|
| 1328 |
+
ElementC,
|
| 1329 |
+
LayoutC,
|
| 1330 |
+
ElementAccumulator,
|
| 1331 |
+
arch::OpClassTensorOp,
|
| 1332 |
+
ArchTag,
|
| 1333 |
+
ThreadblockShape,
|
| 1334 |
+
WarpShape,
|
| 1335 |
+
InstructionShape,
|
| 1336 |
+
EpilogueOutputOp,
|
| 1337 |
+
ThreadblockSwizzle,
|
| 1338 |
+
2,
|
| 1339 |
+
MathOperatorTag,
|
| 1340 |
+
IteratorAlgorithm::kOptimized,
|
| 1341 |
+
StrideSupport,
|
| 1342 |
+
AlignmentA,
|
| 1343 |
+
AlignmentB
|
| 1344 |
+
> {
|
| 1345 |
+
|
| 1346 |
+
// Define the core components from GEMM
|
| 1347 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1348 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1349 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 1350 |
+
2, MathOperatorTag>;
|
| 1351 |
+
|
| 1352 |
+
// Define iterators over tiles from the A operand
|
| 1353 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1354 |
+
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
| 1355 |
+
using IteratorA =
|
| 1356 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1357 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 1358 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1359 |
+
ElementA,
|
| 1360 |
+
LayoutA,
|
| 1361 |
+
ThreadMapA,
|
| 1362 |
+
AccessTypeA
|
| 1363 |
+
>
|
| 1364 |
+
>;
|
| 1365 |
+
|
| 1366 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1367 |
+
|
| 1368 |
+
// Define iterators over tiles from the B operand
|
| 1369 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1370 |
+
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
| 1371 |
+
using IteratorB =
|
| 1372 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1373 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 1374 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1375 |
+
ElementB,
|
| 1376 |
+
LayoutB,
|
| 1377 |
+
ThreadMapB,
|
| 1378 |
+
AccessTypeB
|
| 1379 |
+
>
|
| 1380 |
+
>;
|
| 1381 |
+
|
| 1382 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1383 |
+
|
| 1384 |
+
// Warp-level GEMM components
|
| 1385 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 1386 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1387 |
+
|
| 1388 |
+
// Define the Mma
|
| 1389 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1390 |
+
ThreadblockShape,
|
| 1391 |
+
IteratorA,
|
| 1392 |
+
SmemIteratorA,
|
| 1393 |
+
IteratorB,
|
| 1394 |
+
SmemIteratorB,
|
| 1395 |
+
ElementC,
|
| 1396 |
+
LayoutC,
|
| 1397 |
+
MmaPolicy
|
| 1398 |
+
>;
|
| 1399 |
+
|
| 1400 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 1401 |
+
|
| 1402 |
+
// Define the epilogue
|
| 1403 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 1404 |
+
ArchTag,
|
| 1405 |
+
ThreadblockShape,
|
| 1406 |
+
WarpMmaTensorOp,
|
| 1407 |
+
kPartitionsK,
|
| 1408 |
+
EpilogueOutputOp
|
| 1409 |
+
>::Epilogue;
|
| 1410 |
+
|
| 1411 |
+
// Define the kernel
|
| 1412 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1413 |
+
Mma,
|
| 1414 |
+
Epilogue,
|
| 1415 |
+
ThreadblockSwizzle,
|
| 1416 |
+
conv::Operator::kFprop
|
| 1417 |
+
>;
|
| 1418 |
+
};
|
| 1419 |
+
|
| 1420 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1421 |
+
|
| 1422 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
| 1423 |
+
/// pipeline with interleaved layout.
|
| 1424 |
+
template <
|
| 1425 |
+
typename ElementA,
|
| 1426 |
+
typename ElementB,
|
| 1427 |
+
typename ElementC,
|
| 1428 |
+
typename LayoutC,
|
| 1429 |
+
typename ElementAccumulator,
|
| 1430 |
+
typename ArchTag,
|
| 1431 |
+
typename ThreadblockShape,
|
| 1432 |
+
typename WarpShape,
|
| 1433 |
+
typename InstructionShape,
|
| 1434 |
+
typename EpilogueOutputOp,
|
| 1435 |
+
typename ThreadblockSwizzle,
|
| 1436 |
+
typename MathOperatorTag,
|
| 1437 |
+
conv::StrideSupport StrideSupport,
|
| 1438 |
+
int AlignmentA,
|
| 1439 |
+
int AlignmentB,
|
| 1440 |
+
int InterleavedK
|
| 1441 |
+
>
|
| 1442 |
+
struct DefaultConv2dFprop <
|
| 1443 |
+
ElementA,
|
| 1444 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 1445 |
+
ElementB,
|
| 1446 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 1447 |
+
ElementC,
|
| 1448 |
+
LayoutC,
|
| 1449 |
+
ElementAccumulator,
|
| 1450 |
+
arch::OpClassTensorOp,
|
| 1451 |
+
ArchTag,
|
| 1452 |
+
ThreadblockShape,
|
| 1453 |
+
WarpShape,
|
| 1454 |
+
InstructionShape,
|
| 1455 |
+
EpilogueOutputOp,
|
| 1456 |
+
ThreadblockSwizzle,
|
| 1457 |
+
2,
|
| 1458 |
+
MathOperatorTag,
|
| 1459 |
+
IteratorAlgorithm::kOptimized,
|
| 1460 |
+
StrideSupport,
|
| 1461 |
+
AlignmentA,
|
| 1462 |
+
AlignmentB
|
| 1463 |
+
> {
|
| 1464 |
+
|
| 1465 |
+
// Define the core components from GEMM
|
| 1466 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1467 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 1468 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 1469 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 1470 |
+
2, MathOperatorTag, true>;
|
| 1471 |
+
|
| 1472 |
+
// Define iterators over tiles from the A operand
|
| 1473 |
+
using ThreadMapA = typename MmaCore::SmemThreadMapA;
|
| 1474 |
+
using IteratorA =
|
| 1475 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1476 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 1477 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1478 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 1479 |
+
ThreadMapA
|
| 1480 |
+
>
|
| 1481 |
+
>;
|
| 1482 |
+
|
| 1483 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1484 |
+
|
| 1485 |
+
// Define iterators over tiles from the B operand
|
| 1486 |
+
using ThreadMapB = typename MmaCore::SmemThreadMapB;
|
| 1487 |
+
using IteratorB =
|
| 1488 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1489 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 1490 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1491 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 1492 |
+
ThreadMapB
|
| 1493 |
+
>
|
| 1494 |
+
>;
|
| 1495 |
+
|
| 1496 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1497 |
+
|
| 1498 |
+
// Warp-level GEMM components
|
| 1499 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 1500 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1501 |
+
|
| 1502 |
+
// Define the Mma
|
| 1503 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1504 |
+
ThreadblockShape,
|
| 1505 |
+
IteratorA,
|
| 1506 |
+
SmemIteratorA,
|
| 1507 |
+
IteratorB,
|
| 1508 |
+
SmemIteratorB,
|
| 1509 |
+
ElementC,
|
| 1510 |
+
LayoutC,
|
| 1511 |
+
MmaPolicy
|
| 1512 |
+
>;
|
| 1513 |
+
|
| 1514 |
+
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
| 1515 |
+
|
| 1516 |
+
// Define the epilogue
|
| 1517 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 1518 |
+
ThreadblockShape,
|
| 1519 |
+
WarpMmaTensorOp,
|
| 1520 |
+
kPartitionsK,
|
| 1521 |
+
EpilogueOutputOp,
|
| 1522 |
+
EpilogueOutputOp::kCount,
|
| 1523 |
+
InterleavedK
|
| 1524 |
+
>::Epilogue;
|
| 1525 |
+
|
| 1526 |
+
// Define the kernel
|
| 1527 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1528 |
+
Mma,
|
| 1529 |
+
Epilogue,
|
| 1530 |
+
ThreadblockSwizzle,
|
| 1531 |
+
conv::Operator::kFprop
|
| 1532 |
+
>;
|
| 1533 |
+
};
|
| 1534 |
+
|
| 1535 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1536 |
+
// OpClassSimt convolutions
|
| 1537 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1538 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm,
|
| 1539 |
+
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
| 1540 |
+
|
| 1541 |
+
template <
|
| 1542 |
+
typename ElementA,
|
| 1543 |
+
typename LayoutA,
|
| 1544 |
+
typename ElementB,
|
| 1545 |
+
typename LayoutB,
|
| 1546 |
+
typename ElementC,
|
| 1547 |
+
typename LayoutC,
|
| 1548 |
+
typename ElementAccumulator,
|
| 1549 |
+
typename ArchTag,
|
| 1550 |
+
typename ThreadblockShape,
|
| 1551 |
+
typename WarpShape,
|
| 1552 |
+
typename InstructionShape,
|
| 1553 |
+
typename EpilogueOutputOp,
|
| 1554 |
+
typename ThreadblockSwizzle,
|
| 1555 |
+
int Stages,
|
| 1556 |
+
typename MathOperatorTag,
|
| 1557 |
+
conv::StrideSupport StrideSupport,
|
| 1558 |
+
int AlignmentA,
|
| 1559 |
+
int AlignmentB
|
| 1560 |
+
>
|
| 1561 |
+
struct DefaultConv2dFprop <
|
| 1562 |
+
ElementA,
|
| 1563 |
+
LayoutA,
|
| 1564 |
+
ElementB,
|
| 1565 |
+
LayoutB,
|
| 1566 |
+
ElementC,
|
| 1567 |
+
LayoutC,
|
| 1568 |
+
ElementAccumulator,
|
| 1569 |
+
arch::OpClassSimt,
|
| 1570 |
+
ArchTag,
|
| 1571 |
+
ThreadblockShape,
|
| 1572 |
+
WarpShape,
|
| 1573 |
+
InstructionShape,
|
| 1574 |
+
EpilogueOutputOp,
|
| 1575 |
+
ThreadblockSwizzle,
|
| 1576 |
+
Stages,
|
| 1577 |
+
MathOperatorTag,
|
| 1578 |
+
IteratorAlgorithm::kAnalytic,
|
| 1579 |
+
StrideSupport,
|
| 1580 |
+
AlignmentA,
|
| 1581 |
+
AlignmentB
|
| 1582 |
+
> {
|
| 1583 |
+
|
| 1584 |
+
// Define the core components from GEMM
|
| 1585 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1586 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1587 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1588 |
+
Stages, MathOperatorTag>;
|
| 1589 |
+
|
| 1590 |
+
// Define iterators over tiles from the A operand
|
| 1591 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1592 |
+
using IteratorA =
|
| 1593 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 1594 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1595 |
+
ElementA, LayoutA,
|
| 1596 |
+
ThreadMapA
|
| 1597 |
+
>;
|
| 1598 |
+
|
| 1599 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1600 |
+
|
| 1601 |
+
// Define iterators over tiles from the B operand
|
| 1602 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1603 |
+
using IteratorB =
|
| 1604 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 1605 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1606 |
+
ElementB, LayoutB,
|
| 1607 |
+
ThreadMapB
|
| 1608 |
+
>;
|
| 1609 |
+
|
| 1610 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1611 |
+
|
| 1612 |
+
// Warp-level GEMM components
|
| 1613 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1614 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1615 |
+
|
| 1616 |
+
// Define the Mma
|
| 1617 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1618 |
+
ThreadblockShape,
|
| 1619 |
+
IteratorA,
|
| 1620 |
+
SmemIteratorA,
|
| 1621 |
+
arch::CacheOperation::Always,
|
| 1622 |
+
IteratorB,
|
| 1623 |
+
SmemIteratorB,
|
| 1624 |
+
arch::CacheOperation::Always,
|
| 1625 |
+
MmaPolicy,
|
| 1626 |
+
Stages
|
| 1627 |
+
>;
|
| 1628 |
+
|
| 1629 |
+
// Define the epilogue
|
| 1630 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1631 |
+
ThreadblockShape,
|
| 1632 |
+
WarpMmaSimtOp,
|
| 1633 |
+
EpilogueOutputOp,
|
| 1634 |
+
EpilogueOutputOp::kCount,
|
| 1635 |
+
false,
|
| 1636 |
+
layout::NoPermute,
|
| 1637 |
+
StrideSupport,
|
| 1638 |
+
4
|
| 1639 |
+
>::Epilogue;
|
| 1640 |
+
|
| 1641 |
+
// Define the kernel
|
| 1642 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1643 |
+
Mma,
|
| 1644 |
+
Epilogue,
|
| 1645 |
+
ThreadblockSwizzle,
|
| 1646 |
+
conv::Operator::kFprop
|
| 1647 |
+
>;
|
| 1648 |
+
|
| 1649 |
+
};
|
| 1650 |
+
|
| 1651 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1652 |
+
|
| 1653 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm,
|
| 1654 |
+
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
| 1655 |
+
|
| 1656 |
+
template <
|
| 1657 |
+
typename ElementA,
|
| 1658 |
+
typename LayoutA,
|
| 1659 |
+
typename ElementB,
|
| 1660 |
+
typename LayoutB,
|
| 1661 |
+
typename ElementC,
|
| 1662 |
+
typename LayoutC,
|
| 1663 |
+
typename ElementAccumulator,
|
| 1664 |
+
typename ArchTag,
|
| 1665 |
+
typename ThreadblockShape,
|
| 1666 |
+
typename WarpShape,
|
| 1667 |
+
typename InstructionShape,
|
| 1668 |
+
typename EpilogueOutputOp,
|
| 1669 |
+
typename ThreadblockSwizzle,
|
| 1670 |
+
int Stages,
|
| 1671 |
+
typename MathOperatorTag,
|
| 1672 |
+
conv::StrideSupport StrideSupport,
|
| 1673 |
+
int AlignmentA,
|
| 1674 |
+
int AlignmentB
|
| 1675 |
+
>
|
| 1676 |
+
struct DefaultConv2dFprop <
|
| 1677 |
+
ElementA,
|
| 1678 |
+
LayoutA,
|
| 1679 |
+
ElementB,
|
| 1680 |
+
LayoutB,
|
| 1681 |
+
ElementC,
|
| 1682 |
+
LayoutC,
|
| 1683 |
+
ElementAccumulator,
|
| 1684 |
+
arch::OpClassSimt,
|
| 1685 |
+
ArchTag,
|
| 1686 |
+
ThreadblockShape,
|
| 1687 |
+
WarpShape,
|
| 1688 |
+
InstructionShape,
|
| 1689 |
+
EpilogueOutputOp,
|
| 1690 |
+
ThreadblockSwizzle,
|
| 1691 |
+
Stages,
|
| 1692 |
+
MathOperatorTag,
|
| 1693 |
+
IteratorAlgorithm::kOptimized,
|
| 1694 |
+
StrideSupport,
|
| 1695 |
+
AlignmentA,
|
| 1696 |
+
AlignmentB
|
| 1697 |
+
> {
|
| 1698 |
+
|
| 1699 |
+
// Define the core components from GEMM
|
| 1700 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1701 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1702 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1703 |
+
Stages, MathOperatorTag>;
|
| 1704 |
+
|
| 1705 |
+
// Define iterators over tiles from the A operand
|
| 1706 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1707 |
+
using IteratorA =
|
| 1708 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 1709 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1710 |
+
ElementA,
|
| 1711 |
+
LayoutA,
|
| 1712 |
+
ThreadMapA
|
| 1713 |
+
>;
|
| 1714 |
+
|
| 1715 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1716 |
+
|
| 1717 |
+
// Define iterators over tiles from the B operand
|
| 1718 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1719 |
+
using IteratorB =
|
| 1720 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 1721 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1722 |
+
ElementB,
|
| 1723 |
+
LayoutB,
|
| 1724 |
+
ThreadMapB
|
| 1725 |
+
>;
|
| 1726 |
+
|
| 1727 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1728 |
+
|
| 1729 |
+
// Warp-level GEMM components
|
| 1730 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1731 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1732 |
+
|
| 1733 |
+
// Define the Mma
|
| 1734 |
+
using Mma = threadblock::ImplicitGemmMultistage<
|
| 1735 |
+
ThreadblockShape,
|
| 1736 |
+
IteratorA,
|
| 1737 |
+
SmemIteratorA,
|
| 1738 |
+
arch::CacheOperation::Always,
|
| 1739 |
+
IteratorB,
|
| 1740 |
+
SmemIteratorB,
|
| 1741 |
+
arch::CacheOperation::Always,
|
| 1742 |
+
MmaPolicy,
|
| 1743 |
+
Stages
|
| 1744 |
+
>;
|
| 1745 |
+
|
| 1746 |
+
// Define the epilogue
|
| 1747 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1748 |
+
ThreadblockShape,
|
| 1749 |
+
WarpMmaSimtOp,
|
| 1750 |
+
EpilogueOutputOp,
|
| 1751 |
+
EpilogueOutputOp::kCount,
|
| 1752 |
+
false,
|
| 1753 |
+
layout::NoPermute,
|
| 1754 |
+
StrideSupport,
|
| 1755 |
+
4
|
| 1756 |
+
>::Epilogue;
|
| 1757 |
+
|
| 1758 |
+
// Define the kernel
|
| 1759 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1760 |
+
Mma,
|
| 1761 |
+
Epilogue,
|
| 1762 |
+
ThreadblockSwizzle,
|
| 1763 |
+
conv::Operator::kFprop
|
| 1764 |
+
>;
|
| 1765 |
+
};
|
| 1766 |
+
|
| 1767 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1768 |
+
|
| 1769 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm,
|
| 1770 |
+
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
| 1771 |
+
template <
|
| 1772 |
+
typename ElementA,
|
| 1773 |
+
typename LayoutA,
|
| 1774 |
+
typename ElementB,
|
| 1775 |
+
typename LayoutB,
|
| 1776 |
+
typename ElementC,
|
| 1777 |
+
typename LayoutC,
|
| 1778 |
+
typename ElementAccumulator,
|
| 1779 |
+
typename ArchTag,
|
| 1780 |
+
typename ThreadblockShape,
|
| 1781 |
+
typename WarpShape,
|
| 1782 |
+
typename InstructionShape,
|
| 1783 |
+
typename EpilogueOutputOp,
|
| 1784 |
+
typename ThreadblockSwizzle,
|
| 1785 |
+
typename MathOperatorTag,
|
| 1786 |
+
conv::StrideSupport StrideSupport,
|
| 1787 |
+
int AlignmentA,
|
| 1788 |
+
int AlignmentB
|
| 1789 |
+
>
|
| 1790 |
+
struct DefaultConv2dFprop <
|
| 1791 |
+
ElementA,
|
| 1792 |
+
LayoutA,
|
| 1793 |
+
ElementB,
|
| 1794 |
+
LayoutB,
|
| 1795 |
+
ElementC,
|
| 1796 |
+
LayoutC,
|
| 1797 |
+
ElementAccumulator,
|
| 1798 |
+
arch::OpClassSimt,
|
| 1799 |
+
ArchTag,
|
| 1800 |
+
ThreadblockShape,
|
| 1801 |
+
WarpShape,
|
| 1802 |
+
InstructionShape,
|
| 1803 |
+
EpilogueOutputOp,
|
| 1804 |
+
ThreadblockSwizzle,
|
| 1805 |
+
2,
|
| 1806 |
+
MathOperatorTag,
|
| 1807 |
+
IteratorAlgorithm::kAnalytic,
|
| 1808 |
+
StrideSupport,
|
| 1809 |
+
AlignmentA,
|
| 1810 |
+
AlignmentB
|
| 1811 |
+
> {
|
| 1812 |
+
|
| 1813 |
+
// Define the core components from GEMM
|
| 1814 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1815 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1816 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1817 |
+
2, MathOperatorTag>;
|
| 1818 |
+
|
| 1819 |
+
// Define iterators over tiles from the A operand
|
| 1820 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1821 |
+
using IteratorA =
|
| 1822 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1823 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 1824 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1825 |
+
ElementA, LayoutA,
|
| 1826 |
+
ThreadMapA
|
| 1827 |
+
>
|
| 1828 |
+
>;
|
| 1829 |
+
|
| 1830 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1831 |
+
|
| 1832 |
+
// Define iterators over tiles from the B operand
|
| 1833 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1834 |
+
using IteratorB =
|
| 1835 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1836 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 1837 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1838 |
+
ElementB, LayoutB,
|
| 1839 |
+
ThreadMapB
|
| 1840 |
+
>
|
| 1841 |
+
>;
|
| 1842 |
+
|
| 1843 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1844 |
+
|
| 1845 |
+
// Warp-level GEMM components
|
| 1846 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1847 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1848 |
+
|
| 1849 |
+
// Define the Mma
|
| 1850 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1851 |
+
ThreadblockShape,
|
| 1852 |
+
IteratorA,
|
| 1853 |
+
SmemIteratorA,
|
| 1854 |
+
IteratorB,
|
| 1855 |
+
SmemIteratorB,
|
| 1856 |
+
ElementC,
|
| 1857 |
+
LayoutC,
|
| 1858 |
+
MmaPolicy
|
| 1859 |
+
>;
|
| 1860 |
+
|
| 1861 |
+
// Define the epilogue
|
| 1862 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1863 |
+
ThreadblockShape,
|
| 1864 |
+
WarpMmaSimtOp,
|
| 1865 |
+
EpilogueOutputOp,
|
| 1866 |
+
EpilogueOutputOp::kCount,
|
| 1867 |
+
false,
|
| 1868 |
+
layout::NoPermute,
|
| 1869 |
+
StrideSupport,
|
| 1870 |
+
4
|
| 1871 |
+
>::Epilogue;
|
| 1872 |
+
|
| 1873 |
+
// Define the kernel
|
| 1874 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1875 |
+
Mma,
|
| 1876 |
+
Epilogue,
|
| 1877 |
+
ThreadblockSwizzle,
|
| 1878 |
+
conv::Operator::kFprop
|
| 1879 |
+
>;
|
| 1880 |
+
|
| 1881 |
+
};
|
| 1882 |
+
|
| 1883 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1884 |
+
|
| 1885 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm,
|
| 1886 |
+
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
| 1887 |
+
template <
|
| 1888 |
+
typename ElementA,
|
| 1889 |
+
typename LayoutA,
|
| 1890 |
+
typename ElementB,
|
| 1891 |
+
typename LayoutB,
|
| 1892 |
+
typename ElementC,
|
| 1893 |
+
typename LayoutC,
|
| 1894 |
+
typename ElementAccumulator,
|
| 1895 |
+
typename ArchTag,
|
| 1896 |
+
typename ThreadblockShape,
|
| 1897 |
+
typename WarpShape,
|
| 1898 |
+
typename InstructionShape,
|
| 1899 |
+
typename EpilogueOutputOp,
|
| 1900 |
+
typename ThreadblockSwizzle,
|
| 1901 |
+
typename MathOperatorTag,
|
| 1902 |
+
conv::StrideSupport StrideSupport,
|
| 1903 |
+
int AlignmentA,
|
| 1904 |
+
int AlignmentB
|
| 1905 |
+
>
|
| 1906 |
+
struct DefaultConv2dFprop <
|
| 1907 |
+
ElementA,
|
| 1908 |
+
LayoutA,
|
| 1909 |
+
ElementB,
|
| 1910 |
+
LayoutB,
|
| 1911 |
+
ElementC,
|
| 1912 |
+
LayoutC,
|
| 1913 |
+
ElementAccumulator,
|
| 1914 |
+
arch::OpClassSimt,
|
| 1915 |
+
ArchTag,
|
| 1916 |
+
ThreadblockShape,
|
| 1917 |
+
WarpShape,
|
| 1918 |
+
InstructionShape,
|
| 1919 |
+
EpilogueOutputOp,
|
| 1920 |
+
ThreadblockSwizzle,
|
| 1921 |
+
2,
|
| 1922 |
+
MathOperatorTag,
|
| 1923 |
+
IteratorAlgorithm::kOptimized,
|
| 1924 |
+
StrideSupport,
|
| 1925 |
+
AlignmentA,
|
| 1926 |
+
AlignmentB
|
| 1927 |
+
> {
|
| 1928 |
+
|
| 1929 |
+
// Define the core components from GEMM
|
| 1930 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 1931 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 1932 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
| 1933 |
+
2, MathOperatorTag>;
|
| 1934 |
+
|
| 1935 |
+
// Define iterators over tiles from the A operand
|
| 1936 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 1937 |
+
using IteratorA =
|
| 1938 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1939 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 1940 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 1941 |
+
ElementA,
|
| 1942 |
+
LayoutA,
|
| 1943 |
+
ThreadMapA
|
| 1944 |
+
>
|
| 1945 |
+
>;
|
| 1946 |
+
|
| 1947 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 1948 |
+
|
| 1949 |
+
// Define iterators over tiles from the B operand
|
| 1950 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 1951 |
+
using IteratorB =
|
| 1952 |
+
cutlass::conv::threadblock::TileIterator<
|
| 1953 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 1954 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 1955 |
+
ElementB,
|
| 1956 |
+
LayoutB,
|
| 1957 |
+
ThreadMapB
|
| 1958 |
+
>
|
| 1959 |
+
>;
|
| 1960 |
+
|
| 1961 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 1962 |
+
|
| 1963 |
+
// Warp-level GEMM components
|
| 1964 |
+
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
| 1965 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 1966 |
+
|
| 1967 |
+
// Define the Mma
|
| 1968 |
+
using Mma = threadblock::ImplicitGemmPipelined<
|
| 1969 |
+
ThreadblockShape,
|
| 1970 |
+
IteratorA,
|
| 1971 |
+
SmemIteratorA,
|
| 1972 |
+
IteratorB,
|
| 1973 |
+
SmemIteratorB,
|
| 1974 |
+
ElementC,
|
| 1975 |
+
LayoutC,
|
| 1976 |
+
MmaPolicy
|
| 1977 |
+
>;
|
| 1978 |
+
|
| 1979 |
+
// Define the epilogue
|
| 1980 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
| 1981 |
+
ThreadblockShape,
|
| 1982 |
+
WarpMmaSimtOp,
|
| 1983 |
+
EpilogueOutputOp,
|
| 1984 |
+
EpilogueOutputOp::kCount,
|
| 1985 |
+
false,
|
| 1986 |
+
layout::NoPermute,
|
| 1987 |
+
StrideSupport,
|
| 1988 |
+
4
|
| 1989 |
+
>::Epilogue;
|
| 1990 |
+
|
| 1991 |
+
// Define the kernel
|
| 1992 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
| 1993 |
+
Mma,
|
| 1994 |
+
Epilogue,
|
| 1995 |
+
ThreadblockSwizzle,
|
| 1996 |
+
conv::Operator::kFprop
|
| 1997 |
+
>;
|
| 1998 |
+
|
| 1999 |
+
};
|
| 2000 |
+
|
| 2001 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2002 |
+
|
| 2003 |
+
} // namespace kernel
|
| 2004 |
+
} // namespace conv
|
| 2005 |
+
} // namespace cutlass
|
| 2006 |
+
|
| 2007 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution
|
| 34 |
+
definitions that combine threadblock-scoped matrix multiply-add with the
|
| 35 |
+
appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
|
| 48 |
+
#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h"
|
| 49 |
+
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace cutlass {
|
| 54 |
+
namespace conv {
|
| 55 |
+
namespace kernel {
|
| 56 |
+
|
| 57 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
/// Defines a kernel for fused batch norm and Conv2dFprop
|
| 59 |
+
template <
|
| 60 |
+
typename ElementA,
|
| 61 |
+
typename LayoutA,
|
| 62 |
+
typename ElementB,
|
| 63 |
+
typename LayoutB,
|
| 64 |
+
typename ElementScaleBias,
|
| 65 |
+
typename LayoutScaleBias,
|
| 66 |
+
typename ElementC,
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
typename ElementAccumulator,
|
| 69 |
+
typename OperatorClass,
|
| 70 |
+
typename ArchTag,
|
| 71 |
+
typename ThreadblockShape,
|
| 72 |
+
typename WarpShape,
|
| 73 |
+
typename InstructionShape,
|
| 74 |
+
typename EpilogueOutputOp,
|
| 75 |
+
typename ThreadblockSwizzle,
|
| 76 |
+
int Stages,
|
| 77 |
+
typename MathOperatorTag,
|
| 78 |
+
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
| 79 |
+
conv::StrideSupport StrideSupport = StrideSupport::kUnity
|
| 80 |
+
> struct DefaultConv2dFpropFusion;
|
| 81 |
+
|
| 82 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 83 |
+
// OpClassTensorOp convolutions
|
| 84 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
|
| 86 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 87 |
+
/// pipeline.
|
| 88 |
+
template <
|
| 89 |
+
typename ElementA,
|
| 90 |
+
typename LayoutA,
|
| 91 |
+
typename ElementB,
|
| 92 |
+
typename LayoutB,
|
| 93 |
+
typename ElementScaleBias,
|
| 94 |
+
typename LayoutScaleBias,
|
| 95 |
+
typename ElementC,
|
| 96 |
+
typename LayoutC,
|
| 97 |
+
typename ElementAccumulator,
|
| 98 |
+
typename ArchTag,
|
| 99 |
+
typename ThreadblockShape,
|
| 100 |
+
typename WarpShape,
|
| 101 |
+
typename InstructionShape,
|
| 102 |
+
typename EpilogueOutputOp,
|
| 103 |
+
typename ThreadblockSwizzle,
|
| 104 |
+
int Stages,
|
| 105 |
+
typename MathOperatorTag
|
| 106 |
+
>
|
| 107 |
+
struct DefaultConv2dFpropFusion <
|
| 108 |
+
ElementA,
|
| 109 |
+
LayoutA,
|
| 110 |
+
ElementB,
|
| 111 |
+
LayoutB,
|
| 112 |
+
ElementScaleBias,
|
| 113 |
+
LayoutScaleBias,
|
| 114 |
+
ElementC,
|
| 115 |
+
LayoutC,
|
| 116 |
+
ElementAccumulator,
|
| 117 |
+
arch::OpClassTensorOp,
|
| 118 |
+
ArchTag,
|
| 119 |
+
ThreadblockShape,
|
| 120 |
+
WarpShape,
|
| 121 |
+
InstructionShape,
|
| 122 |
+
EpilogueOutputOp,
|
| 123 |
+
ThreadblockSwizzle,
|
| 124 |
+
Stages,
|
| 125 |
+
MathOperatorTag,
|
| 126 |
+
IteratorAlgorithm::kAnalytic
|
| 127 |
+
> {
|
| 128 |
+
|
| 129 |
+
// Define the core components from GEMM
|
| 130 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 131 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 132 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 133 |
+
Stages, MathOperatorTag>;
|
| 134 |
+
|
| 135 |
+
// Define iterators over tiles from the A operand
|
| 136 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 137 |
+
using IteratorA =
|
| 138 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 139 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 140 |
+
ElementA, LayoutA,
|
| 141 |
+
ThreadMapA
|
| 142 |
+
>;
|
| 143 |
+
|
| 144 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 145 |
+
|
| 146 |
+
// Define iterators over tiles from the B operand
|
| 147 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 148 |
+
using IteratorB =
|
| 149 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 150 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 151 |
+
ElementB, LayoutB,
|
| 152 |
+
ThreadMapB
|
| 153 |
+
>;
|
| 154 |
+
|
| 155 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 156 |
+
|
| 157 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 158 |
+
using IteratorScaleBias =
|
| 159 |
+
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
|
| 160 |
+
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
| 161 |
+
LayoutScaleBias>;
|
| 162 |
+
|
| 163 |
+
using SmemIteratorScaleBias =
|
| 164 |
+
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
|
| 165 |
+
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
| 166 |
+
LayoutScaleBias>;
|
| 167 |
+
|
| 168 |
+
// Warp-level GEMM components
|
| 169 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 170 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 171 |
+
|
| 172 |
+
static int const kThreadCount = 32;
|
| 173 |
+
|
| 174 |
+
// Warp-level iterators to load scale and bias vectors
|
| 175 |
+
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
|
| 176 |
+
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
| 177 |
+
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 178 |
+
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
| 179 |
+
MmaCore::WarpCount::kK>;
|
| 180 |
+
|
| 181 |
+
// Define the Mma
|
| 182 |
+
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
|
| 183 |
+
ThreadblockShape,
|
| 184 |
+
IteratorA,
|
| 185 |
+
SmemIteratorA,
|
| 186 |
+
arch::CacheOperation::Always,
|
| 187 |
+
IteratorB,
|
| 188 |
+
SmemIteratorB,
|
| 189 |
+
arch::CacheOperation::Global,
|
| 190 |
+
IteratorScaleBias,
|
| 191 |
+
SmemIteratorScaleBias,
|
| 192 |
+
arch::CacheOperation::Always,
|
| 193 |
+
MmaPolicy,
|
| 194 |
+
WarpIteratorScaleBias,
|
| 195 |
+
Stages
|
| 196 |
+
>;
|
| 197 |
+
|
| 198 |
+
// Define the epilogue
|
| 199 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 200 |
+
ThreadblockShape,
|
| 201 |
+
WarpMmaTensorOp,
|
| 202 |
+
1,
|
| 203 |
+
EpilogueOutputOp,
|
| 204 |
+
EpilogueOutputOp::kCount
|
| 205 |
+
>::Epilogue;
|
| 206 |
+
|
| 207 |
+
// Define the kernel
|
| 208 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
| 209 |
+
Mma,
|
| 210 |
+
Epilogue,
|
| 211 |
+
ThreadblockSwizzle,
|
| 212 |
+
conv::Operator::kFprop
|
| 213 |
+
>;
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 217 |
+
|
| 218 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 219 |
+
/// multistage pipeline.
|
| 220 |
+
template <
|
| 221 |
+
typename ElementA,
|
| 222 |
+
typename LayoutA,
|
| 223 |
+
typename ElementB,
|
| 224 |
+
typename LayoutB,
|
| 225 |
+
typename ElementScaleBias,
|
| 226 |
+
typename LayoutScaleBias,
|
| 227 |
+
typename ElementC,
|
| 228 |
+
typename LayoutC,
|
| 229 |
+
typename ElementAccumulator,
|
| 230 |
+
typename ArchTag,
|
| 231 |
+
typename ThreadblockShape,
|
| 232 |
+
typename WarpShape,
|
| 233 |
+
typename InstructionShape,
|
| 234 |
+
typename EpilogueOutputOp,
|
| 235 |
+
typename ThreadblockSwizzle,
|
| 236 |
+
int Stages,
|
| 237 |
+
typename MathOperatorTag
|
| 238 |
+
>
|
| 239 |
+
struct DefaultConv2dFpropFusion <
|
| 240 |
+
ElementA,
|
| 241 |
+
LayoutA,
|
| 242 |
+
ElementB,
|
| 243 |
+
LayoutB,
|
| 244 |
+
ElementScaleBias,
|
| 245 |
+
LayoutScaleBias,
|
| 246 |
+
ElementC,
|
| 247 |
+
LayoutC,
|
| 248 |
+
ElementAccumulator,
|
| 249 |
+
arch::OpClassTensorOp,
|
| 250 |
+
ArchTag,
|
| 251 |
+
ThreadblockShape,
|
| 252 |
+
WarpShape,
|
| 253 |
+
InstructionShape,
|
| 254 |
+
EpilogueOutputOp,
|
| 255 |
+
ThreadblockSwizzle,
|
| 256 |
+
Stages,
|
| 257 |
+
MathOperatorTag,
|
| 258 |
+
IteratorAlgorithm::kOptimized
|
| 259 |
+
> {
|
| 260 |
+
|
| 261 |
+
// Define the core components from GEMM
|
| 262 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 263 |
+
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
| 264 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 265 |
+
Stages, MathOperatorTag
|
| 266 |
+
>;
|
| 267 |
+
|
| 268 |
+
// Define iterators over tiles from the A operand
|
| 269 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 270 |
+
using IteratorA =
|
| 271 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 272 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 273 |
+
ElementA,
|
| 274 |
+
LayoutA,
|
| 275 |
+
ThreadMapA
|
| 276 |
+
>;
|
| 277 |
+
|
| 278 |
+
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
| 279 |
+
|
| 280 |
+
// Define iterators over tiles from the B operand
|
| 281 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 282 |
+
using IteratorB =
|
| 283 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 284 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 285 |
+
ElementB,
|
| 286 |
+
LayoutB,
|
| 287 |
+
ThreadMapB
|
| 288 |
+
>;
|
| 289 |
+
|
| 290 |
+
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
| 291 |
+
|
| 292 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 293 |
+
using IteratorScaleBias =
|
| 294 |
+
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
|
| 295 |
+
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
| 296 |
+
LayoutScaleBias>;
|
| 297 |
+
|
| 298 |
+
using SmemIteratorScaleBias =
|
| 299 |
+
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
|
| 300 |
+
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
| 301 |
+
LayoutScaleBias>;
|
| 302 |
+
|
| 303 |
+
// Warp-level GEMM components
|
| 304 |
+
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
| 305 |
+
using MmaPolicy = typename MmaCore::MmaPolicy;
|
| 306 |
+
|
| 307 |
+
static int const kThreadCount = 32;
|
| 308 |
+
|
| 309 |
+
// Warp-level iterators to load scale and bias vectors
|
| 310 |
+
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
|
| 311 |
+
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
| 312 |
+
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 313 |
+
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
| 314 |
+
MmaCore::WarpCount::kK>;
|
| 315 |
+
|
| 316 |
+
// Define the Mma
|
| 317 |
+
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
|
| 318 |
+
ThreadblockShape,
|
| 319 |
+
IteratorA,
|
| 320 |
+
SmemIteratorA,
|
| 321 |
+
arch::CacheOperation::Always,
|
| 322 |
+
IteratorB,
|
| 323 |
+
SmemIteratorB,
|
| 324 |
+
arch::CacheOperation::Global,
|
| 325 |
+
IteratorScaleBias,
|
| 326 |
+
SmemIteratorScaleBias,
|
| 327 |
+
arch::CacheOperation::Always,
|
| 328 |
+
MmaPolicy,
|
| 329 |
+
WarpIteratorScaleBias,
|
| 330 |
+
Stages
|
| 331 |
+
>;
|
| 332 |
+
|
| 333 |
+
// Define the epilogue
|
| 334 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 335 |
+
ThreadblockShape,
|
| 336 |
+
WarpMmaTensorOp,
|
| 337 |
+
1,
|
| 338 |
+
EpilogueOutputOp,
|
| 339 |
+
EpilogueOutputOp::kCount
|
| 340 |
+
>::Epilogue;
|
| 341 |
+
|
| 342 |
+
// Define the kernel
|
| 343 |
+
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
| 344 |
+
Mma,
|
| 345 |
+
Epilogue,
|
| 346 |
+
ThreadblockSwizzle,
|
| 347 |
+
conv::Operator::kFprop
|
| 348 |
+
>;
|
| 349 |
+
};
|
| 350 |
+
|
| 351 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 352 |
+
|
| 353 |
+
} // namespace kernel
|
| 354 |
+
} // namespace conv
|
| 355 |
+
} // namespace cutlass
|
| 356 |
+
|
| 357 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|