Kernels:
Trusted publisher
Uploaded using `kernel-builder` (batch 8/32).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h +0 -222
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp +0 -67
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h +0 -324
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp +0 -528
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h +0 -143
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h +0 -187
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h +0 -402
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h +0 -644
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h +0 -375
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h +0 -141
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h +0 -276
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h +0 -573
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h +0 -144
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h +0 -186
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h +0 -127
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h +0 -157
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h +0 -69
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp +0 -369
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp +0 -116
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h +0 -111
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h +0 -541
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h +0 -591
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h +0 -157
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h +0 -38
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +0 -472
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp +0 -570
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp +0 -341
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h +0 -135
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h +0 -94
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h +0 -1549
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h +0 -385
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h +0 -350
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +0 -311
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp +0 -146
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +0 -162
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h +0 -168
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +0 -159
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h +0 -355
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h +0 -250
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h +0 -2075
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +0 -142
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h +0 -514
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h +0 -141
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h +0 -186
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp +0 -782
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +0 -802
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h +0 -66
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h +0 -531
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h +0 -210
- build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +0 -228
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/* \file
|
| 32 |
-
\brief Defines a math function
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
*/
|
| 36 |
-
|
| 37 |
-
#pragma once
|
| 38 |
-
|
| 39 |
-
#include <vector>
|
| 40 |
-
#include <string>
|
| 41 |
-
#include <memory>
|
| 42 |
-
#include <algorithm>
|
| 43 |
-
#include <unordered_map>
|
| 44 |
-
|
| 45 |
-
// CUTLASS Library includes
|
| 46 |
-
#include "cutlass/blas3.h"
|
| 47 |
-
#include "cutlass/library/library.h"
|
| 48 |
-
#include "cutlass/library/util.h"
|
| 49 |
-
#include "cutlass/library/manifest.h"
|
| 50 |
-
|
| 51 |
-
// Profiler includes
|
| 52 |
-
#include "options.h"
|
| 53 |
-
#include "device_context.h"
|
| 54 |
-
#include "operation_profiler.h"
|
| 55 |
-
#include "performance_result.h"
|
| 56 |
-
#include "problem_space.h"
|
| 57 |
-
|
| 58 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
-
|
| 60 |
-
namespace cutlass {
|
| 61 |
-
namespace profiler {
|
| 62 |
-
|
| 63 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
-
|
| 65 |
-
/// Abstract base class for each math function
|
| 66 |
-
class TrmmOperationProfiler : public OperationProfiler {
|
| 67 |
-
public:
|
| 68 |
-
|
| 69 |
-
/// Problem structure obtained from problem space
|
| 70 |
-
struct TrmmProblem {
|
| 71 |
-
int64_t m;
|
| 72 |
-
int64_t n;
|
| 73 |
-
int64_t lda;
|
| 74 |
-
int64_t ldb;
|
| 75 |
-
int64_t ldd;
|
| 76 |
-
SideMode side_mode;
|
| 77 |
-
FillMode fill_mode;
|
| 78 |
-
DiagType diag_type;
|
| 79 |
-
std::vector<uint8_t> alpha;
|
| 80 |
-
std::vector<uint8_t> beta;
|
| 81 |
-
int64_t split_k_slices;
|
| 82 |
-
int64_t batch_count;
|
| 83 |
-
|
| 84 |
-
//
|
| 85 |
-
// Methods
|
| 86 |
-
//
|
| 87 |
-
|
| 88 |
-
TrmmProblem():
|
| 89 |
-
m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { }
|
| 90 |
-
|
| 91 |
-
/// Parses the problem
|
| 92 |
-
Status parse(
|
| 93 |
-
library::TrmmDescription const &operation_desc,
|
| 94 |
-
ProblemSpace const &problem_space,
|
| 95 |
-
ProblemSpace::Problem const &problem);
|
| 96 |
-
|
| 97 |
-
/// Initializes a performance result
|
| 98 |
-
void initialize_result(
|
| 99 |
-
PerformanceResult &result,
|
| 100 |
-
library::TrmmDescription const &operation_desc,
|
| 101 |
-
ProblemSpace const &problem_space);
|
| 102 |
-
};
|
| 103 |
-
|
| 104 |
-
/// Workspace used
|
| 105 |
-
struct TrmmWorkspace {
|
| 106 |
-
|
| 107 |
-
DeviceAllocation *A;
|
| 108 |
-
DeviceAllocation *B;
|
| 109 |
-
DeviceAllocation *D;
|
| 110 |
-
DeviceAllocation *Computed;
|
| 111 |
-
DeviceAllocation *Reference;
|
| 112 |
-
|
| 113 |
-
library::TrmmConfiguration configuration;
|
| 114 |
-
library::TrmmArguments arguments;
|
| 115 |
-
|
| 116 |
-
/// Buffer used for the operation's host workspace
|
| 117 |
-
std::vector<uint8_t> host_workspace;
|
| 118 |
-
|
| 119 |
-
/// Buffer used for the operations' device workspace
|
| 120 |
-
DeviceAllocation device_workspace;
|
| 121 |
-
|
| 122 |
-
//
|
| 123 |
-
// Methods
|
| 124 |
-
//
|
| 125 |
-
|
| 126 |
-
TrmmWorkspace():
|
| 127 |
-
A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { }
|
| 128 |
-
};
|
| 129 |
-
|
| 130 |
-
protected:
|
| 131 |
-
|
| 132 |
-
//
|
| 133 |
-
// Data members
|
| 134 |
-
//
|
| 135 |
-
|
| 136 |
-
/// GEMM problem obtained from problem space
|
| 137 |
-
TrmmProblem problem_;
|
| 138 |
-
|
| 139 |
-
/// Device memory allocations
|
| 140 |
-
TrmmWorkspace trmm_workspace_;
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
public:
|
| 144 |
-
//
|
| 145 |
-
// Methods
|
| 146 |
-
//
|
| 147 |
-
|
| 148 |
-
/// Ctor
|
| 149 |
-
TrmmOperationProfiler(Options const &options);
|
| 150 |
-
|
| 151 |
-
/// Destructor
|
| 152 |
-
virtual ~TrmmOperationProfiler();
|
| 153 |
-
|
| 154 |
-
/// Prints usage statement for the math function
|
| 155 |
-
virtual void print_usage(std::ostream &out) const;
|
| 156 |
-
|
| 157 |
-
/// Prints examples
|
| 158 |
-
virtual void print_examples(std::ostream &out) const;
|
| 159 |
-
|
| 160 |
-
/// Extracts the problem dimensions
|
| 161 |
-
virtual Status initialize_configuration(
|
| 162 |
-
Options const &options,
|
| 163 |
-
PerformanceReport &report,
|
| 164 |
-
DeviceContext &device_context,
|
| 165 |
-
library::Operation const *operation,
|
| 166 |
-
ProblemSpace const &problem_space,
|
| 167 |
-
ProblemSpace::Problem const &problem);
|
| 168 |
-
|
| 169 |
-
/// Initializes workspace
|
| 170 |
-
virtual Status initialize_workspace(
|
| 171 |
-
Options const &options,
|
| 172 |
-
PerformanceReport &report,
|
| 173 |
-
DeviceContext &device_context,
|
| 174 |
-
library::Operation const *operation,
|
| 175 |
-
ProblemSpace const &problem_space,
|
| 176 |
-
ProblemSpace::Problem const &problem);
|
| 177 |
-
|
| 178 |
-
/// Verifies CUTLASS against references
|
| 179 |
-
virtual bool verify_cutlass(
|
| 180 |
-
Options const &options,
|
| 181 |
-
PerformanceReport &report,
|
| 182 |
-
DeviceContext &device_context,
|
| 183 |
-
library::Operation const *operation,
|
| 184 |
-
ProblemSpace const &problem_space,
|
| 185 |
-
ProblemSpace::Problem const &problem);
|
| 186 |
-
|
| 187 |
-
/// Measures performance results
|
| 188 |
-
virtual bool profile(
|
| 189 |
-
Options const &options,
|
| 190 |
-
PerformanceReport &report,
|
| 191 |
-
DeviceContext &device_context,
|
| 192 |
-
library::Operation const *operation,
|
| 193 |
-
ProblemSpace const &problem_space,
|
| 194 |
-
ProblemSpace::Problem const &problem);
|
| 195 |
-
|
| 196 |
-
protected:
|
| 197 |
-
|
| 198 |
-
/// Initializes the performance result
|
| 199 |
-
void initialize_result_(
|
| 200 |
-
PerformanceResult &result,
|
| 201 |
-
Options const &options,
|
| 202 |
-
library::TrmmDescription const &operation_desc,
|
| 203 |
-
ProblemSpace const &problem_space);
|
| 204 |
-
|
| 205 |
-
/// Verifies CUTLASS against references
|
| 206 |
-
bool verify_with_cublas_(
|
| 207 |
-
Options const &options,
|
| 208 |
-
PerformanceReport &report,
|
| 209 |
-
DeviceContext &device_context,
|
| 210 |
-
library::Operation const *operation,
|
| 211 |
-
ProblemSpace const &problem_space,
|
| 212 |
-
ProblemSpace::Problem const &problem);
|
| 213 |
-
|
| 214 |
-
};
|
| 215 |
-
|
| 216 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 217 |
-
|
| 218 |
-
} // namespace profiler
|
| 219 |
-
} // namespace cutlass
|
| 220 |
-
|
| 221 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <cuda_runtime.h>
|
| 35 |
-
|
| 36 |
-
struct GPU_Clock
|
| 37 |
-
{
|
| 38 |
-
GPU_Clock() {
|
| 39 |
-
cudaEventCreate(&start_);
|
| 40 |
-
cudaEventCreate(&stop_);
|
| 41 |
-
cudaEventRecord(start_);
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
~GPU_Clock() {
|
| 45 |
-
cudaEventDestroy(start_);
|
| 46 |
-
cudaEventDestroy(stop_);
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
void start() {
|
| 50 |
-
cudaEventRecord(start_);
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
float milliseconds() {
|
| 54 |
-
cudaEventRecord(stop_);
|
| 55 |
-
cudaEventSynchronize(stop_);
|
| 56 |
-
float time;
|
| 57 |
-
cudaEventElapsedTime(&time, start_, stop_);
|
| 58 |
-
return time;
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
float seconds() {
|
| 62 |
-
return milliseconds() * float(1e-3);
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
private:
|
| 66 |
-
cudaEvent_t start_, stop_;
|
| 67 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h
DELETED
|
@@ -1,324 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* Utility for parsing command line arguments
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include <iostream>
|
| 40 |
-
#include <limits>
|
| 41 |
-
#include <sstream>
|
| 42 |
-
#include <string>
|
| 43 |
-
#include <vector>
|
| 44 |
-
#include <unordered_map>
|
| 45 |
-
|
| 46 |
-
#include <cuda_runtime.h>
|
| 47 |
-
|
| 48 |
-
#include "cutlass/cutlass.h"
|
| 49 |
-
|
| 50 |
-
namespace cutlass {
|
| 51 |
-
|
| 52 |
-
/******************************************************************************
|
| 53 |
-
* command_line
|
| 54 |
-
******************************************************************************/
|
| 55 |
-
|
| 56 |
-
/**
|
| 57 |
-
* Utility for parsing command line arguments
|
| 58 |
-
*/
|
| 59 |
-
struct CommandLine {
|
| 60 |
-
std::vector<std::string> keys;
|
| 61 |
-
std::vector<std::string> values;
|
| 62 |
-
std::vector<std::string> args;
|
| 63 |
-
|
| 64 |
-
/**
|
| 65 |
-
* Constructor
|
| 66 |
-
*/
|
| 67 |
-
CommandLine(int argc, const char** argv) {
|
| 68 |
-
using namespace std;
|
| 69 |
-
|
| 70 |
-
for (int i = 1; i < argc; i++) {
|
| 71 |
-
string arg = argv[i];
|
| 72 |
-
|
| 73 |
-
if ((arg[0] != '-') || (arg[1] != '-')) {
|
| 74 |
-
args.push_back(arg);
|
| 75 |
-
continue;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
string::size_type pos;
|
| 79 |
-
string key, val;
|
| 80 |
-
if ((pos = arg.find('=')) == string::npos) {
|
| 81 |
-
key = string(arg, 2, arg.length() - 2);
|
| 82 |
-
val = "";
|
| 83 |
-
} else {
|
| 84 |
-
key = string(arg, 2, pos - 2);
|
| 85 |
-
val = string(arg, pos + 1, arg.length() - 1);
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
keys.push_back(key);
|
| 89 |
-
values.push_back(val);
|
| 90 |
-
}
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
/**
|
| 94 |
-
* Constructor to represent a command line from a map of [argument] -> [value]
|
| 95 |
-
*/
|
| 96 |
-
CommandLine(std::unordered_map<std::string, std::string>& arg_map) {
|
| 97 |
-
for (const auto& [key, value] : arg_map) {
|
| 98 |
-
keys.push_back(key);
|
| 99 |
-
values.push_back(value);
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
/**
|
| 104 |
-
* Checks whether a flag "--<flag>" is present in the commandline
|
| 105 |
-
*/
|
| 106 |
-
bool check_cmd_line_flag(const char* arg_name) const {
|
| 107 |
-
using namespace std;
|
| 108 |
-
|
| 109 |
-
for (int i = 0; i < int(keys.size()); ++i) {
|
| 110 |
-
if (keys[i] == string(arg_name)) return true;
|
| 111 |
-
}
|
| 112 |
-
return false;
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
/**
|
| 116 |
-
* Returns number of naked (non-flag and non-key-value) commandline parameters
|
| 117 |
-
*/
|
| 118 |
-
size_t num_naked_args() const {
|
| 119 |
-
return args.size();
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
/**
|
| 123 |
-
* Print naked (non-flag and non-key-value) commandline parameters
|
| 124 |
-
*/
|
| 125 |
-
void print_naked_args(std::ostream &out) const {
|
| 126 |
-
for (auto arg : args) {
|
| 127 |
-
out << " " << arg <<"\n";
|
| 128 |
-
}
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
/**
|
| 132 |
-
* Returns the commandline parameter for a given index (not including flags)
|
| 133 |
-
*/
|
| 134 |
-
template <typename value_t>
|
| 135 |
-
void get_cmd_line_argument(size_t index, value_t& val) const {
|
| 136 |
-
using namespace std;
|
| 137 |
-
if (index < args.size()) {
|
| 138 |
-
istringstream str_stream(args[index]);
|
| 139 |
-
str_stream >> val;
|
| 140 |
-
}
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
/**
|
| 144 |
-
* Obtains the boolean value specified for a given commandline parameter --<flag>=<bool>
|
| 145 |
-
*/
|
| 146 |
-
void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const {
|
| 147 |
-
val = _default;
|
| 148 |
-
if (check_cmd_line_flag(arg_name)) {
|
| 149 |
-
std::string value;
|
| 150 |
-
get_cmd_line_argument(arg_name, value);
|
| 151 |
-
|
| 152 |
-
val = !(value == "0" || value == "false");
|
| 153 |
-
}
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
/**
|
| 157 |
-
* Obtains the value specified for a given commandline parameter --<flag>=<value>
|
| 158 |
-
*/
|
| 159 |
-
template <typename value_t>
|
| 160 |
-
void get_cmd_line_argument(const char* arg_name,
|
| 161 |
-
value_t& val) const {
|
| 162 |
-
|
| 163 |
-
get_cmd_line_argument(arg_name, val, val);
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
/**
|
| 167 |
-
* Obtains the value specified for a given commandline parameter --<flag>=<value>
|
| 168 |
-
*/
|
| 169 |
-
template <typename value_t>
|
| 170 |
-
void get_cmd_line_argument(const char* arg_name,
|
| 171 |
-
value_t& val,
|
| 172 |
-
value_t const& _default) const {
|
| 173 |
-
using namespace std;
|
| 174 |
-
|
| 175 |
-
val = _default;
|
| 176 |
-
|
| 177 |
-
for (int i = 0; i < int(keys.size()); ++i) {
|
| 178 |
-
if (keys[i] == string(arg_name)) {
|
| 179 |
-
istringstream str_stream(values[i]);
|
| 180 |
-
str_stream >> val;
|
| 181 |
-
}
|
| 182 |
-
}
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
/**
|
| 186 |
-
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
|
| 187 |
-
*/
|
| 188 |
-
template <typename value_t>
|
| 189 |
-
void get_cmd_line_arguments(const char* arg_name,
|
| 190 |
-
std::vector<value_t>& vals,
|
| 191 |
-
char sep = ',') const {
|
| 192 |
-
using namespace std;
|
| 193 |
-
|
| 194 |
-
if (check_cmd_line_flag(arg_name)) {
|
| 195 |
-
// Clear any default values
|
| 196 |
-
vals.clear();
|
| 197 |
-
|
| 198 |
-
// Recover from multi-value string
|
| 199 |
-
for (size_t i = 0; i < keys.size(); ++i) {
|
| 200 |
-
if (keys[i] == string(arg_name)) {
|
| 201 |
-
string val_string(values[i]);
|
| 202 |
-
separate_string(val_string, vals, sep);
|
| 203 |
-
}
|
| 204 |
-
}
|
| 205 |
-
}
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
/**
|
| 209 |
-
* Returns the values specified for a given commandline parameter
|
| 210 |
-
* --<flag>=<value>,<value_start:value_end>*
|
| 211 |
-
*/
|
| 212 |
-
void get_cmd_line_argument_pairs(const char* arg_name,
|
| 213 |
-
std::vector<std::pair<std::string, std::string> >& tokens,
|
| 214 |
-
char delim = ',',
|
| 215 |
-
char sep = ':') const {
|
| 216 |
-
if (check_cmd_line_flag(arg_name)) {
|
| 217 |
-
std::string value;
|
| 218 |
-
get_cmd_line_argument(arg_name, value);
|
| 219 |
-
|
| 220 |
-
tokenize(tokens, value, delim, sep);
|
| 221 |
-
}
|
| 222 |
-
}
|
| 223 |
-
|
| 224 |
-
/**
|
| 225 |
-
* Returns a list of ranges specified for a given commandline parameter
|
| 226 |
-
* --<flag>=<key:value>,<key:value>*
|
| 227 |
-
*/
|
| 228 |
-
void get_cmd_line_argument_ranges(const char* arg_name,
|
| 229 |
-
std::vector<std::vector<std::string> >& vals,
|
| 230 |
-
char delim = ',',
|
| 231 |
-
char sep = ':') const {
|
| 232 |
-
std::vector<std::string> ranges;
|
| 233 |
-
get_cmd_line_arguments(arg_name, ranges, delim);
|
| 234 |
-
|
| 235 |
-
for (std::vector<std::string>::const_iterator range = ranges.begin();
|
| 236 |
-
range != ranges.end(); ++range) {
|
| 237 |
-
|
| 238 |
-
std::vector<std::string> range_vals;
|
| 239 |
-
separate_string(*range, range_vals, sep);
|
| 240 |
-
vals.push_back(range_vals);
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
/**
|
| 245 |
-
* The number of pairs parsed
|
| 246 |
-
*/
|
| 247 |
-
int parsed_argc() const { return (int)keys.size(); }
|
| 248 |
-
|
| 249 |
-
//-------------------------------------------------------------------------
|
| 250 |
-
// Utility functions
|
| 251 |
-
//-------------------------------------------------------------------------
|
| 252 |
-
|
| 253 |
-
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
| 254 |
-
static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
|
| 255 |
-
std::string const& str,
|
| 256 |
-
char delim = ',',
|
| 257 |
-
char sep = ':') {
|
| 258 |
-
// Home-built to avoid Boost dependency
|
| 259 |
-
size_t s_idx = 0;
|
| 260 |
-
size_t d_idx = std::string::npos;
|
| 261 |
-
while (s_idx < str.size()) {
|
| 262 |
-
d_idx = str.find_first_of(delim, s_idx);
|
| 263 |
-
|
| 264 |
-
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
|
| 265 |
-
size_t sep_idx = str.find_first_of(sep, s_idx);
|
| 266 |
-
size_t offset = 1;
|
| 267 |
-
if (sep_idx == std::string::npos || sep_idx >= end_idx) {
|
| 268 |
-
sep_idx = end_idx;
|
| 269 |
-
offset = 0;
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
std::pair<std::string, std::string> item(
|
| 273 |
-
str.substr(s_idx, sep_idx - s_idx),
|
| 274 |
-
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
|
| 275 |
-
|
| 276 |
-
tokens.push_back(item);
|
| 277 |
-
s_idx = end_idx + 1;
|
| 278 |
-
}
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
| 282 |
-
static void tokenize(std::vector<std::string>& tokens,
|
| 283 |
-
std::string const& str,
|
| 284 |
-
char delim = ',',
|
| 285 |
-
char sep = ':') {
|
| 286 |
-
typedef std::vector<std::pair<std::string, std::string> > TokenVector;
|
| 287 |
-
typedef TokenVector::const_iterator token_iterator;
|
| 288 |
-
|
| 289 |
-
std::vector<std::pair<std::string, std::string> > token_pairs;
|
| 290 |
-
tokenize(token_pairs, str, delim, sep);
|
| 291 |
-
for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
|
| 292 |
-
tokens.push_back(tok->first);
|
| 293 |
-
}
|
| 294 |
-
}
|
| 295 |
-
|
| 296 |
-
template <typename value_t>
|
| 297 |
-
static void separate_string(std::string const& str,
|
| 298 |
-
std::vector<value_t>& vals,
|
| 299 |
-
char sep = ',') {
|
| 300 |
-
std::istringstream str_stream(str);
|
| 301 |
-
std::string::size_type old_pos = 0;
|
| 302 |
-
std::string::size_type new_pos = 0;
|
| 303 |
-
|
| 304 |
-
// Iterate <sep>-delimited values
|
| 305 |
-
value_t val;
|
| 306 |
-
while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
|
| 307 |
-
if (new_pos != old_pos) {
|
| 308 |
-
str_stream.width(new_pos - old_pos);
|
| 309 |
-
str_stream >> val;
|
| 310 |
-
vals.push_back(val);
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
// skip over delimiter
|
| 314 |
-
str_stream.ignore(1);
|
| 315 |
-
old_pos = new_pos + 1;
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
// Read last value
|
| 319 |
-
str_stream >> val;
|
| 320 |
-
vals.push_back(val);
|
| 321 |
-
}
|
| 322 |
-
};
|
| 323 |
-
|
| 324 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp
DELETED
|
@@ -1,528 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <cuda_runtime.h>
|
| 35 |
-
#include <cublas_v2.h>
|
| 36 |
-
|
| 37 |
-
//-- BLAM_DEBUG_OUT ---------------------------------------------------------
|
| 38 |
-
#ifdef BLAM_DEBUG
|
| 39 |
-
# include <iostream>
|
| 40 |
-
# ifndef BLAM_DEBUG_OUT
|
| 41 |
-
# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl
|
| 42 |
-
# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl
|
| 43 |
-
# endif // BLAM_DEBUG_OUT
|
| 44 |
-
#else
|
| 45 |
-
# ifndef BLAM_DEBUG_OUT
|
| 46 |
-
# define BLAM_DEBUG_OUT(msg)
|
| 47 |
-
# define BLAM_DEBUG_OUT_2(msg)
|
| 48 |
-
# endif // BLAM_DEBUG_OUT
|
| 49 |
-
#endif // BLAM_DEBUG
|
| 50 |
-
|
| 51 |
-
// User could potentially define ComplexFloat/ComplexDouble instead of std::
|
| 52 |
-
#ifndef BLAM_COMPLEX_TYPES
|
| 53 |
-
#define BLAM_COMPLEX_TYPES 1
|
| 54 |
-
#include "cutlass/cutlass.h"
|
| 55 |
-
#include CUDA_STD_HEADER(complex)
|
| 56 |
-
|
| 57 |
-
namespace blam {
|
| 58 |
-
template <typename T>
|
| 59 |
-
using Complex = cuda::std::complex<T>;
|
| 60 |
-
using ComplexFloat = cuda::std::complex<float>;
|
| 61 |
-
using ComplexDouble = cuda::std::complex<double>;
|
| 62 |
-
}
|
| 63 |
-
#endif // BLAM_COMPLEX_TYPES
|
| 64 |
-
|
| 65 |
-
// User could potentially define Half instead of cute::
|
| 66 |
-
#ifndef BLAM_HALF_TYPE
|
| 67 |
-
#define BLAM_HALF_TYPE 1
|
| 68 |
-
#include <cute/numeric/numeric_types.hpp>
|
| 69 |
-
namespace blam {
|
| 70 |
-
using Half = cute::half_t;
|
| 71 |
-
}
|
| 72 |
-
#endif // BLAM_HALF_TYPE
|
| 73 |
-
|
| 74 |
-
namespace blam
|
| 75 |
-
{
|
| 76 |
-
namespace cublas
|
| 77 |
-
{
|
| 78 |
-
|
| 79 |
-
inline const char*
|
| 80 |
-
cublas_get_error(cublasStatus_t status)
|
| 81 |
-
{
|
| 82 |
-
switch (status) {
|
| 83 |
-
case CUBLAS_STATUS_SUCCESS:
|
| 84 |
-
return "CUBLAS_STATUS_SUCCESS";
|
| 85 |
-
case CUBLAS_STATUS_NOT_INITIALIZED:
|
| 86 |
-
return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized.";
|
| 87 |
-
case CUBLAS_STATUS_ALLOC_FAILED:
|
| 88 |
-
return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library.";
|
| 89 |
-
case CUBLAS_STATUS_INVALID_VALUE:
|
| 90 |
-
return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function.";
|
| 91 |
-
case CUBLAS_STATUS_ARCH_MISMATCH:
|
| 92 |
-
return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture.";
|
| 93 |
-
case CUBLAS_STATUS_MAPPING_ERROR:
|
| 94 |
-
return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed.";
|
| 95 |
-
case CUBLAS_STATUS_EXECUTION_FAILED:
|
| 96 |
-
return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute.";
|
| 97 |
-
case CUBLAS_STATUS_INTERNAL_ERROR:
|
| 98 |
-
return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed.";
|
| 99 |
-
case CUBLAS_STATUS_NOT_SUPPORTED:
|
| 100 |
-
return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported.";
|
| 101 |
-
case CUBLAS_STATUS_LICENSE_ERROR:
|
| 102 |
-
return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing.";
|
| 103 |
-
default:
|
| 104 |
-
return "CUBLAS_ERROR -- <unknown>";
|
| 105 |
-
}
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
inline bool
|
| 109 |
-
cublas_is_error(cublasStatus_t status)
|
| 110 |
-
{
|
| 111 |
-
return status != CUBLAS_STATUS_SUCCESS;
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
// hgemm
|
| 116 |
-
inline cublasStatus_t
|
| 117 |
-
gemm(cublasHandle_t handle,
|
| 118 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 119 |
-
int m, int n, int k,
|
| 120 |
-
const Half* alpha,
|
| 121 |
-
const Half* A, int ldA,
|
| 122 |
-
const Half* B, int ldB,
|
| 123 |
-
const Half* beta,
|
| 124 |
-
Half* C, int ldC)
|
| 125 |
-
{
|
| 126 |
-
BLAM_DEBUG_OUT("cublasHgemm");
|
| 127 |
-
|
| 128 |
-
return cublasGemmEx(handle, transA, transB,
|
| 129 |
-
m, n, k,
|
| 130 |
-
reinterpret_cast<const __half*>(alpha),
|
| 131 |
-
reinterpret_cast<const __half*>(A), CUDA_R_16F, ldA,
|
| 132 |
-
reinterpret_cast<const __half*>(B), CUDA_R_16F, ldB,
|
| 133 |
-
reinterpret_cast<const __half*>(beta),
|
| 134 |
-
reinterpret_cast< __half*>(C), CUDA_R_16F, ldC,
|
| 135 |
-
CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
// mixed hf gemm
|
| 139 |
-
inline cublasStatus_t
|
| 140 |
-
gemm(cublasHandle_t handle,
|
| 141 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 142 |
-
int m, int n, int k,
|
| 143 |
-
const float* alpha,
|
| 144 |
-
const Half* A, int ldA,
|
| 145 |
-
const Half* B, int ldB,
|
| 146 |
-
const float* beta,
|
| 147 |
-
float* C, int ldC)
|
| 148 |
-
{
|
| 149 |
-
BLAM_DEBUG_OUT("cublasGemmEx mixed half-float");
|
| 150 |
-
|
| 151 |
-
return cublasGemmEx(handle, transA, transB,
|
| 152 |
-
m, n, k,
|
| 153 |
-
alpha,
|
| 154 |
-
reinterpret_cast<const __half*>(A), CUDA_R_16F, ldA,
|
| 155 |
-
reinterpret_cast<const __half*>(B), CUDA_R_16F, ldB,
|
| 156 |
-
beta,
|
| 157 |
-
C, CUDA_R_32F, ldC,
|
| 158 |
-
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
// igemm
|
| 162 |
-
inline cublasStatus_t
|
| 163 |
-
gemm(cublasHandle_t handle,
|
| 164 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 165 |
-
int m, int n, int k,
|
| 166 |
-
const int32_t* alpha,
|
| 167 |
-
const int8_t* A, int ldA,
|
| 168 |
-
const int8_t* B, int ldB,
|
| 169 |
-
const int32_t* beta,
|
| 170 |
-
int32_t* C, int ldC)
|
| 171 |
-
{
|
| 172 |
-
BLAM_DEBUG_OUT("cublasIgemm");
|
| 173 |
-
|
| 174 |
-
return cublasGemmEx(handle, transA, transB,
|
| 175 |
-
m, n, k,
|
| 176 |
-
alpha,
|
| 177 |
-
A, CUDA_R_8I, ldA,
|
| 178 |
-
B, CUDA_R_8I, ldB,
|
| 179 |
-
beta,
|
| 180 |
-
C, CUDA_R_32I, ldC,
|
| 181 |
-
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
// sgemm
|
| 185 |
-
inline cublasStatus_t
|
| 186 |
-
gemm(cublasHandle_t handle,
|
| 187 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 188 |
-
int m, int n, int k,
|
| 189 |
-
const float* alpha,
|
| 190 |
-
const float* A, int ldA,
|
| 191 |
-
const float* B, int ldB,
|
| 192 |
-
const float* beta,
|
| 193 |
-
float* C, int ldC)
|
| 194 |
-
{
|
| 195 |
-
BLAM_DEBUG_OUT("cublasSgemm");
|
| 196 |
-
|
| 197 |
-
return cublasSgemm(handle, transA, transB,
|
| 198 |
-
m, n, k,
|
| 199 |
-
alpha,
|
| 200 |
-
A, ldA,
|
| 201 |
-
B, ldB,
|
| 202 |
-
beta,
|
| 203 |
-
C, ldC);
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
// dgemm
|
| 207 |
-
inline cublasStatus_t
|
| 208 |
-
gemm(cublasHandle_t handle,
|
| 209 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 210 |
-
int m, int n, int k,
|
| 211 |
-
const double* alpha,
|
| 212 |
-
const double* A, int ldA,
|
| 213 |
-
const double* B, int ldB,
|
| 214 |
-
const double* beta,
|
| 215 |
-
double* C, int ldC)
|
| 216 |
-
{
|
| 217 |
-
BLAM_DEBUG_OUT("cublasDgemm");
|
| 218 |
-
|
| 219 |
-
return cublasDgemm(handle, transA, transB,
|
| 220 |
-
m, n, k,
|
| 221 |
-
alpha,
|
| 222 |
-
A, ldA,
|
| 223 |
-
B, ldB,
|
| 224 |
-
beta,
|
| 225 |
-
C, ldC);
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
// cgemm
|
| 229 |
-
inline cublasStatus_t
|
| 230 |
-
gemm(cublasHandle_t handle,
|
| 231 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 232 |
-
int m, int n, int k,
|
| 233 |
-
const ComplexFloat* alpha,
|
| 234 |
-
const ComplexFloat* A, int ldA,
|
| 235 |
-
const ComplexFloat* B, int ldB,
|
| 236 |
-
const ComplexFloat* beta,
|
| 237 |
-
ComplexFloat* C, int ldC)
|
| 238 |
-
{
|
| 239 |
-
BLAM_DEBUG_OUT("cublasCgemm");
|
| 240 |
-
|
| 241 |
-
return cublasCgemm(handle, transA, transB,
|
| 242 |
-
m, n, k,
|
| 243 |
-
reinterpret_cast<const cuFloatComplex*>(alpha),
|
| 244 |
-
reinterpret_cast<const cuFloatComplex*>(A), ldA,
|
| 245 |
-
reinterpret_cast<const cuFloatComplex*>(B), ldB,
|
| 246 |
-
reinterpret_cast<const cuFloatComplex*>(beta),
|
| 247 |
-
reinterpret_cast<cuFloatComplex*>(C), ldC);
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
// zgemm
|
| 251 |
-
inline cublasStatus_t
|
| 252 |
-
gemm(cublasHandle_t handle,
|
| 253 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 254 |
-
int m, int n, int k,
|
| 255 |
-
const ComplexDouble* alpha,
|
| 256 |
-
const ComplexDouble* A, int ldA,
|
| 257 |
-
const ComplexDouble* B, int ldB,
|
| 258 |
-
const ComplexDouble* beta,
|
| 259 |
-
ComplexDouble* C, int ldC)
|
| 260 |
-
{
|
| 261 |
-
BLAM_DEBUG_OUT("cublasZgemm");
|
| 262 |
-
|
| 263 |
-
return cublasZgemm(handle, transA, transB,
|
| 264 |
-
m, n, k,
|
| 265 |
-
reinterpret_cast<const cuDoubleComplex*>(alpha),
|
| 266 |
-
reinterpret_cast<const cuDoubleComplex*>(A), ldA,
|
| 267 |
-
reinterpret_cast<const cuDoubleComplex*>(B), ldB,
|
| 268 |
-
reinterpret_cast<const cuDoubleComplex*>(beta),
|
| 269 |
-
reinterpret_cast<cuDoubleComplex*>(C), ldC);
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
// hgemm
|
| 273 |
-
inline cublasStatus_t
|
| 274 |
-
gemm_batch(cublasHandle_t handle,
|
| 275 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 276 |
-
int m, int n, int k,
|
| 277 |
-
const Half* alpha,
|
| 278 |
-
const Half* A, int ldA, int loA,
|
| 279 |
-
const Half* B, int ldB, int loB,
|
| 280 |
-
const Half* beta,
|
| 281 |
-
Half* C, int ldC, int loC,
|
| 282 |
-
int batch_size)
|
| 283 |
-
{
|
| 284 |
-
BLAM_DEBUG_OUT("cublasHgemmStridedBatched");
|
| 285 |
-
|
| 286 |
-
return cublasHgemmStridedBatched(handle, transA, transB,
|
| 287 |
-
m, n, k,
|
| 288 |
-
reinterpret_cast<const __half*>(alpha),
|
| 289 |
-
reinterpret_cast<const __half*>(A), ldA, loA,
|
| 290 |
-
reinterpret_cast<const __half*>(B), ldB, loB,
|
| 291 |
-
reinterpret_cast<const __half*>(beta),
|
| 292 |
-
reinterpret_cast<__half*>(C), ldC, loC,
|
| 293 |
-
batch_size);
|
| 294 |
-
}
|
| 295 |
-
|
| 296 |
-
// sgemm
|
| 297 |
-
inline cublasStatus_t
|
| 298 |
-
gemm_batch(cublasHandle_t handle,
|
| 299 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 300 |
-
int m, int n, int k,
|
| 301 |
-
const float* alpha,
|
| 302 |
-
const float* A, int ldA, int loA,
|
| 303 |
-
const float* B, int ldB, int loB,
|
| 304 |
-
const float* beta,
|
| 305 |
-
float* C, int ldC, int loC,
|
| 306 |
-
int batch_size)
|
| 307 |
-
{
|
| 308 |
-
BLAM_DEBUG_OUT("cublasSgemmStridedBatched");
|
| 309 |
-
|
| 310 |
-
return cublasSgemmStridedBatched(handle, transA, transB,
|
| 311 |
-
m, n, k,
|
| 312 |
-
alpha,
|
| 313 |
-
A, ldA, loA,
|
| 314 |
-
B, ldB, loB,
|
| 315 |
-
beta,
|
| 316 |
-
C, ldC, loC,
|
| 317 |
-
batch_size);
|
| 318 |
-
}
|
| 319 |
-
|
| 320 |
-
// dgemm
|
| 321 |
-
inline cublasStatus_t
|
| 322 |
-
gemm_batch(cublasHandle_t handle,
|
| 323 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 324 |
-
int m, int n, int k,
|
| 325 |
-
const double* alpha,
|
| 326 |
-
const double* A, int ldA, int loA,
|
| 327 |
-
const double* B, int ldB, int loB,
|
| 328 |
-
const double* beta,
|
| 329 |
-
double* C, int ldC, int loC,
|
| 330 |
-
int batch_size)
|
| 331 |
-
{
|
| 332 |
-
BLAM_DEBUG_OUT("cublasDgemmStridedBatched");
|
| 333 |
-
|
| 334 |
-
return cublasDgemmStridedBatched(handle, transA, transB,
|
| 335 |
-
m, n, k,
|
| 336 |
-
alpha,
|
| 337 |
-
A, ldA, loA,
|
| 338 |
-
B, ldB, loB,
|
| 339 |
-
beta,
|
| 340 |
-
C, ldC, loC,
|
| 341 |
-
batch_size);
|
| 342 |
-
}
|
| 343 |
-
|
| 344 |
-
// cgemm
|
| 345 |
-
inline cublasStatus_t
|
| 346 |
-
gemm_batch(cublasHandle_t handle,
|
| 347 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 348 |
-
int m, int n, int k,
|
| 349 |
-
const ComplexFloat* alpha,
|
| 350 |
-
const ComplexFloat* A, int ldA, int loA,
|
| 351 |
-
const ComplexFloat* B, int ldB, int loB,
|
| 352 |
-
const ComplexFloat* beta,
|
| 353 |
-
ComplexFloat* C, int ldC, int loC,
|
| 354 |
-
int batch_size)
|
| 355 |
-
{
|
| 356 |
-
BLAM_DEBUG_OUT("cublasCgemmStridedBatched");
|
| 357 |
-
|
| 358 |
-
return cublasCgemmStridedBatched(handle, transA, transB,
|
| 359 |
-
m, n, k,
|
| 360 |
-
reinterpret_cast<const cuFloatComplex*>(alpha),
|
| 361 |
-
reinterpret_cast<const cuFloatComplex*>(A), ldA, loA,
|
| 362 |
-
reinterpret_cast<const cuFloatComplex*>(B), ldB, loB,
|
| 363 |
-
reinterpret_cast<const cuFloatComplex*>(beta),
|
| 364 |
-
reinterpret_cast<cuFloatComplex*>(C), ldC, loC,
|
| 365 |
-
batch_size);
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
// zgemm
|
| 369 |
-
inline cublasStatus_t
|
| 370 |
-
gemm_batch(cublasHandle_t handle,
|
| 371 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 372 |
-
int m, int n, int k,
|
| 373 |
-
const ComplexDouble* alpha,
|
| 374 |
-
const ComplexDouble* A, int ldA, int loA,
|
| 375 |
-
const ComplexDouble* B, int ldB, int loB,
|
| 376 |
-
const ComplexDouble* beta,
|
| 377 |
-
ComplexDouble* C, int ldC, int loC,
|
| 378 |
-
int batch_size)
|
| 379 |
-
{
|
| 380 |
-
BLAM_DEBUG_OUT("cublasZgemmStridedBatched");
|
| 381 |
-
|
| 382 |
-
return cublasZgemmStridedBatched(handle, transA, transB,
|
| 383 |
-
m, n, k,
|
| 384 |
-
reinterpret_cast<const cuDoubleComplex*>(alpha),
|
| 385 |
-
reinterpret_cast<const cuDoubleComplex*>(A), ldA, loA,
|
| 386 |
-
reinterpret_cast<const cuDoubleComplex*>(B), ldB, loB,
|
| 387 |
-
reinterpret_cast<const cuDoubleComplex*>(beta),
|
| 388 |
-
reinterpret_cast<cuDoubleComplex*>(C), ldC, loC,
|
| 389 |
-
batch_size);
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
// hgemm
|
| 393 |
-
inline cublasStatus_t
|
| 394 |
-
gemm_batch(cublasHandle_t handle,
|
| 395 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 396 |
-
int m, int n, int k,
|
| 397 |
-
const Half* alpha,
|
| 398 |
-
const Half* const A[], int ldA,
|
| 399 |
-
const Half* const B[], int ldB,
|
| 400 |
-
const Half* beta,
|
| 401 |
-
Half* const C[], int ldC,
|
| 402 |
-
int batch_size)
|
| 403 |
-
{
|
| 404 |
-
BLAM_DEBUG_OUT("cublasHgemmBatched");
|
| 405 |
-
|
| 406 |
-
return cublasHgemmBatched(handle, transA, transB,
|
| 407 |
-
m, n, k,
|
| 408 |
-
reinterpret_cast<const __half*>(alpha),
|
| 409 |
-
reinterpret_cast<const __half**>(const_cast<const Half**>(A)), ldA,
|
| 410 |
-
// A, ldA, // cuBLAS 9.2
|
| 411 |
-
reinterpret_cast<const __half**>(const_cast<const Half**>(B)), ldB,
|
| 412 |
-
// B, ldB, // cuBLAS 9.2
|
| 413 |
-
reinterpret_cast<const __half*>(beta),
|
| 414 |
-
reinterpret_cast<__half**>(const_cast<Half**>(C)), ldC,
|
| 415 |
-
// C, ldC, // cuBLAS 9.2
|
| 416 |
-
batch_size);
|
| 417 |
-
}
|
| 418 |
-
|
| 419 |
-
// sgemm
|
| 420 |
-
inline cublasStatus_t
|
| 421 |
-
gemm_batch(cublasHandle_t handle,
|
| 422 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 423 |
-
int m, int n, int k,
|
| 424 |
-
const float* alpha,
|
| 425 |
-
const float* const A[], int ldA,
|
| 426 |
-
const float* const B[], int ldB,
|
| 427 |
-
const float* beta,
|
| 428 |
-
float* const C[], int ldC,
|
| 429 |
-
int batch_size)
|
| 430 |
-
{
|
| 431 |
-
BLAM_DEBUG_OUT("cublasSgemmBatched");
|
| 432 |
-
|
| 433 |
-
return cublasSgemmBatched(handle, transA, transB,
|
| 434 |
-
m, n, k,
|
| 435 |
-
alpha,
|
| 436 |
-
const_cast<const float**>(A), ldA,
|
| 437 |
-
// A, ldA, // cuBLAS 9.2
|
| 438 |
-
const_cast<const float**>(B), ldB,
|
| 439 |
-
// B, ldB, // cuBLAS 9.2
|
| 440 |
-
beta,
|
| 441 |
-
const_cast<float**>(C), ldC,
|
| 442 |
-
// C, ldC, // cuBLAS 9.2
|
| 443 |
-
batch_size);
|
| 444 |
-
}
|
| 445 |
-
|
| 446 |
-
// dgemm
|
| 447 |
-
inline cublasStatus_t
|
| 448 |
-
gemm_batch(cublasHandle_t handle,
|
| 449 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 450 |
-
int m, int n, int k,
|
| 451 |
-
const double* alpha,
|
| 452 |
-
const double* const A[], int ldA,
|
| 453 |
-
const double* const B[], int ldB,
|
| 454 |
-
const double* beta,
|
| 455 |
-
double* const C[], int ldC,
|
| 456 |
-
int batch_size)
|
| 457 |
-
{
|
| 458 |
-
BLAM_DEBUG_OUT("cublasDgemmBatched");
|
| 459 |
-
|
| 460 |
-
return cublasDgemmBatched(handle, transA, transB,
|
| 461 |
-
m, n, k,
|
| 462 |
-
alpha,
|
| 463 |
-
const_cast<const double**>(A), ldA,
|
| 464 |
-
// A, ldA, // cuBLAS 9.2
|
| 465 |
-
const_cast<const double**>(B), ldB,
|
| 466 |
-
// B, ldB, // cuBLAS 9.2
|
| 467 |
-
beta,
|
| 468 |
-
const_cast<double**>(C), ldC,
|
| 469 |
-
// C, ldC, // cuBLAS 9.2
|
| 470 |
-
batch_size);
|
| 471 |
-
}
|
| 472 |
-
|
| 473 |
-
// cgemm
|
| 474 |
-
inline cublasStatus_t
|
| 475 |
-
gemm_batch(cublasHandle_t handle,
|
| 476 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 477 |
-
int m, int n, int k,
|
| 478 |
-
const ComplexFloat* alpha,
|
| 479 |
-
const ComplexFloat* const A[], int ldA,
|
| 480 |
-
const ComplexFloat* const B[], int ldB,
|
| 481 |
-
const ComplexFloat* beta,
|
| 482 |
-
ComplexFloat* const C[], int ldC,
|
| 483 |
-
int batch_size)
|
| 484 |
-
{
|
| 485 |
-
BLAM_DEBUG_OUT("cublasCgemmBatched");
|
| 486 |
-
|
| 487 |
-
return cublasCgemmBatched(handle, transA, transB,
|
| 488 |
-
m, n, k,
|
| 489 |
-
reinterpret_cast<const cuFloatComplex*>(alpha),
|
| 490 |
-
const_cast<const cuFloatComplex**>(reinterpret_cast<const cuFloatComplex* const *>(A)), ldA,
|
| 491 |
-
//reinterpret_cast<const cuFloatComplex* const *>(A), ldA, // cuBLAS 9.2
|
| 492 |
-
const_cast<const cuFloatComplex**>(reinterpret_cast<const cuFloatComplex* const *>(B)), ldB,
|
| 493 |
-
//reinterpret_cast<const cuFloatComplex* const *>(B), ldB, // cuBLAS 9.2
|
| 494 |
-
reinterpret_cast<const cuFloatComplex*>(beta),
|
| 495 |
-
const_cast<cuFloatComplex**>(reinterpret_cast<cuFloatComplex* const *>(C)), ldC,
|
| 496 |
-
//reinterpret_cast<cuFloatComplex* const *>(C), ldC, // cuBLAS 9.2
|
| 497 |
-
batch_size);
|
| 498 |
-
}
|
| 499 |
-
|
| 500 |
-
// zgemm
|
| 501 |
-
inline cublasStatus_t
|
| 502 |
-
gemm_batch(cublasHandle_t handle,
|
| 503 |
-
cublasOperation_t transA, cublasOperation_t transB,
|
| 504 |
-
int m, int n, int k,
|
| 505 |
-
const ComplexDouble* alpha,
|
| 506 |
-
const ComplexDouble* const A[], int ldA,
|
| 507 |
-
const ComplexDouble* const B[], int ldB,
|
| 508 |
-
const ComplexDouble* beta,
|
| 509 |
-
ComplexDouble* const C[], int ldC,
|
| 510 |
-
int batch_size)
|
| 511 |
-
{
|
| 512 |
-
BLAM_DEBUG_OUT("cublasZgemmBatched");
|
| 513 |
-
|
| 514 |
-
return cublasZgemmBatched(handle, transA, transB,
|
| 515 |
-
m, n, k,
|
| 516 |
-
reinterpret_cast<const cuDoubleComplex*>(alpha),
|
| 517 |
-
const_cast<const cuDoubleComplex**>(reinterpret_cast<const cuDoubleComplex* const *>(A)), ldA,
|
| 518 |
-
//reinterpret_cast<const cuDoubleComplex* const *>(A), ldA, // cuBLAS 9.2
|
| 519 |
-
const_cast<const cuDoubleComplex**>(reinterpret_cast<const cuDoubleComplex* const *>(B)), ldB,
|
| 520 |
-
//reinterpret_cast<const cuDoubleComplex* const *>(B), ldB, // cuBLAS 9.2
|
| 521 |
-
reinterpret_cast<const cuDoubleComplex*>(beta),
|
| 522 |
-
const_cast<cuDoubleComplex**>(reinterpret_cast<cuDoubleComplex* const *>(C)), ldC,
|
| 523 |
-
//reinterpret_cast<cuDoubleComplex* const *>(C), ldC, // cuBLAS 9.2
|
| 524 |
-
batch_size);
|
| 525 |
-
}
|
| 526 |
-
|
| 527 |
-
} // end namespace cublas
|
| 528 |
-
} // end namespace blam
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h
DELETED
|
@@ -1,143 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Contains code for debugging cutlass code
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#pragma once
|
| 37 |
-
|
| 38 |
-
#include "device_dump.h"
|
| 39 |
-
|
| 40 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
-
|
| 42 |
-
/******************************************************************************
|
| 43 |
-
* Debug and logging macros
|
| 44 |
-
******************************************************************************/
|
| 45 |
-
|
| 46 |
-
/**
|
| 47 |
-
* Formats and prints the given message to stdout
|
| 48 |
-
*/
|
| 49 |
-
#if !defined(CUDA_LOG)
|
| 50 |
-
#if !defined(__CUDA_ARCH__)
|
| 51 |
-
#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__)
|
| 52 |
-
#else
|
| 53 |
-
#define CUDA_LOG(format, ...) \
|
| 54 |
-
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
|
| 55 |
-
blockIdx.x, \
|
| 56 |
-
blockIdx.y, \
|
| 57 |
-
blockIdx.z, \
|
| 58 |
-
threadIdx.x, \
|
| 59 |
-
threadIdx.y, \
|
| 60 |
-
threadIdx.z, \
|
| 61 |
-
__VA_ARGS__);
|
| 62 |
-
#endif
|
| 63 |
-
#endif
|
| 64 |
-
|
| 65 |
-
/**
|
| 66 |
-
* Formats and prints the given message to stdout only if DEBUG is defined
|
| 67 |
-
*/
|
| 68 |
-
#if !defined(CUDA_LOG_DEBUG)
|
| 69 |
-
#ifdef DEBUG
|
| 70 |
-
#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
|
| 71 |
-
#else
|
| 72 |
-
#define CUDA_LOG_DEBUG(format, ...)
|
| 73 |
-
#endif
|
| 74 |
-
#endif
|
| 75 |
-
|
| 76 |
-
/**
|
| 77 |
-
* \brief The corresponding error message is printed to \p stderr (or \p stdout in device code)
|
| 78 |
-
* along with the supplied source context.
|
| 79 |
-
*
|
| 80 |
-
* \return The CUDA error.
|
| 81 |
-
*/
|
| 82 |
-
__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
|
| 83 |
-
const char* expression,
|
| 84 |
-
const char* filename,
|
| 85 |
-
int line) {
|
| 86 |
-
(void)filename;
|
| 87 |
-
(void)line;
|
| 88 |
-
if (error) {
|
| 89 |
-
#if !defined(__CUDA_ARCH__)
|
| 90 |
-
fprintf(
|
| 91 |
-
stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error));
|
| 92 |
-
fflush(stderr);
|
| 93 |
-
#else
|
| 94 |
-
printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression);
|
| 95 |
-
#endif
|
| 96 |
-
}
|
| 97 |
-
return error;
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
/**
|
| 101 |
-
* \brief Perror macro
|
| 102 |
-
*/
|
| 103 |
-
#ifndef CUDA_PERROR
|
| 104 |
-
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)
|
| 105 |
-
#endif
|
| 106 |
-
|
| 107 |
-
/**
|
| 108 |
-
* \brief Perror macro with exit
|
| 109 |
-
*/
|
| 110 |
-
#ifndef CUDA_PERROR_EXIT
|
| 111 |
-
#define CUDA_PERROR_EXIT(e) \
|
| 112 |
-
do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \
|
| 113 |
-
exit(1); \
|
| 114 |
-
} } while (0)
|
| 115 |
-
#endif
|
| 116 |
-
|
| 117 |
-
/**
|
| 118 |
-
* \brief Perror macro only if DEBUG is defined
|
| 119 |
-
*/
|
| 120 |
-
#ifndef CUDA_PERROR_DEBUG
|
| 121 |
-
#ifdef DEBUG
|
| 122 |
-
#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
|
| 123 |
-
#else
|
| 124 |
-
#define CUDA_PERROR_DEBUG(e) (e)
|
| 125 |
-
#endif
|
| 126 |
-
#endif
|
| 127 |
-
|
| 128 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 129 |
-
|
| 130 |
-
// A small helper class to dump a type at compile time
|
| 131 |
-
// Usage:: DumpType<Class>::Class
|
| 132 |
-
template <typename T>
|
| 133 |
-
struct DebugType {};
|
| 134 |
-
|
| 135 |
-
template <typename T>
|
| 136 |
-
void DebugTypeFunc(T const& t) {
|
| 137 |
-
T::t;
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
// A small helper class to dump a compile time constant at compile time
|
| 141 |
-
// Usage: DumpValue<Class::kConstant>::kConstant
|
| 142 |
-
template <int Value>
|
| 143 |
-
struct DebugValue {};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h
DELETED
|
@@ -1,187 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <cstdio>
|
| 35 |
-
#include "cutlass/cutlass.h"
|
| 36 |
-
|
| 37 |
-
/**
|
| 38 |
-
* \file
|
| 39 |
-
* \brief C++ interface to dump fragments and shared memory contents for
|
| 40 |
-
* debugging.
|
| 41 |
-
*/
|
| 42 |
-
|
| 43 |
-
namespace cutlass {
|
| 44 |
-
namespace debug {
|
| 45 |
-
|
| 46 |
-
/******************************************************************************
|
| 47 |
-
* Dump the fragments
|
| 48 |
-
******************************************************************************/
|
| 49 |
-
|
| 50 |
-
/// The first N threads dump the first M elements from their fragments with a
|
| 51 |
-
/// stride of S elements. If N is not specified, dump the data of all the
|
| 52 |
-
/// threads. If M is not specified, dump all the elements of the fragment.
|
| 53 |
-
template <typename Fragment>
|
| 54 |
-
CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0,
|
| 55 |
-
int S = 1) {
|
| 56 |
-
int total_threads = blockDim.x * blockDim.y * blockDim.z;
|
| 57 |
-
int block_id =
|
| 58 |
-
blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
|
| 59 |
-
int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) +
|
| 60 |
-
(threadIdx.y * blockDim.x) + threadIdx.x;
|
| 61 |
-
|
| 62 |
-
if (N < 0 || N > total_threads) {
|
| 63 |
-
if (thread_id == 0 && block_id == 0)
|
| 64 |
-
printf("Thread number N = %d should between [1, %d].\n", N,
|
| 65 |
-
total_threads);
|
| 66 |
-
|
| 67 |
-
__syncthreads();
|
| 68 |
-
|
| 69 |
-
return;
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
int total_elements = int(frag.size());
|
| 73 |
-
|
| 74 |
-
if (M < 0 || M > total_elements) {
|
| 75 |
-
if (thread_id == 0 && block_id == 0)
|
| 76 |
-
printf("Element number M = %d should between [1, %d].\n", M,
|
| 77 |
-
total_elements);
|
| 78 |
-
|
| 79 |
-
__syncthreads();
|
| 80 |
-
|
| 81 |
-
return;
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
if (N == 0) N = total_threads;
|
| 85 |
-
|
| 86 |
-
if (M == 0) M = total_elements;
|
| 87 |
-
|
| 88 |
-
if (S < 1 || S > M) {
|
| 89 |
-
if (thread_id == 0 && block_id == 0)
|
| 90 |
-
printf("Stride S = %d should between [1, %d].\n", S, M);
|
| 91 |
-
|
| 92 |
-
__syncthreads();
|
| 93 |
-
|
| 94 |
-
return;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
if (thread_id == 0 && block_id == 0)
|
| 98 |
-
printf("\n*******************Dumping the fragments*******************\n\n");
|
| 99 |
-
|
| 100 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 101 |
-
for (int tid = 0; tid < N; ++tid) {
|
| 102 |
-
if (tid == thread_id) {
|
| 103 |
-
printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31);
|
| 104 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 105 |
-
for (int i = 0; i < M; i += S) {
|
| 106 |
-
printf("%.0f ", float(typename Fragment::value_type(frag[i])));
|
| 107 |
-
}
|
| 108 |
-
printf("\n");
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
__syncthreads();
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
if (thread_id == 0 && block_id == 0)
|
| 115 |
-
printf("\n***********************************************************\n\n");
|
| 116 |
-
|
| 117 |
-
__syncthreads();
|
| 118 |
-
|
| 119 |
-
return;
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
/******************************************************************************
|
| 123 |
-
* Dump the shared memory
|
| 124 |
-
******************************************************************************/
|
| 125 |
-
|
| 126 |
-
#define SHMEM_ROW_SIZE 128
|
| 127 |
-
|
| 128 |
-
/// Dump the shared memory contents. ptr is the begin address, size specifies
|
| 129 |
-
/// the number of elements that need to be dumped, and S specifies the stride.
|
| 130 |
-
template <typename Element>
|
| 131 |
-
CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) {
|
| 132 |
-
int block_id =
|
| 133 |
-
blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
|
| 134 |
-
int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) +
|
| 135 |
-
(threadIdx.y * blockDim.x) + threadIdx.x;
|
| 136 |
-
|
| 137 |
-
if (ptr == nullptr) {
|
| 138 |
-
if (thread_id == 0 && block_id == 0) printf("ptr is null.\n");
|
| 139 |
-
|
| 140 |
-
__syncthreads();
|
| 141 |
-
return;
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
if (size < 1) {
|
| 145 |
-
if (thread_id == 0 && block_id == 0)
|
| 146 |
-
printf("Element size is less than 1\n");
|
| 147 |
-
|
| 148 |
-
__syncthreads();
|
| 149 |
-
|
| 150 |
-
return;
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
int row_elements = SHMEM_ROW_SIZE / sizeof(Element);
|
| 154 |
-
|
| 155 |
-
if (S < 1 || S > row_elements) {
|
| 156 |
-
if (thread_id == 0 && block_id == 0)
|
| 157 |
-
printf("Stride S = %d should between [1, %d].\n", S, row_elements);
|
| 158 |
-
|
| 159 |
-
__syncthreads();
|
| 160 |
-
|
| 161 |
-
return;
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
__syncthreads();
|
| 165 |
-
|
| 166 |
-
if (thread_id == 0)
|
| 167 |
-
printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id);
|
| 168 |
-
|
| 169 |
-
if (thread_id == 0) {
|
| 170 |
-
for (int i = 0; i < size; i += row_elements) {
|
| 171 |
-
for (int j = 0; j < row_elements; j += S) {
|
| 172 |
-
printf("%.0f ", float(ptr[i + j]));
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
printf("\n");
|
| 176 |
-
}
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
if (thread_id == 0)
|
| 180 |
-
printf("\n***********************************************************\n\n");
|
| 181 |
-
|
| 182 |
-
__syncthreads();
|
| 183 |
-
|
| 184 |
-
return;
|
| 185 |
-
}
|
| 186 |
-
} // namespace debug
|
| 187 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h
DELETED
|
@@ -1,402 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C'].
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include "cutlass/cutlass.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/tensor_coord.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
#include "device_utils.h"
|
| 45 |
-
#include <cfloat>
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
|
| 49 |
-
/** \brief interface to do group norm on a device memory tensor with NHWC layout.
|
| 50 |
-
* \tparam T: data type
|
| 51 |
-
*/
|
| 52 |
-
template <typename T>
|
| 53 |
-
void groupnorm(cutlass::Tensor4DCoord input_size,
|
| 54 |
-
const int num_groups,
|
| 55 |
-
const float eps,
|
| 56 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 57 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 58 |
-
TensorRef<T, layout::TensorNHWC> ref_gamma,
|
| 59 |
-
TensorRef<T, layout::TensorNHWC> ref_beta,
|
| 60 |
-
cudaStream_t stream);
|
| 61 |
-
|
| 62 |
-
extern __shared__ char groupnorm_shm[];
|
| 63 |
-
|
| 64 |
-
// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory,
|
| 65 |
-
// we store the input in the shared memory.
|
| 66 |
-
// grid(num_groups, dim0)
|
| 67 |
-
// block(BLOCKSIZE)
|
| 68 |
-
// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
|
| 69 |
-
template<typename TVec, typename T, int T_PER_TVec>
|
| 70 |
-
__global__ void groupnorm_twopass_store_locally(T* output,
|
| 71 |
-
const T* input,
|
| 72 |
-
const T* gamma,
|
| 73 |
-
const T* beta,
|
| 74 |
-
int num_groups,
|
| 75 |
-
int prod_dim1_to_last_dim,
|
| 76 |
-
int last_dim,
|
| 77 |
-
const float eps,
|
| 78 |
-
const int TVecs_PER_THREAD)
|
| 79 |
-
{
|
| 80 |
-
const int bid = blockIdx.y; // index of batch
|
| 81 |
-
const int gid = blockIdx.x; // index of group
|
| 82 |
-
const int tid = threadIdx.x; // index of thread
|
| 83 |
-
const int bdimx = blockDim.x;
|
| 84 |
-
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
| 85 |
-
const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
|
| 86 |
-
const int s_group_stride = last_dim / num_groups;
|
| 87 |
-
const int v_group_stride = s_group_stride / T_PER_TVec;
|
| 88 |
-
const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
|
| 89 |
-
const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
|
| 90 |
-
TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
|
| 91 |
-
T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid;
|
| 92 |
-
float local_sum[1] = {0.0f};
|
| 93 |
-
|
| 94 |
-
// load from global memory into shared memory
|
| 95 |
-
#pragma unroll
|
| 96 |
-
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
| 97 |
-
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
| 98 |
-
const int offset_in_group =
|
| 99 |
-
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
| 100 |
-
/ T_PER_TVec;
|
| 101 |
-
if (current_load_start_idx < s_reduce_elements) {
|
| 102 |
-
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
| 103 |
-
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
| 104 |
-
const int local_val_offset = i * T_PER_TVec;
|
| 105 |
-
#pragma unroll
|
| 106 |
-
for (int j = 0; j < T_PER_TVec; j++) {
|
| 107 |
-
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
| 108 |
-
local_sum[0] += tmp;
|
| 109 |
-
local_val[local_val_offset + j] = tmp_vec_ptr[j];
|
| 110 |
-
}
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
__shared__ float s_mean, s_variance;
|
| 114 |
-
|
| 115 |
-
// reduction for mean
|
| 116 |
-
if (bdimx <= 32) {
|
| 117 |
-
warpReduceSum<float, 1>(local_sum);
|
| 118 |
-
}
|
| 119 |
-
else {
|
| 120 |
-
blockReduceSum<float, 1>(local_sum);
|
| 121 |
-
}
|
| 122 |
-
if (tid == 0) {
|
| 123 |
-
s_mean = local_sum[0] / s_reduce_elements;
|
| 124 |
-
}
|
| 125 |
-
__syncthreads();
|
| 126 |
-
|
| 127 |
-
// reduction for std
|
| 128 |
-
local_sum[0] = 0.0f;
|
| 129 |
-
#pragma unroll
|
| 130 |
-
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
| 131 |
-
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
| 132 |
-
if (current_load_start_idx < s_reduce_elements) {
|
| 133 |
-
const int local_val_offset = i * T_PER_TVec;
|
| 134 |
-
#pragma unroll
|
| 135 |
-
for (int j = 0; j < T_PER_TVec; j++) {
|
| 136 |
-
float tmp = static_cast<float>(local_val[local_val_offset + j]);
|
| 137 |
-
tmp -= s_mean;
|
| 138 |
-
local_sum[0] += tmp * tmp;
|
| 139 |
-
}
|
| 140 |
-
}
|
| 141 |
-
}
|
| 142 |
-
if (bdimx <= 32) {
|
| 143 |
-
warpReduceSum<float, 1>(local_sum);
|
| 144 |
-
}
|
| 145 |
-
else {
|
| 146 |
-
blockReduceSum<float, 1>(local_sum);
|
| 147 |
-
}
|
| 148 |
-
if (tid == 0) {
|
| 149 |
-
s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
|
| 150 |
-
}
|
| 151 |
-
__syncthreads();
|
| 152 |
-
|
| 153 |
-
// normalize
|
| 154 |
-
const int gamma_offset_of_group = gid * v_group_stride;
|
| 155 |
-
const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
|
| 156 |
-
const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
|
| 157 |
-
#pragma unroll
|
| 158 |
-
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
| 159 |
-
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
| 160 |
-
const int offset_in_group =
|
| 161 |
-
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
| 162 |
-
/ T_PER_TVec;
|
| 163 |
-
const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
|
| 164 |
-
const int local_val_offset = i * T_PER_TVec;
|
| 165 |
-
if (current_load_start_idx < s_reduce_elements) {
|
| 166 |
-
TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
|
| 167 |
-
TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
|
| 168 |
-
T* gamma_val_ptr = (T*)(&gamma_val);
|
| 169 |
-
T* beta_val_ptr = (T*)(&beta_val);
|
| 170 |
-
TVec tmp_vec;
|
| 171 |
-
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
| 172 |
-
#pragma unroll
|
| 173 |
-
for (int j = 0; j < T_PER_TVec; j++) {
|
| 174 |
-
float tmp = (static_cast<float>(local_val[local_val_offset + j]) - s_mean) * s_variance
|
| 175 |
-
* static_cast<float>(gamma_val_ptr[j])
|
| 176 |
-
+ static_cast<float>(beta_val_ptr[j]);
|
| 177 |
-
if (sizeof(T) == sizeof(half)) {
|
| 178 |
-
tmp_vec_ptr[j] = T(__float2half_rn(tmp));
|
| 179 |
-
}
|
| 180 |
-
else {
|
| 181 |
-
tmp_vec_ptr[j] = T(tmp);
|
| 182 |
-
}
|
| 183 |
-
}
|
| 184 |
-
output_TVec_ptr[offset_in_group] = tmp_vec;
|
| 185 |
-
}
|
| 186 |
-
}
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
// For large prod_dim1_to_last_dim/num_groups,
|
| 190 |
-
// in which the data cannot be stored locally,
|
| 191 |
-
// we will load from global memory multiple times,
|
| 192 |
-
// grid(num_groups, dim0)
|
| 193 |
-
// block(BLOCKSIZE)
|
| 194 |
-
// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
|
| 195 |
-
template<typename TVec, typename T, int T_PER_TVec>
|
| 196 |
-
__global__ void groupnorm_twopass_multiple_load(T* output,
|
| 197 |
-
const T* input,
|
| 198 |
-
const T* gamma,
|
| 199 |
-
const T* beta,
|
| 200 |
-
int num_groups,
|
| 201 |
-
int prod_dim1_to_last_dim,
|
| 202 |
-
int last_dim,
|
| 203 |
-
const float eps,
|
| 204 |
-
const int TVecs_PER_THREAD)
|
| 205 |
-
{
|
| 206 |
-
const int bid = blockIdx.y; // index of batch
|
| 207 |
-
const int gid = blockIdx.x; // index of group
|
| 208 |
-
const int tid = threadIdx.x; // index of thread
|
| 209 |
-
const int bdimx = blockDim.x;
|
| 210 |
-
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
| 211 |
-
const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
|
| 212 |
-
const int s_group_stride = last_dim / num_groups;
|
| 213 |
-
const int v_group_stride = s_group_stride / T_PER_TVec;
|
| 214 |
-
const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
|
| 215 |
-
const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
|
| 216 |
-
TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
|
| 217 |
-
float local_sum[1] = {0.0f};
|
| 218 |
-
|
| 219 |
-
#pragma unroll
|
| 220 |
-
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
| 221 |
-
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
| 222 |
-
if (current_load_start_idx < s_reduce_elements) {
|
| 223 |
-
const int offset_in_group =
|
| 224 |
-
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
| 225 |
-
/ T_PER_TVec;
|
| 226 |
-
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
| 227 |
-
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
| 228 |
-
#pragma unroll
|
| 229 |
-
for (int j = 0; j < T_PER_TVec; j++) {
|
| 230 |
-
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
| 231 |
-
local_sum[0] += tmp;
|
| 232 |
-
}
|
| 233 |
-
}
|
| 234 |
-
}
|
| 235 |
-
__shared__ float s_mean, s_variance;
|
| 236 |
-
|
| 237 |
-
// reduction for mean
|
| 238 |
-
if (bdimx <= 32) {
|
| 239 |
-
warpReduceSum<float, 1>(local_sum);
|
| 240 |
-
}
|
| 241 |
-
else {
|
| 242 |
-
blockReduceSum<float, 1>(local_sum);
|
| 243 |
-
}
|
| 244 |
-
if (tid == 0) {
|
| 245 |
-
s_mean = local_sum[0] / s_reduce_elements;
|
| 246 |
-
}
|
| 247 |
-
__syncthreads();
|
| 248 |
-
|
| 249 |
-
// reduction for std
|
| 250 |
-
local_sum[0] = 0.0f;
|
| 251 |
-
#pragma unroll
|
| 252 |
-
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
| 253 |
-
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
| 254 |
-
if (current_load_start_idx < s_reduce_elements) {
|
| 255 |
-
const int offset_in_group =
|
| 256 |
-
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
| 257 |
-
/ T_PER_TVec;
|
| 258 |
-
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
| 259 |
-
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
| 260 |
-
#pragma unroll
|
| 261 |
-
for (int j = 0; j < T_PER_TVec; j++) {
|
| 262 |
-
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
| 263 |
-
tmp -= s_mean;
|
| 264 |
-
local_sum[0] += tmp * tmp;
|
| 265 |
-
}
|
| 266 |
-
}
|
| 267 |
-
}
|
| 268 |
-
if (bdimx <= 32) {
|
| 269 |
-
warpReduceSum<float, 1>(local_sum);
|
| 270 |
-
}
|
| 271 |
-
else {
|
| 272 |
-
blockReduceSum<float, 1>(local_sum);
|
| 273 |
-
}
|
| 274 |
-
if (tid == 0) {
|
| 275 |
-
s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
|
| 276 |
-
}
|
| 277 |
-
__syncthreads();
|
| 278 |
-
|
| 279 |
-
// normalize
|
| 280 |
-
const int gamma_offset_of_group = gid * v_group_stride;
|
| 281 |
-
const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
|
| 282 |
-
const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
|
| 283 |
-
#pragma unroll
|
| 284 |
-
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
| 285 |
-
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
| 286 |
-
if (current_load_start_idx < s_reduce_elements) {
|
| 287 |
-
const int offset_in_group =
|
| 288 |
-
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
| 289 |
-
/ T_PER_TVec;
|
| 290 |
-
const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
|
| 291 |
-
TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
|
| 292 |
-
TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
|
| 293 |
-
T* gamma_val_ptr = (T*)(&gamma_val);
|
| 294 |
-
T* beta_val_ptr = (T*)(&beta_val);
|
| 295 |
-
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
| 296 |
-
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
| 297 |
-
TVec output_tmp_vec;
|
| 298 |
-
T* output_tmp_vec_ptr = (T*)(&output_tmp_vec);
|
| 299 |
-
#pragma unroll
|
| 300 |
-
for (int j = 0; j < T_PER_TVec; j++) {
|
| 301 |
-
float tmp =
|
| 302 |
-
(static_cast<float>(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast<float>(gamma_val_ptr[j])
|
| 303 |
-
+ static_cast<float>(beta_val_ptr[j]);
|
| 304 |
-
if (sizeof(T) == sizeof(half)) {
|
| 305 |
-
output_tmp_vec_ptr[j] = T(__float2half_rn(tmp));
|
| 306 |
-
}
|
| 307 |
-
else {
|
| 308 |
-
output_tmp_vec_ptr[j] = T(tmp);
|
| 309 |
-
}
|
| 310 |
-
}
|
| 311 |
-
output_TVec_ptr[offset_in_group] = output_tmp_vec;
|
| 312 |
-
}
|
| 313 |
-
}
|
| 314 |
-
}
|
| 315 |
-
|
| 316 |
-
//ref_input & ref_output should be [N, H, W, C]
|
| 317 |
-
//ref_gamma & ref_beta should be [1, 1, 1, C]
|
| 318 |
-
template <typename T>
|
| 319 |
-
void groupnorm(cutlass::Tensor4DCoord input_size,
|
| 320 |
-
const int num_groups,
|
| 321 |
-
const float eps,
|
| 322 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 323 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 324 |
-
TensorRef<T, layout::TensorNHWC> ref_gamma,
|
| 325 |
-
TensorRef<T, layout::TensorNHWC> ref_beta,
|
| 326 |
-
cudaStream_t stream){
|
| 327 |
-
const int N = input_size.n();
|
| 328 |
-
const int H = input_size.h();
|
| 329 |
-
const int W = input_size.w();
|
| 330 |
-
const int C = input_size.c();
|
| 331 |
-
if (C % num_groups != 0){
|
| 332 |
-
printf("[ERROR] C should be a multiple of num_groups.\n");
|
| 333 |
-
}
|
| 334 |
-
T* output = ref_output.data();
|
| 335 |
-
const T* input = ref_input.data();
|
| 336 |
-
const T* gamma = ref_gamma.data();
|
| 337 |
-
const T* beta = ref_beta.data();
|
| 338 |
-
|
| 339 |
-
const int dim0 = N;
|
| 340 |
-
const int last_dim = C;
|
| 341 |
-
const int prod_dim1_to_last_dim = H*W*C;
|
| 342 |
-
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
| 343 |
-
const int s_group_stride = last_dim / num_groups;
|
| 344 |
-
dim3 grid(num_groups, dim0);
|
| 345 |
-
int threadblock_size = 32;
|
| 346 |
-
if (s_group_stride % 2 == 0) {
|
| 347 |
-
const int T_PER_TVec = 2;
|
| 348 |
-
while (threadblock_size < 1024) {
|
| 349 |
-
if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
|
| 350 |
-
break;
|
| 351 |
-
threadblock_size *= 2;
|
| 352 |
-
}
|
| 353 |
-
dim3 block(threadblock_size);
|
| 354 |
-
const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
|
| 355 |
-
const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
|
| 356 |
-
// for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32;
|
| 357 |
-
// the size of grid & block may have better choice for different cases.
|
| 358 |
-
// ensure shared memory is smaller than 48KB
|
| 359 |
-
if (std::is_same<T, float>::value){
|
| 360 |
-
if (shm_size < 48 * 1024) {
|
| 361 |
-
groupnorm_twopass_store_locally<float2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
| 362 |
-
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
| 363 |
-
}
|
| 364 |
-
else {
|
| 365 |
-
groupnorm_twopass_multiple_load<float2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
| 366 |
-
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
| 367 |
-
}
|
| 368 |
-
}
|
| 369 |
-
else{
|
| 370 |
-
if (shm_size < 48 * 1024) {
|
| 371 |
-
groupnorm_twopass_store_locally<half2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
| 372 |
-
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
| 373 |
-
}
|
| 374 |
-
else {
|
| 375 |
-
groupnorm_twopass_multiple_load<half2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
| 376 |
-
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
| 377 |
-
}
|
| 378 |
-
}
|
| 379 |
-
}
|
| 380 |
-
else {
|
| 381 |
-
const int T_PER_TVec = 1;
|
| 382 |
-
while (threadblock_size < 1024) {
|
| 383 |
-
if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
|
| 384 |
-
break;
|
| 385 |
-
threadblock_size *= 2;
|
| 386 |
-
}
|
| 387 |
-
dim3 block(threadblock_size);
|
| 388 |
-
const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
|
| 389 |
-
const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
|
| 390 |
-
if (shm_size < 48 * 1024) {
|
| 391 |
-
groupnorm_twopass_store_locally<T, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
| 392 |
-
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
| 393 |
-
}
|
| 394 |
-
else {
|
| 395 |
-
groupnorm_twopass_multiple_load<T, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
| 396 |
-
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
| 397 |
-
}
|
| 398 |
-
}
|
| 399 |
-
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
} //namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h
DELETED
|
@@ -1,644 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout.
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include "cutlass/cutlass.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/tensor_coord.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
#include "device_utils.h"
|
| 45 |
-
#include <cfloat>
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
|
| 49 |
-
/** \brief interface to do layernorm on a device memory tensor with RowMajor layout.
|
| 50 |
-
* \tparam T: data type
|
| 51 |
-
*/
|
| 52 |
-
template <typename T>
|
| 53 |
-
void layernorm(cutlass::MatrixCoord tensor_size,
|
| 54 |
-
TensorRef<T, layout::RowMajor> ref_output,
|
| 55 |
-
TensorRef<T, layout::RowMajor> ref_input,
|
| 56 |
-
TensorRef<T, layout::RowMajor> ref_gamma,
|
| 57 |
-
TensorRef<T, layout::RowMajor> ref_beta,
|
| 58 |
-
cudaStream_t stream);
|
| 59 |
-
|
| 60 |
-
/**
|
| 61 |
-
* output [m, n] row-major
|
| 62 |
-
* input [m, n] row-major
|
| 63 |
-
* gamma [n]
|
| 64 |
-
* beta [n]
|
| 65 |
-
* grid(m)
|
| 66 |
-
* block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements
|
| 67 |
-
*/
|
| 68 |
-
template<typename T, int ITEM_PER_THREAD>
|
| 69 |
-
__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output,
|
| 70 |
-
const T* input,
|
| 71 |
-
const T* gamma,
|
| 72 |
-
const T* beta,
|
| 73 |
-
const int m,
|
| 74 |
-
const int n)
|
| 75 |
-
{
|
| 76 |
-
const int m_idx = blockIdx.x;
|
| 77 |
-
const int tid = threadIdx.x;
|
| 78 |
-
const int bdimx = blockDim.x;
|
| 79 |
-
__shared__ float s_mean, s_variance;
|
| 80 |
-
T local_val[ITEM_PER_THREAD];
|
| 81 |
-
float local_sums[1] = {0.0f};
|
| 82 |
-
int offset = m_idx * n;
|
| 83 |
-
input += offset;
|
| 84 |
-
output += offset;
|
| 85 |
-
|
| 86 |
-
const T zero = T(0.0f);
|
| 87 |
-
#pragma unroll
|
| 88 |
-
for (int i = 0 ; i < ITEM_PER_THREAD ; i++){
|
| 89 |
-
int index = tid + i*bdimx;
|
| 90 |
-
local_val[i] = index < n ? input[index] : zero;
|
| 91 |
-
local_sums[0] += static_cast<float>(local_val[i]);
|
| 92 |
-
}
|
| 93 |
-
if (blockDim.x <= 32) {
|
| 94 |
-
warpReduceSum<float, 1>(local_sums);
|
| 95 |
-
}
|
| 96 |
-
else {
|
| 97 |
-
blockReduceSum<float, 1>(local_sums);
|
| 98 |
-
}
|
| 99 |
-
if (threadIdx.x == 0) {
|
| 100 |
-
s_mean = local_sums[0] / n;
|
| 101 |
-
}
|
| 102 |
-
__syncthreads();
|
| 103 |
-
|
| 104 |
-
local_sums[0] = 0.0f;
|
| 105 |
-
#pragma unroll
|
| 106 |
-
for (int i = 0 ; i < ITEM_PER_THREAD ; i++){
|
| 107 |
-
int index = tid + i*bdimx;
|
| 108 |
-
if (index < n){
|
| 109 |
-
const float tmp = static_cast<float>(local_val[i]) - s_mean;
|
| 110 |
-
local_sums[0] += tmp * tmp;
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
if (blockDim.x <= 32) {
|
| 115 |
-
warpReduceSum<float, 1>(local_sums);
|
| 116 |
-
}
|
| 117 |
-
else {
|
| 118 |
-
blockReduceSum<float, 1>(local_sums);
|
| 119 |
-
}
|
| 120 |
-
if (threadIdx.x == 0) {
|
| 121 |
-
s_variance = rsqrtf(local_sums[0] / n + 1e-5);
|
| 122 |
-
}
|
| 123 |
-
__syncthreads();
|
| 124 |
-
|
| 125 |
-
#pragma unroll
|
| 126 |
-
for (int i = 0 ; i < ITEM_PER_THREAD ; i++){
|
| 127 |
-
int index = tid + i*bdimx;
|
| 128 |
-
if (index < n) {
|
| 129 |
-
const T gamma_val = gamma[index];
|
| 130 |
-
const T beta_val = beta[index];
|
| 131 |
-
output[index] = T((static_cast<float>(local_val[i]) - s_mean) * s_variance * static_cast<float>(gamma_val) + static_cast<float>(beta_val));
|
| 132 |
-
}
|
| 133 |
-
}
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
/**
|
| 137 |
-
* output [m, n] row-major
|
| 138 |
-
* input [m, n] row-major
|
| 139 |
-
* gamma [n]
|
| 140 |
-
* beta [n]
|
| 141 |
-
* grid(m)
|
| 142 |
-
* block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements;
|
| 143 |
-
*/
|
| 144 |
-
template<typename T2, typename T, int ITEM_PER_THREAD>
|
| 145 |
-
__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output,
|
| 146 |
-
const T2* input,
|
| 147 |
-
const T2* gamma,
|
| 148 |
-
const T2* beta,
|
| 149 |
-
const int m,
|
| 150 |
-
const int n)
|
| 151 |
-
{
|
| 152 |
-
const int m_idx = blockIdx.x;
|
| 153 |
-
const int tid = threadIdx.x;
|
| 154 |
-
const int bdimx = blockDim.x;
|
| 155 |
-
__shared__ float s_mean, s_variance;
|
| 156 |
-
float local_sums[1] = {0.0f};
|
| 157 |
-
T2 local_val[ITEM_PER_THREAD];
|
| 158 |
-
const int n_2 = n / 2;
|
| 159 |
-
int offset = m_idx * n_2;
|
| 160 |
-
input += offset;
|
| 161 |
-
output += offset;
|
| 162 |
-
|
| 163 |
-
const T2 zero = {T(0.0f), T(0.0f)};
|
| 164 |
-
#pragma UNROLL
|
| 165 |
-
for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
|
| 166 |
-
const int index = i*bdimx + tid;
|
| 167 |
-
local_val[i] = index < n_2 ? input[index] : zero;
|
| 168 |
-
local_sums[0] += static_cast<float>(local_val[i].x) + static_cast<float>(local_val[i].y);
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
if (blockDim.x <= 32) {
|
| 172 |
-
warpReduceSum<float, 1>(local_sums);
|
| 173 |
-
}
|
| 174 |
-
else {
|
| 175 |
-
blockReduceSum<float, 1>(local_sums);
|
| 176 |
-
}
|
| 177 |
-
if (threadIdx.x == 0) {
|
| 178 |
-
s_mean = local_sums[0] / n;
|
| 179 |
-
}
|
| 180 |
-
__syncthreads();
|
| 181 |
-
|
| 182 |
-
local_sums[0] = 0.0f;
|
| 183 |
-
#pragma UNROLL
|
| 184 |
-
for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
|
| 185 |
-
const int index = i*bdimx + tid;
|
| 186 |
-
if (index < n_2){
|
| 187 |
-
const float2 tmp = {static_cast<float>(local_val[i].x) - s_mean,
|
| 188 |
-
static_cast<float>(local_val[i].y) - s_mean};
|
| 189 |
-
local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y;
|
| 190 |
-
}
|
| 191 |
-
}
|
| 192 |
-
if (blockDim.x <= 32) {
|
| 193 |
-
warpReduceSum<float, 1>(local_sums);
|
| 194 |
-
}
|
| 195 |
-
else {
|
| 196 |
-
blockReduceSum<float, 1>(local_sums);
|
| 197 |
-
}
|
| 198 |
-
if (threadIdx.x == 0) {
|
| 199 |
-
s_variance = rsqrtf(local_sums[0] / n + 1e-5);
|
| 200 |
-
}
|
| 201 |
-
__syncthreads();
|
| 202 |
-
|
| 203 |
-
#pragma UNROLL
|
| 204 |
-
for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
|
| 205 |
-
const int index = i*bdimx + tid;
|
| 206 |
-
if (index < n_2){
|
| 207 |
-
const T2 gamma_val = gamma[index];
|
| 208 |
-
const T2 beta_val = beta[index];
|
| 209 |
-
T2 tmp;
|
| 210 |
-
tmp.x = T((static_cast<float>(local_val[i].x) - s_mean)*s_variance*static_cast<float>(gamma_val.x) + static_cast<float>(beta_val.x));
|
| 211 |
-
tmp.y = T((static_cast<float>(local_val[i].y) - s_mean)*s_variance*static_cast<float>(gamma_val.y) + static_cast<float>(beta_val.y));
|
| 212 |
-
output[index] = tmp;
|
| 213 |
-
}
|
| 214 |
-
}
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
/**
|
| 218 |
-
* output [m, n] row-major
|
| 219 |
-
* input [m, n] row-major
|
| 220 |
-
* gamma [n]
|
| 221 |
-
* beta [n]
|
| 222 |
-
* grid(m)
|
| 223 |
-
* block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements;
|
| 224 |
-
*/
|
| 225 |
-
template<typename T4, typename T, int ITEM_PER_THREAD>
|
| 226 |
-
__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output,
|
| 227 |
-
const T4* input,
|
| 228 |
-
const T4* gamma,
|
| 229 |
-
const T4* beta,
|
| 230 |
-
const int m,
|
| 231 |
-
const int n)
|
| 232 |
-
{
|
| 233 |
-
const int m_idx = blockIdx.x;
|
| 234 |
-
const int tid = threadIdx.x;
|
| 235 |
-
const int bdimx = blockDim.x;
|
| 236 |
-
__shared__ float s_mean, s_variance;
|
| 237 |
-
float local_sums[1] = {0.0f};
|
| 238 |
-
T4 local_val[ITEM_PER_THREAD];
|
| 239 |
-
const int n_4 = n / 4;
|
| 240 |
-
int offset = m_idx * n_4;
|
| 241 |
-
input += offset;
|
| 242 |
-
output += offset;
|
| 243 |
-
|
| 244 |
-
const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)};
|
| 245 |
-
#pragma UNROLL
|
| 246 |
-
for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
|
| 247 |
-
const int index = i*bdimx + tid;
|
| 248 |
-
local_val[i] = index < n_4 ? input[index] : zero;
|
| 249 |
-
local_sums[0] += static_cast<float>(local_val[i].x) + static_cast<float>(local_val[i].y) +
|
| 250 |
-
static_cast<float>(local_val[i].z) + static_cast<float>(local_val[i].w);
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
if (blockDim.x <= 32) {
|
| 254 |
-
warpReduceSum<float, 1>(local_sums);
|
| 255 |
-
}
|
| 256 |
-
else {
|
| 257 |
-
blockReduceSum<float, 1>(local_sums);
|
| 258 |
-
}
|
| 259 |
-
if (threadIdx.x == 0) {
|
| 260 |
-
s_mean = local_sums[0] / n;
|
| 261 |
-
}
|
| 262 |
-
__syncthreads();
|
| 263 |
-
|
| 264 |
-
local_sums[0] = 0.0f;
|
| 265 |
-
#pragma UNROLL
|
| 266 |
-
for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
|
| 267 |
-
const int index = i*bdimx + tid;
|
| 268 |
-
if (index < n_4){
|
| 269 |
-
const float4 tmp = {static_cast<float>(local_val[i].x) - s_mean,
|
| 270 |
-
static_cast<float>(local_val[i].y) - s_mean,
|
| 271 |
-
static_cast<float>(local_val[i].z) - s_mean,
|
| 272 |
-
static_cast<float>(local_val[i].w) - s_mean};
|
| 273 |
-
local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w;
|
| 274 |
-
}
|
| 275 |
-
}
|
| 276 |
-
if (blockDim.x <= 32) {
|
| 277 |
-
warpReduceSum<float, 1>(local_sums);
|
| 278 |
-
}
|
| 279 |
-
else {
|
| 280 |
-
blockReduceSum<float, 1>(local_sums);
|
| 281 |
-
}
|
| 282 |
-
if (threadIdx.x == 0) {
|
| 283 |
-
s_variance = rsqrtf(local_sums[0] / n + 1e-5);
|
| 284 |
-
}
|
| 285 |
-
__syncthreads();
|
| 286 |
-
|
| 287 |
-
#pragma UNROLL
|
| 288 |
-
for (int i = 0; i < ITEM_PER_THREAD; i += 1) {
|
| 289 |
-
const int index = i*bdimx + tid;
|
| 290 |
-
if (index < n_4){
|
| 291 |
-
const T4 gamma_val = gamma[index];
|
| 292 |
-
const T4 beta_val = beta[index];
|
| 293 |
-
T4 tmp;
|
| 294 |
-
tmp.x = T((static_cast<float>(local_val[i].x) - s_mean)*s_variance*static_cast<float>(gamma_val.x) + static_cast<float>(beta_val.x));
|
| 295 |
-
tmp.y = T((static_cast<float>(local_val[i].y) - s_mean)*s_variance*static_cast<float>(gamma_val.y) + static_cast<float>(beta_val.y));
|
| 296 |
-
tmp.z = T((static_cast<float>(local_val[i].z) - s_mean)*s_variance*static_cast<float>(gamma_val.z) + static_cast<float>(beta_val.z));
|
| 297 |
-
tmp.w = T((static_cast<float>(local_val[i].w) - s_mean)*s_variance*static_cast<float>(gamma_val.w) + static_cast<float>(beta_val.w));
|
| 298 |
-
output[index] = tmp;
|
| 299 |
-
}
|
| 300 |
-
}
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
/**
|
| 304 |
-
* output [m, n] row-major
|
| 305 |
-
* input [m, n] row-major
|
| 306 |
-
* gamma [n]
|
| 307 |
-
* beta [n]
|
| 308 |
-
* grid(m)
|
| 309 |
-
* block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements
|
| 310 |
-
*/
|
| 311 |
-
template<typename T>
|
| 312 |
-
__global__ void layernorm_twoPassAlgo_e1(T* output,
|
| 313 |
-
const T* input,
|
| 314 |
-
const T* gamma,
|
| 315 |
-
const T* beta,
|
| 316 |
-
const int m,
|
| 317 |
-
const int n)
|
| 318 |
-
{
|
| 319 |
-
const int m_idx = blockIdx.x;
|
| 320 |
-
const int tid = threadIdx.x;
|
| 321 |
-
const int bdimx = blockDim.x;
|
| 322 |
-
__shared__ float s_mean, s_variance;
|
| 323 |
-
float local_sums[1] = {0.0f};
|
| 324 |
-
int offset = m_idx * n;
|
| 325 |
-
input += offset;
|
| 326 |
-
output += offset;
|
| 327 |
-
|
| 328 |
-
for (int index = tid ; index < n ; index += bdimx){
|
| 329 |
-
float local_val = static_cast<float>(input[index]);
|
| 330 |
-
local_sums[0] += local_val;
|
| 331 |
-
}
|
| 332 |
-
if (blockDim.x <= 32) {
|
| 333 |
-
warpReduceSum<float, 1>(local_sums);
|
| 334 |
-
}
|
| 335 |
-
else {
|
| 336 |
-
blockReduceSum<float, 1>(local_sums);
|
| 337 |
-
}
|
| 338 |
-
if (threadIdx.x == 0) {
|
| 339 |
-
s_mean = local_sums[0] / n;
|
| 340 |
-
}
|
| 341 |
-
__syncthreads();
|
| 342 |
-
|
| 343 |
-
local_sums[0] = 0.0f;
|
| 344 |
-
for (int index = tid ; index < n ; index += bdimx){
|
| 345 |
-
float local_val = static_cast<float>(input[index]);
|
| 346 |
-
local_val = local_val - s_mean;
|
| 347 |
-
local_sums[0] += local_val * local_val;
|
| 348 |
-
}
|
| 349 |
-
|
| 350 |
-
if (blockDim.x <= 32) {
|
| 351 |
-
warpReduceSum<float, 1>(local_sums);
|
| 352 |
-
}
|
| 353 |
-
else {
|
| 354 |
-
blockReduceSum<float, 1>(local_sums);
|
| 355 |
-
}
|
| 356 |
-
if (threadIdx.x == 0) {
|
| 357 |
-
s_variance = rsqrtf(local_sums[0] / n + 1e-5);
|
| 358 |
-
}
|
| 359 |
-
__syncthreads();
|
| 360 |
-
|
| 361 |
-
for (int index = tid ; index < n ; index += bdimx){
|
| 362 |
-
const T gamma_val = gamma[index];
|
| 363 |
-
const T beta_val = beta[index];
|
| 364 |
-
const T local_val = input[index];
|
| 365 |
-
output[index] = T((static_cast<float>(local_val) - s_mean) * s_variance * static_cast<float>(gamma_val) + static_cast<float>(beta_val));
|
| 366 |
-
}
|
| 367 |
-
}
|
| 368 |
-
|
| 369 |
-
/**
|
| 370 |
-
* output [m, n] row-major
|
| 371 |
-
* input [m, n] row-major
|
| 372 |
-
* gamma [n]
|
| 373 |
-
* beta [n]
|
| 374 |
-
* grid(m)
|
| 375 |
-
* block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements;
|
| 376 |
-
*/
|
| 377 |
-
template<typename T2, typename T>
|
| 378 |
-
__global__ void layernorm_twoPassAlgo_e2(T2* output,
|
| 379 |
-
const T2* input,
|
| 380 |
-
const T2* gamma,
|
| 381 |
-
const T2* beta,
|
| 382 |
-
const int m,
|
| 383 |
-
const int n)
|
| 384 |
-
{
|
| 385 |
-
const int m_idx = blockIdx.x;
|
| 386 |
-
const int tid = threadIdx.x;
|
| 387 |
-
const int bdimx = blockDim.x;
|
| 388 |
-
__shared__ float s_mean, s_variance;
|
| 389 |
-
float local_sums[1] = {0.0f};
|
| 390 |
-
const int n_2 = n / 2;
|
| 391 |
-
int offset = m_idx * n_2;
|
| 392 |
-
input += offset;
|
| 393 |
-
output += offset;
|
| 394 |
-
|
| 395 |
-
for (int index = tid; index < n_2; index += bdimx) {
|
| 396 |
-
const T2 local_val = input[index];
|
| 397 |
-
local_sums[0] += static_cast<float>(local_val.x) + static_cast<float>(local_val.y);
|
| 398 |
-
}
|
| 399 |
-
|
| 400 |
-
if (blockDim.x <= 32) {
|
| 401 |
-
warpReduceSum<float, 1>(local_sums);
|
| 402 |
-
}
|
| 403 |
-
else {
|
| 404 |
-
blockReduceSum<float, 1>(local_sums);
|
| 405 |
-
}
|
| 406 |
-
if (threadIdx.x == 0) {
|
| 407 |
-
s_mean = local_sums[0] / n;
|
| 408 |
-
}
|
| 409 |
-
__syncthreads();
|
| 410 |
-
|
| 411 |
-
local_sums[0] = 0.0f;
|
| 412 |
-
for (int index = tid; index < n_2; index += bdimx) {
|
| 413 |
-
const T2 local_val = input[index];
|
| 414 |
-
const float2 tmp = {static_cast<float>(local_val.x) - s_mean,
|
| 415 |
-
static_cast<float>(local_val.y) - s_mean};
|
| 416 |
-
local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y;
|
| 417 |
-
}
|
| 418 |
-
if (blockDim.x <= 32) {
|
| 419 |
-
warpReduceSum<float, 1>(local_sums);
|
| 420 |
-
}
|
| 421 |
-
else {
|
| 422 |
-
blockReduceSum<float, 1>(local_sums);
|
| 423 |
-
}
|
| 424 |
-
if (threadIdx.x == 0) {
|
| 425 |
-
s_variance = rsqrtf(local_sums[0] / n + 1e-5);
|
| 426 |
-
}
|
| 427 |
-
__syncthreads();
|
| 428 |
-
|
| 429 |
-
for (int index = tid; index < n_2; index += bdimx) {
|
| 430 |
-
const T2 local_val = input[index];
|
| 431 |
-
const T2 gamma_val = gamma[index];
|
| 432 |
-
const T2 beta_val = beta[index];
|
| 433 |
-
T2 tmp;
|
| 434 |
-
tmp.x = T((static_cast<float>(local_val.x) - s_mean)*s_variance*static_cast<float>(gamma_val.x) + static_cast<float>(beta_val.x));
|
| 435 |
-
tmp.y = T((static_cast<float>(local_val.y) - s_mean)*s_variance*static_cast<float>(gamma_val.y) + static_cast<float>(beta_val.y));
|
| 436 |
-
output[index] = tmp;
|
| 437 |
-
}
|
| 438 |
-
}
|
| 439 |
-
|
| 440 |
-
template <typename T>
|
| 441 |
-
void layernorm(cutlass::MatrixCoord tensor_size,
|
| 442 |
-
TensorRef<T, layout::RowMajor> ref_output,
|
| 443 |
-
TensorRef<T, layout::RowMajor> ref_input,
|
| 444 |
-
TensorRef<T, layout::RowMajor> ref_gamma,
|
| 445 |
-
TensorRef<T, layout::RowMajor> ref_beta,
|
| 446 |
-
cudaStream_t stream){
|
| 447 |
-
const int m = tensor_size.row();
|
| 448 |
-
const int n = tensor_size.column();
|
| 449 |
-
T* output = ref_output.data();
|
| 450 |
-
const T* input = ref_input.data();
|
| 451 |
-
const T* gamma = ref_gamma.data();
|
| 452 |
-
const T* beta = ref_beta.data();
|
| 453 |
-
dim3 grid(m);
|
| 454 |
-
dim3 block((n + 31)/32*32);
|
| 455 |
-
if (block.x > 1024){
|
| 456 |
-
block.x = 1024;
|
| 457 |
-
}
|
| 458 |
-
// TODO : There should be better configs for different cases, we only use several samples to show how to use here
|
| 459 |
-
// TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels.
|
| 460 |
-
if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) {
|
| 461 |
-
block.x = (n/4 + 31)/32*32;
|
| 462 |
-
if (std::is_same<T, float>::value) {
|
| 463 |
-
layernorm_twoPassAlgo_stored_locally_e4<float4, float, 1><<<grid, block, 0, stream>>>(
|
| 464 |
-
(float4*)output,
|
| 465 |
-
(const float4*)input,
|
| 466 |
-
(const float4*)gamma,
|
| 467 |
-
(const float4*)beta,
|
| 468 |
-
m,
|
| 469 |
-
n);
|
| 470 |
-
} // if (std::is_same<T, float>::value)
|
| 471 |
-
else {
|
| 472 |
-
layernorm_twoPassAlgo_stored_locally_e4<half4, half, 1><<<grid, block, 0, stream>>>(
|
| 473 |
-
(half4*)output,
|
| 474 |
-
(const half4*)input,
|
| 475 |
-
(const half4*)gamma,
|
| 476 |
-
(const half4*)beta,
|
| 477 |
-
m,
|
| 478 |
-
n);
|
| 479 |
-
}
|
| 480 |
-
} //if ((n % 4 == 0) && (n >= 128) && (n <= 4096))
|
| 481 |
-
else if (n % 2 == 0) {
|
| 482 |
-
if (n / 2 <= 1024) {
|
| 483 |
-
block.x = (n/2 + 31)/32*32;
|
| 484 |
-
if (std::is_same<T, float>::value) {
|
| 485 |
-
layernorm_twoPassAlgo_stored_locally_e2<float2, float, 1><<<grid, block, 0, stream>>>(
|
| 486 |
-
(float2*)output,
|
| 487 |
-
(const float2*)input,
|
| 488 |
-
(const float2*)gamma,
|
| 489 |
-
(const float2*)beta,
|
| 490 |
-
m,
|
| 491 |
-
n);
|
| 492 |
-
} //if (std::is_same<T, float>::value)
|
| 493 |
-
else {
|
| 494 |
-
layernorm_twoPassAlgo_stored_locally_e2<half2, half, 1><<<grid, block, 0, stream>>>(
|
| 495 |
-
(half2*)output,
|
| 496 |
-
(const half2*)input,
|
| 497 |
-
(const half2*)gamma,
|
| 498 |
-
(const half2*)beta,
|
| 499 |
-
m,
|
| 500 |
-
n);
|
| 501 |
-
}
|
| 502 |
-
} // if (n / 2 <= 1024)
|
| 503 |
-
else if (n <= 8192) {
|
| 504 |
-
block.x = ((n + 7)/8 + 31)/32*32;
|
| 505 |
-
if (std::is_same<T, float>::value) {
|
| 506 |
-
layernorm_twoPassAlgo_stored_locally_e2<float2, float, 4><<<grid, block, 0, stream>>>(
|
| 507 |
-
(float2*)output,
|
| 508 |
-
(const float2*)input,
|
| 509 |
-
(const float2*)gamma,
|
| 510 |
-
(const float2*)beta,
|
| 511 |
-
m,
|
| 512 |
-
n);
|
| 513 |
-
} // if (std::is_same<T, float>::value)
|
| 514 |
-
else {
|
| 515 |
-
layernorm_twoPassAlgo_stored_locally_e2<half2, half, 4><<<grid, block, 0, stream>>>(
|
| 516 |
-
(half2*)output,
|
| 517 |
-
(const half2*)input,
|
| 518 |
-
(const half2*)gamma,
|
| 519 |
-
(const half2*)beta,
|
| 520 |
-
m,
|
| 521 |
-
n);
|
| 522 |
-
}
|
| 523 |
-
} // if (n <= 8192)
|
| 524 |
-
else if (n <= 16384) {
|
| 525 |
-
block.x = ((n + 15)/ 16 + 31)/32*32;
|
| 526 |
-
if (std::is_same<T, float>::value) {
|
| 527 |
-
layernorm_twoPassAlgo_stored_locally_e2<float2, float, 8><<<grid, block, 0, stream>>>(
|
| 528 |
-
(float2*)output,
|
| 529 |
-
(const float2*)input,
|
| 530 |
-
(const float2*)gamma,
|
| 531 |
-
(const float2*)beta,
|
| 532 |
-
m,
|
| 533 |
-
n);
|
| 534 |
-
} // if (std::is_same<T, float>::value)
|
| 535 |
-
else {
|
| 536 |
-
layernorm_twoPassAlgo_stored_locally_e2<half2, half, 8><<<grid, block, 0, stream>>>(
|
| 537 |
-
(half2*)output,
|
| 538 |
-
(const half2*)input,
|
| 539 |
-
(const half2*)gamma,
|
| 540 |
-
(const half2*)beta,
|
| 541 |
-
m,
|
| 542 |
-
n);
|
| 543 |
-
}
|
| 544 |
-
} // if (n <= 16384)
|
| 545 |
-
else if (n <= 32768) {
|
| 546 |
-
block.x = ((n + 31)/32 + 31)/32*32;
|
| 547 |
-
if (std::is_same<T, float>::value) {
|
| 548 |
-
layernorm_twoPassAlgo_stored_locally_e2<float2, float, 16><<<grid, block, 0, stream>>>(
|
| 549 |
-
(float2*)output,
|
| 550 |
-
(const float2*)input,
|
| 551 |
-
(const float2*)gamma,
|
| 552 |
-
(const float2*)beta,
|
| 553 |
-
m,
|
| 554 |
-
n);
|
| 555 |
-
} // if (std::is_same<T, float>::value)
|
| 556 |
-
else {
|
| 557 |
-
layernorm_twoPassAlgo_stored_locally_e2<half2, half, 16><<<grid, block, 0, stream>>>(
|
| 558 |
-
(half2*)output,
|
| 559 |
-
(const half2*)input,
|
| 560 |
-
(const half2*)gamma,
|
| 561 |
-
(const half2*)beta,
|
| 562 |
-
m,
|
| 563 |
-
n);
|
| 564 |
-
}
|
| 565 |
-
} // if (n <= 32768)
|
| 566 |
-
else {
|
| 567 |
-
if (block.x > 512)
|
| 568 |
-
block.x = 512;
|
| 569 |
-
if (std::is_same<T, float>::value) {
|
| 570 |
-
layernorm_twoPassAlgo_e2<float2, float><<<grid, block, 0, stream>>>(
|
| 571 |
-
(float2 *)output,
|
| 572 |
-
(const float2 *)input,
|
| 573 |
-
(const float2 *)gamma,
|
| 574 |
-
(const float2 *)beta,
|
| 575 |
-
m,
|
| 576 |
-
n);
|
| 577 |
-
} // if (std::is_same<T, float>::value)
|
| 578 |
-
else {
|
| 579 |
-
layernorm_twoPassAlgo_e2<half2, half><<<grid, block, 0, stream>>>(
|
| 580 |
-
(half2 *)output,
|
| 581 |
-
(const half2 *)input,
|
| 582 |
-
(const half2 *)gamma,
|
| 583 |
-
(const half2 *)beta,
|
| 584 |
-
m,
|
| 585 |
-
n);
|
| 586 |
-
}
|
| 587 |
-
}
|
| 588 |
-
} // if (n % 2 == 0)
|
| 589 |
-
else {
|
| 590 |
-
if (n <= 1024) {
|
| 591 |
-
layernorm_twoPassAlgo_stored_locally_e1<T, 1><<<grid, block, 0, stream>>>(
|
| 592 |
-
output,
|
| 593 |
-
input,
|
| 594 |
-
gamma,
|
| 595 |
-
beta,
|
| 596 |
-
m,
|
| 597 |
-
n);
|
| 598 |
-
} // if (n <= 1024)
|
| 599 |
-
else if (n <= 8192) {
|
| 600 |
-
block.x = ((n + 7)/8 + 31)/32*32;
|
| 601 |
-
layernorm_twoPassAlgo_stored_locally_e1<T, 8><<<grid, block, 0, stream>>>(
|
| 602 |
-
output,
|
| 603 |
-
input,
|
| 604 |
-
gamma,
|
| 605 |
-
beta,
|
| 606 |
-
m,
|
| 607 |
-
n);
|
| 608 |
-
} // if (n <= 8192)
|
| 609 |
-
else if (n <= 16384) {
|
| 610 |
-
block.x = ((n + 15)/16 + 32)/32*32;
|
| 611 |
-
layernorm_twoPassAlgo_stored_locally_e1<T, 16><<<grid, block, 0, stream>>>(
|
| 612 |
-
output,
|
| 613 |
-
input,
|
| 614 |
-
gamma,
|
| 615 |
-
beta,
|
| 616 |
-
m,
|
| 617 |
-
n);
|
| 618 |
-
} // if (n <= 16384)
|
| 619 |
-
else if (n <= 32768) {
|
| 620 |
-
block.x = ((n + 31)/32 + 31)/32*32;
|
| 621 |
-
layernorm_twoPassAlgo_stored_locally_e1<T, 32><<<grid, block, 0, stream>>>(
|
| 622 |
-
output,
|
| 623 |
-
input,
|
| 624 |
-
gamma,
|
| 625 |
-
beta,
|
| 626 |
-
m,
|
| 627 |
-
n);
|
| 628 |
-
} // if (n <= 32768)
|
| 629 |
-
else{
|
| 630 |
-
if (block.x > 512) {
|
| 631 |
-
block.x = 512;
|
| 632 |
-
}
|
| 633 |
-
layernorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
|
| 634 |
-
output,
|
| 635 |
-
input,
|
| 636 |
-
gamma,
|
| 637 |
-
beta,
|
| 638 |
-
m,
|
| 639 |
-
n);
|
| 640 |
-
}
|
| 641 |
-
}
|
| 642 |
-
}
|
| 643 |
-
|
| 644 |
-
} //namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h
DELETED
|
@@ -1,375 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief C++ interface to CUDA device memory management functions.
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include <memory>
|
| 40 |
-
#include <sstream>
|
| 41 |
-
|
| 42 |
-
#include "cutlass/platform/platform.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
#include "cutlass/trace.h"
|
| 45 |
-
#include "exceptions.h"
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
namespace device_memory {
|
| 49 |
-
|
| 50 |
-
/******************************************************************************
|
| 51 |
-
* Allocation lifetime
|
| 52 |
-
******************************************************************************/
|
| 53 |
-
|
| 54 |
-
/// Allocate a buffer of \p count elements of type \p T on the current CUDA device
|
| 55 |
-
template <typename T>
|
| 56 |
-
T* allocate(size_t count = 1) {
|
| 57 |
-
|
| 58 |
-
T* ptr = 0;
|
| 59 |
-
size_t bytes = count * sizeof_bits<T>::value / 8;
|
| 60 |
-
|
| 61 |
-
cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes);
|
| 62 |
-
|
| 63 |
-
if (cuda_error != cudaSuccess) {
|
| 64 |
-
#if (CUTLASS_DEBUG_TRACE_LEVEL > 0)
|
| 65 |
-
std::ostringstream os;
|
| 66 |
-
os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes;
|
| 67 |
-
CUTLASS_TRACE_HOST(os.str());
|
| 68 |
-
#endif
|
| 69 |
-
throw cuda_exception("Failed to allocate memory", cuda_error);
|
| 70 |
-
}
|
| 71 |
-
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 72 |
-
else {
|
| 73 |
-
std::ostringstream os;
|
| 74 |
-
os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes;
|
| 75 |
-
CUTLASS_TRACE_HOST(os.str());
|
| 76 |
-
}
|
| 77 |
-
#endif
|
| 78 |
-
|
| 79 |
-
return ptr;
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
/// Free the buffer pointed to by \p ptr
|
| 83 |
-
template <typename T>
|
| 84 |
-
void free(T* ptr) {
|
| 85 |
-
if (ptr) {
|
| 86 |
-
cudaError_t cuda_error = (cudaFree(ptr));
|
| 87 |
-
if (cuda_error != cudaSuccess) {
|
| 88 |
-
throw cuda_exception("Failed to free device memory", cuda_error);
|
| 89 |
-
}
|
| 90 |
-
}
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
/******************************************************************************
|
| 94 |
-
* Data movement
|
| 95 |
-
******************************************************************************/
|
| 96 |
-
|
| 97 |
-
template <typename T>
|
| 98 |
-
void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) {
|
| 99 |
-
size_t bytes = count * sizeof_bits<T>::value / 8;
|
| 100 |
-
if (bytes == 0 && count > 0) {
|
| 101 |
-
bytes = 1;
|
| 102 |
-
}
|
| 103 |
-
cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind));
|
| 104 |
-
if (cuda_error != cudaSuccess) {
|
| 105 |
-
std::ostringstream os;
|
| 106 |
-
os << "cutlass::device_memory::copy: cudaMemcpy() failed: "
|
| 107 |
-
<< "dst=" << dst << ", src=" << src
|
| 108 |
-
<< ", bytes=" << bytes << ", count=" << count;
|
| 109 |
-
if (kind == cudaMemcpyHostToDevice) {
|
| 110 |
-
os << ", kind=cudaMemcpyHostToDevice";
|
| 111 |
-
}
|
| 112 |
-
else if (kind == cudaMemcpyDeviceToHost) {
|
| 113 |
-
os << ", kind=cudaMemcpyDeviceToHost";
|
| 114 |
-
}
|
| 115 |
-
else if (kind == cudaMemcpyDeviceToDevice) {
|
| 116 |
-
os << ", kind=cudaMemcpyDeviceToDevice";
|
| 117 |
-
}
|
| 118 |
-
else if (kind == cudaMemcpyHostToHost) {
|
| 119 |
-
os << ", kind=cudaMemcpyHostToHost";
|
| 120 |
-
}
|
| 121 |
-
else if (kind == cudaMemcpyDefault) {
|
| 122 |
-
os << ", kind=cudaMemcpyDefault";
|
| 123 |
-
}
|
| 124 |
-
else {
|
| 125 |
-
os << ", kind=Unknown";
|
| 126 |
-
}
|
| 127 |
-
os << ", error: " << cudaGetErrorString(cuda_error);
|
| 128 |
-
|
| 129 |
-
throw cuda_exception(os.str().c_str(), cuda_error);
|
| 130 |
-
}
|
| 131 |
-
}
|
| 132 |
-
|
| 133 |
-
template <typename T>
|
| 134 |
-
void copy_to_device(T* dst, T const* src, size_t count = 1) {
|
| 135 |
-
copy(dst, src, count, cudaMemcpyHostToDevice);
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
template <typename T>
|
| 139 |
-
void copy_to_host(T* dst, T const* src, size_t count = 1) {
|
| 140 |
-
copy(dst, src, count, cudaMemcpyDeviceToHost);
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
template <typename T>
|
| 144 |
-
void copy_device_to_device(T* dst, T const* src, size_t count = 1) {
|
| 145 |
-
copy(dst, src, count, cudaMemcpyDeviceToDevice);
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
template <typename T>
|
| 149 |
-
void copy_host_to_host(T* dst, T const* src, size_t count = 1) {
|
| 150 |
-
copy(dst, src, count, cudaMemcpyHostToHost);
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
/// Copies elements from device memory to host-side range
|
| 154 |
-
template <typename OutputIterator, typename T>
|
| 155 |
-
void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) {
|
| 156 |
-
size_t elements = end - begin;
|
| 157 |
-
copy_to_host(&*begin, device_begin, elements);
|
| 158 |
-
}
|
| 159 |
-
|
| 160 |
-
/// Copies elements to device memory from host-side range
|
| 161 |
-
template <typename T, typename InputIterator>
|
| 162 |
-
void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) {
|
| 163 |
-
size_t elements = end - begin;
|
| 164 |
-
copy_to_device(device_begin, &*begin, elements);
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 168 |
-
|
| 169 |
-
} // namespace device_memory
|
| 170 |
-
|
| 171 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 172 |
-
|
| 173 |
-
template <typename T>
|
| 174 |
-
class DeviceAllocation {
|
| 175 |
-
public:
|
| 176 |
-
|
| 177 |
-
/// Delete functor for CUDA device memory
|
| 178 |
-
struct deleter {
|
| 179 |
-
void operator()(T* ptr) {
|
| 180 |
-
cudaError_t cuda_error = (cudaFree(ptr));
|
| 181 |
-
if (cuda_error != cudaSuccess) {
|
| 182 |
-
// noexcept
|
| 183 |
-
// throw cuda_exception("cudaFree() failed", cuda_error);
|
| 184 |
-
return;
|
| 185 |
-
}
|
| 186 |
-
}
|
| 187 |
-
};
|
| 188 |
-
|
| 189 |
-
public:
|
| 190 |
-
//
|
| 191 |
-
// Data members
|
| 192 |
-
//
|
| 193 |
-
|
| 194 |
-
/// Number of elements of T allocated on the current CUDA device
|
| 195 |
-
size_t capacity;
|
| 196 |
-
|
| 197 |
-
/// Smart pointer
|
| 198 |
-
platform::unique_ptr<T, deleter> smart_ptr;
|
| 199 |
-
|
| 200 |
-
public:
|
| 201 |
-
|
| 202 |
-
//
|
| 203 |
-
// Static methods
|
| 204 |
-
//
|
| 205 |
-
|
| 206 |
-
/// Static member to compute the number of bytes needed for a given number of elements
|
| 207 |
-
static size_t bytes(size_t elements) {
|
| 208 |
-
if (sizeof_bits<T>::value < 8) {
|
| 209 |
-
size_t const kElementsPerByte = 8 / sizeof_bits<T>::value;
|
| 210 |
-
return elements / kElementsPerByte;
|
| 211 |
-
}
|
| 212 |
-
else {
|
| 213 |
-
size_t const kBytesPerElement = sizeof_bits<T>::value / 8;
|
| 214 |
-
return elements * kBytesPerElement;
|
| 215 |
-
}
|
| 216 |
-
}
|
| 217 |
-
|
| 218 |
-
public:
|
| 219 |
-
|
| 220 |
-
//
|
| 221 |
-
// Methods
|
| 222 |
-
//
|
| 223 |
-
|
| 224 |
-
/// Constructor: allocates no memory
|
| 225 |
-
DeviceAllocation() : capacity(0) {}
|
| 226 |
-
|
| 227 |
-
/// Constructor: allocates \p capacity elements on the current CUDA device
|
| 228 |
-
DeviceAllocation(size_t _capacity) :
|
| 229 |
-
smart_ptr(device_memory::allocate<T>(_capacity)), capacity(_capacity) {}
|
| 230 |
-
|
| 231 |
-
/// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation
|
| 232 |
-
DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {}
|
| 233 |
-
|
| 234 |
-
/// Copy constructor
|
| 235 |
-
DeviceAllocation(DeviceAllocation const &p):
|
| 236 |
-
smart_ptr(device_memory::allocate<T>(p.capacity)), capacity(p.capacity) {
|
| 237 |
-
|
| 238 |
-
device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity);
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
/// Move constructor
|
| 242 |
-
DeviceAllocation(DeviceAllocation &&p): capacity(0) {
|
| 243 |
-
std::swap(smart_ptr, p.smart_ptr);
|
| 244 |
-
std::swap(capacity, p.capacity);
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
/// Destructor
|
| 248 |
-
~DeviceAllocation() { reset(); }
|
| 249 |
-
|
| 250 |
-
/// Returns a pointer to the managed object
|
| 251 |
-
T* get() const { return smart_ptr.get(); }
|
| 252 |
-
|
| 253 |
-
/// Releases the ownership of the managed object (without deleting) and resets capacity to zero
|
| 254 |
-
T* release() {
|
| 255 |
-
capacity = 0;
|
| 256 |
-
return smart_ptr.release();
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
/// Deletes the managed object and resets capacity to zero
|
| 260 |
-
void reset() {
|
| 261 |
-
capacity = 0;
|
| 262 |
-
smart_ptr.reset();
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
-
/// Deletes managed object, if owned, and allocates a new object
|
| 266 |
-
void reset(size_t _capacity) {
|
| 267 |
-
reset(device_memory::allocate<T>(_capacity), _capacity);
|
| 268 |
-
}
|
| 269 |
-
|
| 270 |
-
/// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity
|
| 271 |
-
void reset(T* _ptr, size_t _capacity) {
|
| 272 |
-
smart_ptr.reset(_ptr);
|
| 273 |
-
capacity = _capacity;
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
/// Allocates a new buffer and copies the old buffer into it. The old buffer is then released.
|
| 277 |
-
void reallocate(size_t new_capacity) {
|
| 278 |
-
|
| 279 |
-
platform::unique_ptr<T, deleter> new_allocation(device_memory::allocate<T>(new_capacity));
|
| 280 |
-
|
| 281 |
-
device_memory::copy_device_to_device(
|
| 282 |
-
new_allocation.get(),
|
| 283 |
-
smart_ptr.get(),
|
| 284 |
-
std::min(new_capacity, capacity));
|
| 285 |
-
|
| 286 |
-
std::swap(smart_ptr, new_allocation);
|
| 287 |
-
std::swap(new_capacity, capacity);
|
| 288 |
-
}
|
| 289 |
-
|
| 290 |
-
/// Returns the number of elements
|
| 291 |
-
size_t size() const {
|
| 292 |
-
return capacity;
|
| 293 |
-
}
|
| 294 |
-
|
| 295 |
-
/// Returns the number of bytes needed to store the allocation
|
| 296 |
-
size_t bytes() const {
|
| 297 |
-
return bytes(capacity);
|
| 298 |
-
}
|
| 299 |
-
|
| 300 |
-
/// Returns a pointer to the object owned by *this
|
| 301 |
-
T* operator->() const { return smart_ptr.get(); }
|
| 302 |
-
|
| 303 |
-
/// Returns the deleter object which would be used for destruction of the managed object.
|
| 304 |
-
deleter& get_deleter() { return smart_ptr.get_deleter(); }
|
| 305 |
-
|
| 306 |
-
/// Returns the deleter object which would be used for destruction of the managed object (const)
|
| 307 |
-
const deleter& get_deleter() const { return smart_ptr.get_deleter(); }
|
| 308 |
-
|
| 309 |
-
/// Copies a device-side memory allocation
|
| 310 |
-
DeviceAllocation & operator=(DeviceAllocation const &p) {
|
| 311 |
-
if (capacity != p.capacity) {
|
| 312 |
-
smart_ptr.reset(device_memory::allocate<T>(p.capacity));
|
| 313 |
-
capacity = p.capacity;
|
| 314 |
-
}
|
| 315 |
-
device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity);
|
| 316 |
-
return *this;
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
/// Move assignment
|
| 320 |
-
DeviceAllocation & operator=(DeviceAllocation && p) {
|
| 321 |
-
std::swap(smart_ptr, p.smart_ptr);
|
| 322 |
-
std::swap(capacity, p.capacity);
|
| 323 |
-
return *this;
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
/// Copies the entire allocation from another location in device memory.
|
| 327 |
-
void copy_from_device(T const *ptr) const {
|
| 328 |
-
copy_from_device(ptr, capacity);
|
| 329 |
-
}
|
| 330 |
-
|
| 331 |
-
/// Copies a given number of elements from device memory
|
| 332 |
-
void copy_from_device(T const *ptr, size_t elements) const {
|
| 333 |
-
device_memory::copy_device_to_device(get(), ptr, elements);
|
| 334 |
-
}
|
| 335 |
-
|
| 336 |
-
void copy_to_device(T *ptr) const {
|
| 337 |
-
copy_to_device(ptr, capacity);
|
| 338 |
-
}
|
| 339 |
-
|
| 340 |
-
void copy_to_device(T *ptr, size_t elements) const {
|
| 341 |
-
device_memory::copy_device_to_device(ptr, get(), elements);
|
| 342 |
-
}
|
| 343 |
-
|
| 344 |
-
void copy_from_host(T const *ptr) const {
|
| 345 |
-
copy_from_host(ptr, capacity);
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
void copy_from_host(T const *ptr, size_t elements) const {
|
| 349 |
-
device_memory::copy_to_device(get(), ptr, elements);
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
void copy_to_host(T *ptr) const {
|
| 353 |
-
copy_to_host(ptr, capacity);
|
| 354 |
-
}
|
| 355 |
-
|
| 356 |
-
void copy_to_host(T *ptr, size_t elements) const {
|
| 357 |
-
device_memory::copy_to_host(ptr, get(), elements);
|
| 358 |
-
}
|
| 359 |
-
};
|
| 360 |
-
|
| 361 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 362 |
-
|
| 363 |
-
namespace device_memory {
|
| 364 |
-
|
| 365 |
-
/// Device allocation abstraction that tracks size and capacity
|
| 366 |
-
template <typename T>
|
| 367 |
-
using allocation = cutlass::DeviceAllocation<T>;
|
| 368 |
-
|
| 369 |
-
} // namespace device_memory
|
| 370 |
-
|
| 371 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 372 |
-
|
| 373 |
-
} // namespace cutlass
|
| 374 |
-
|
| 375 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout.
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include "cutlass/cutlass.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/tensor_coord.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
|
| 45 |
-
namespace cutlass {
|
| 46 |
-
|
| 47 |
-
/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout.
|
| 48 |
-
* \tparam T: data type
|
| 49 |
-
*/
|
| 50 |
-
template <typename T>
|
| 51 |
-
void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size,
|
| 52 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 53 |
-
TensorRef<T, layout::TensorNCHW> ref_input,
|
| 54 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 55 |
-
cudaStream_t stream);
|
| 56 |
-
|
| 57 |
-
template <typename T>
|
| 58 |
-
__global__ void nchw_to_nhwc_kernel(T *output,
|
| 59 |
-
const T *input,
|
| 60 |
-
const int n,
|
| 61 |
-
const int h,
|
| 62 |
-
const int w,
|
| 63 |
-
const int c) {
|
| 64 |
-
const int hw = h*w;
|
| 65 |
-
const int chw = c*hw;
|
| 66 |
-
__shared__ T shbuf[32 * (32 + 1)];
|
| 67 |
-
const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x;
|
| 68 |
-
const int32_t wid = tid / 32;
|
| 69 |
-
const int32_t lid = tid % 32;
|
| 70 |
-
const int32_t ni = blockIdx.z;
|
| 71 |
-
const int32_t ci0 = blockIdx.y * 32;
|
| 72 |
-
const int32_t hwi0 = blockIdx.x * 32;
|
| 73 |
-
|
| 74 |
-
const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0;
|
| 75 |
-
const T *A = input + input_idx;
|
| 76 |
-
if (hwi0 + lid < hw) {
|
| 77 |
-
const int lid_x_33 = lid * 33;
|
| 78 |
-
if ((ci0 + 32) <= c) {
|
| 79 |
-
int ci = wid; // between 0 and 7
|
| 80 |
-
CUTLASS_PRAGMA_UNROLL
|
| 81 |
-
for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) {
|
| 82 |
-
shbuf[lid_x_33 + ci] = A[lid];
|
| 83 |
-
A = &A[8 * hw];
|
| 84 |
-
ci += 8;
|
| 85 |
-
}
|
| 86 |
-
} else {
|
| 87 |
-
for (int ci = wid; ci < 32; ci += 8) {
|
| 88 |
-
if ((ci + ci0) < c) {
|
| 89 |
-
shbuf[lid_x_33 + ci] = A[lid];
|
| 90 |
-
}
|
| 91 |
-
A = &A[8 * hw];
|
| 92 |
-
}
|
| 93 |
-
}
|
| 94 |
-
}
|
| 95 |
-
__syncthreads();
|
| 96 |
-
|
| 97 |
-
const int32_t ciOut = ci0 + lid;
|
| 98 |
-
output = &output[ni * chw + ciOut];
|
| 99 |
-
if (ciOut < c) {
|
| 100 |
-
if (hwi0 + 32 < hw) {
|
| 101 |
-
int hwI = wid;
|
| 102 |
-
CUTLASS_PRAGMA_UNROLL
|
| 103 |
-
for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) {
|
| 104 |
-
output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid];
|
| 105 |
-
hwI += 8;
|
| 106 |
-
}
|
| 107 |
-
} else {
|
| 108 |
-
for (int hwI = wid; hwI < 32; hwI += 8) {
|
| 109 |
-
if (hwi0 + hwI < hw) {
|
| 110 |
-
output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid];
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
}
|
| 114 |
-
}
|
| 115 |
-
}
|
| 116 |
-
|
| 117 |
-
template <typename T>
|
| 118 |
-
void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size,
|
| 119 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 120 |
-
TensorRef<T, layout::TensorNCHW> ref_input,
|
| 121 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 122 |
-
cudaStream_t stream) {
|
| 123 |
-
|
| 124 |
-
assert(
|
| 125 |
-
input_tensor_size.n() == output_tensor_size.n() &&
|
| 126 |
-
input_tensor_size.c() == output_tensor_size.h() &&
|
| 127 |
-
input_tensor_size.h() == output_tensor_size.w() &&
|
| 128 |
-
input_tensor_size.w() == output_tensor_size.c());
|
| 129 |
-
|
| 130 |
-
int n = output_tensor_size.n();
|
| 131 |
-
int h = output_tensor_size.h();
|
| 132 |
-
int w = output_tensor_size.w();
|
| 133 |
-
int c = output_tensor_size.c();
|
| 134 |
-
|
| 135 |
-
dim3 grid((h*w + 31)/32, (c + 31)/32, n);
|
| 136 |
-
dim3 block(32, 8);
|
| 137 |
-
nchw_to_nhwc_kernel<<<grid, block, 0, stream>>>(ref_output.data(), ref_input.data(),
|
| 138 |
-
n, h, w, c);
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
} //namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h
DELETED
|
@@ -1,276 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief cuda kernels for padding in device memory with NHWC layout.
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include "cutlass/cutlass.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/tensor_coord.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
|
| 45 |
-
namespace cutlass {
|
| 46 |
-
|
| 47 |
-
/** \brief interface for padding in a device memory tensor with NHWC layout
|
| 48 |
-
* \tparam T: data type
|
| 49 |
-
*/
|
| 50 |
-
template <typename T>
|
| 51 |
-
void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size,
|
| 52 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 53 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 54 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 55 |
-
cudaStream_t stream);
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
template <typename T>
|
| 59 |
-
__global__ void nhwc_padding_kernel(const int32_t n,
|
| 60 |
-
const int32_t h,
|
| 61 |
-
const int32_t w,
|
| 62 |
-
const int32_t c_in,
|
| 63 |
-
const int32_t c_out,
|
| 64 |
-
const T zero,
|
| 65 |
-
const T *input,
|
| 66 |
-
T *output){
|
| 67 |
-
|
| 68 |
-
const int32_t idx_jump = blockDim.x * gridDim.x;
|
| 69 |
-
const int32_t total_elements = n * h * w * c_out;
|
| 70 |
-
|
| 71 |
-
int32_t c_idx, w_idx, h_idx, n_idx, resudial;
|
| 72 |
-
|
| 73 |
-
T value;
|
| 74 |
-
for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) {
|
| 75 |
-
|
| 76 |
-
c_idx = idx%c_out;
|
| 77 |
-
if (c_idx >= c_in){
|
| 78 |
-
value = zero;
|
| 79 |
-
}
|
| 80 |
-
else{
|
| 81 |
-
resudial = idx/c_out;
|
| 82 |
-
w_idx = resudial%w;
|
| 83 |
-
resudial = resudial/w;
|
| 84 |
-
h_idx = resudial%h;
|
| 85 |
-
n_idx = resudial/h;
|
| 86 |
-
resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx;
|
| 87 |
-
value = input[resudial];
|
| 88 |
-
}
|
| 89 |
-
output[idx] = value;
|
| 90 |
-
}
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
// fast kernel for c_in = 3 & c_out = 4
|
| 95 |
-
template <typename Tio, typename Telement, int element_in_Tio>
|
| 96 |
-
__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n,
|
| 97 |
-
const int32_t h,
|
| 98 |
-
const int32_t w,
|
| 99 |
-
const Tio *input,
|
| 100 |
-
Tio *output,
|
| 101 |
-
const int32_t max_output_element,
|
| 102 |
-
const int32_t max_input_element,
|
| 103 |
-
const Tio zero_io,
|
| 104 |
-
const Telement zero_element){
|
| 105 |
-
__shared__ Tio shm[192];
|
| 106 |
-
const int tidx = blockIdx.x * 192 + threadIdx.x;
|
| 107 |
-
const int threadidx = threadIdx.x;
|
| 108 |
-
|
| 109 |
-
shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
|
| 110 |
-
__syncthreads();
|
| 111 |
-
|
| 112 |
-
const int output_offset = blockIdx.x * 256;
|
| 113 |
-
const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256;
|
| 114 |
-
for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
|
| 115 |
-
{
|
| 116 |
-
const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4;
|
| 117 |
-
Telement array[element_in_Tio];
|
| 118 |
-
CUTLASS_PRAGMA_UNROLL
|
| 119 |
-
for (int k = 0 ; k < element_in_Tio ; k++)
|
| 120 |
-
array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k];
|
| 121 |
-
output[i] = *((const Tio *)array);
|
| 122 |
-
}
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
// fast kernel for c_in = 3 & c_out = 8
|
| 126 |
-
template <typename Tio, typename Telement, int element_in_Tio>
|
| 127 |
-
__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n,
|
| 128 |
-
const int32_t h,
|
| 129 |
-
const int32_t w,
|
| 130 |
-
const Tio *input,
|
| 131 |
-
Tio *output,
|
| 132 |
-
const int32_t max_output_element,
|
| 133 |
-
const int32_t max_input_element,
|
| 134 |
-
const Tio zero_io,
|
| 135 |
-
const Telement zero_element){
|
| 136 |
-
__shared__ Tio shm[192];
|
| 137 |
-
const int tidx = blockIdx.x * 192 + threadIdx.x;
|
| 138 |
-
const int threadidx = threadIdx.x;
|
| 139 |
-
|
| 140 |
-
shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx];
|
| 141 |
-
__syncthreads();
|
| 142 |
-
|
| 143 |
-
const int output_offset = blockIdx.x * 512;
|
| 144 |
-
const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512;
|
| 145 |
-
for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192)
|
| 146 |
-
{
|
| 147 |
-
const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3;
|
| 148 |
-
Telement array[element_in_Tio];
|
| 149 |
-
//float
|
| 150 |
-
if (element_in_Tio == 4){
|
| 151 |
-
CUTLASS_PRAGMA_UNROLL
|
| 152 |
-
for (int k = 0 ; k < element_in_Tio ; k++)
|
| 153 |
-
array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]);
|
| 154 |
-
}
|
| 155 |
-
//half
|
| 156 |
-
else{
|
| 157 |
-
CUTLASS_PRAGMA_UNROLL
|
| 158 |
-
for (int k = 0 ; k < element_in_Tio ; k++)
|
| 159 |
-
array[k] = (k >= 3) ? zero_element : shm_element[k];
|
| 160 |
-
}
|
| 161 |
-
output[i] = *((const Tio *)array);
|
| 162 |
-
}
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
template <typename T>
|
| 166 |
-
void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size,
|
| 167 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 168 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 169 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 170 |
-
cudaStream_t stream){
|
| 171 |
-
assert(
|
| 172 |
-
input_tensor_size.n() == output_tensor_size.n() &&
|
| 173 |
-
input_tensor_size.h() == output_tensor_size.h() &&
|
| 174 |
-
input_tensor_size.w() == output_tensor_size.w() &&
|
| 175 |
-
input_tensor_size.c() <= output_tensor_size.c());
|
| 176 |
-
|
| 177 |
-
int n = input_tensor_size.n();
|
| 178 |
-
int h = input_tensor_size.h();
|
| 179 |
-
int w = input_tensor_size.w();
|
| 180 |
-
int c_in = input_tensor_size.c();
|
| 181 |
-
int c_out = output_tensor_size.c();
|
| 182 |
-
|
| 183 |
-
//case 1 : channel == 3 padding to 4 or 8
|
| 184 |
-
if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){
|
| 185 |
-
dim3 block(192);
|
| 186 |
-
const int nhw = n*h*w;
|
| 187 |
-
const int nhwc = nhw*c_in;
|
| 188 |
-
//for half_t
|
| 189 |
-
if (cutlass::sizeof_bits<T>::value == 16){
|
| 190 |
-
const int element_in_Tio = 8;
|
| 191 |
-
const int max_input_element = nhwc/element_in_Tio;
|
| 192 |
-
const int max_output_element = nhw*c_out/element_in_Tio;
|
| 193 |
-
const int4 zero_io = {0, 0, 0, 0};
|
| 194 |
-
const half_t zero_element = static_cast<half_t>(0.0f);
|
| 195 |
-
dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio));
|
| 196 |
-
if (c_out == 4){
|
| 197 |
-
nhwc_padding_channel_3To4_kernel<int4, half_t, element_in_Tio><<<grid, block, 0, stream>>>
|
| 198 |
-
(n, h, w,
|
| 199 |
-
(const int4 *)ref_input.data(),
|
| 200 |
-
(int4 *)ref_output.data(),
|
| 201 |
-
max_output_element,
|
| 202 |
-
max_input_element,
|
| 203 |
-
zero_io,
|
| 204 |
-
zero_element);
|
| 205 |
-
}
|
| 206 |
-
else if (c_out == 8){
|
| 207 |
-
nhwc_padding_channel_3To8_kernel<int4, half_t, element_in_Tio><<<grid, block, 0, stream>>>
|
| 208 |
-
(n, h, w,
|
| 209 |
-
(const int4 *)ref_input.data(),
|
| 210 |
-
(int4 *)ref_output.data(),
|
| 211 |
-
max_output_element,
|
| 212 |
-
max_input_element,
|
| 213 |
-
zero_io,
|
| 214 |
-
zero_element);
|
| 215 |
-
}
|
| 216 |
-
}
|
| 217 |
-
//for float
|
| 218 |
-
else{
|
| 219 |
-
const int element_in_Tio = 4;
|
| 220 |
-
const int max_input_element = nhwc/element_in_Tio;
|
| 221 |
-
const int max_output_element = nhw*c_out/element_in_Tio;
|
| 222 |
-
const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f};
|
| 223 |
-
const float zero_element = 0.0f;
|
| 224 |
-
dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio));
|
| 225 |
-
if (c_out == 4){
|
| 226 |
-
nhwc_padding_channel_3To4_kernel<float4, float, element_in_Tio><<<grid, block, 0, stream>>>
|
| 227 |
-
(n, h, w,
|
| 228 |
-
(const float4 *)ref_input.data(),
|
| 229 |
-
(float4 *)ref_output.data(),
|
| 230 |
-
max_output_element,
|
| 231 |
-
max_input_element,
|
| 232 |
-
zero_io,
|
| 233 |
-
zero_element);
|
| 234 |
-
}
|
| 235 |
-
else if (c_out == 8){
|
| 236 |
-
nhwc_padding_channel_3To8_kernel<float4, float, element_in_Tio><<<grid, block, 0, stream>>>
|
| 237 |
-
(n, h, w,
|
| 238 |
-
(const float4 *)ref_input.data(),
|
| 239 |
-
(float4 *)ref_output.data(),
|
| 240 |
-
max_output_element,
|
| 241 |
-
max_input_element,
|
| 242 |
-
zero_io,
|
| 243 |
-
zero_element);
|
| 244 |
-
}
|
| 245 |
-
}
|
| 246 |
-
}
|
| 247 |
-
//case 2 : even channel
|
| 248 |
-
else if ((c_out % 2) == 0 && (c_in % 2) == 0){
|
| 249 |
-
int32_t total_elements = n * h * w * c_out / 2;
|
| 250 |
-
int block_size = 256;
|
| 251 |
-
dim3 grid((total_elements + 255)/256);
|
| 252 |
-
dim3 block(block_size);
|
| 253 |
-
//for half_t
|
| 254 |
-
if (cutlass::sizeof_bits<T>::value == 16){
|
| 255 |
-
const __half2 zero = {0.0f, 0.0f};
|
| 256 |
-
nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data());
|
| 257 |
-
}
|
| 258 |
-
//for float
|
| 259 |
-
else{
|
| 260 |
-
const float2 zero = {0.0f, 0.0f};
|
| 261 |
-
nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data());
|
| 262 |
-
}
|
| 263 |
-
}
|
| 264 |
-
//case 3 : odd channel
|
| 265 |
-
else{
|
| 266 |
-
int32_t total_elements = n * h * w * c_out;
|
| 267 |
-
int block_size = 256;
|
| 268 |
-
dim3 grid((total_elements + 255)/256);
|
| 269 |
-
dim3 block(block_size);
|
| 270 |
-
const T zero = static_cast<T>(0.0f);
|
| 271 |
-
nhwc_padding_kernel<<<grid, block, 0, stream>>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data());
|
| 272 |
-
}
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
} //namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h
DELETED
|
@@ -1,573 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout.
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include "cutlass/cutlass.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/tensor_coord.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
#include "device_utils.h"
|
| 45 |
-
#include <cfloat>
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
|
| 49 |
-
/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout.
|
| 50 |
-
* \tparam T: data type
|
| 51 |
-
*/
|
| 52 |
-
template <typename T>
|
| 53 |
-
void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size,
|
| 54 |
-
cutlass::Tensor4DCoord filter_tensor_size,
|
| 55 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 56 |
-
cutlass::MatrixCoord padding,
|
| 57 |
-
cutlass::MatrixCoord stride,
|
| 58 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 59 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 60 |
-
int poolingType, //0 for avg pooling ; 1 for max pooling
|
| 61 |
-
cudaStream_t stream);
|
| 62 |
-
|
| 63 |
-
/** get the output size of pooling
|
| 64 |
-
*/
|
| 65 |
-
inline int getOutputSize(int H_W, int padding, int kernel_size, int stride)
|
| 66 |
-
{
|
| 67 |
-
return (H_W + 2 * padding - kernel_size) / stride + 1;
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
/**
|
| 71 |
-
* input is [N, H, W, C]
|
| 72 |
-
* assume stride == kernel_size
|
| 73 |
-
* output_h = (H + 2*padding_H - kernel_H)/stride_H
|
| 74 |
-
* output_w = (W + 2*padding_W - kernel_W)/stride_W
|
| 75 |
-
* output is [N, output_h, output_w, C]
|
| 76 |
-
* grid(N, output_h, output_w)
|
| 77 |
-
* block(min(C, 256)) :
|
| 78 |
-
* each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output)
|
| 79 |
-
*/
|
| 80 |
-
template<typename T, bool IS_AVG_POOLING>
|
| 81 |
-
__global__ void pooling_nhwc_element1_kernel(T* output,
|
| 82 |
-
const T* input,
|
| 83 |
-
const int N,
|
| 84 |
-
const int H,
|
| 85 |
-
const int W,
|
| 86 |
-
const int C,
|
| 87 |
-
const int output_H,
|
| 88 |
-
const int output_W,
|
| 89 |
-
const int kernel_H,
|
| 90 |
-
const int kernel_W,
|
| 91 |
-
const int stride_H,
|
| 92 |
-
const int stride_W,
|
| 93 |
-
const int padding_H,
|
| 94 |
-
const int padding_W)
|
| 95 |
-
{
|
| 96 |
-
const int tid = threadIdx.x;
|
| 97 |
-
const int n_idx = blockIdx.x;
|
| 98 |
-
const int output_h_idx = blockIdx.y;
|
| 99 |
-
const int output_w_idx = blockIdx.z;
|
| 100 |
-
|
| 101 |
-
int h_start_idx = output_h_idx * stride_H - padding_H;
|
| 102 |
-
int h_end_idx = h_start_idx + kernel_H;
|
| 103 |
-
h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx;
|
| 104 |
-
h_end_idx = h_end_idx > H ? H : h_end_idx;
|
| 105 |
-
|
| 106 |
-
int w_start_idx = output_w_idx * stride_W - padding_W;
|
| 107 |
-
int w_end_idx = w_start_idx + kernel_W;
|
| 108 |
-
w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx;
|
| 109 |
-
w_end_idx = w_end_idx > W ? W : w_end_idx;
|
| 110 |
-
|
| 111 |
-
input += n_idx * H * W * C;
|
| 112 |
-
output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C;
|
| 113 |
-
const int kernel_size2 = kernel_H * kernel_W;
|
| 114 |
-
for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) {
|
| 115 |
-
float pooling;
|
| 116 |
-
if (IS_AVG_POOLING){
|
| 117 |
-
pooling = 0.0f;
|
| 118 |
-
}
|
| 119 |
-
else{
|
| 120 |
-
pooling = -FLT_MAX;
|
| 121 |
-
}
|
| 122 |
-
for (int h = h_start_idx; h < h_end_idx; h++) {
|
| 123 |
-
for (int w = w_start_idx; w < w_end_idx; w++) {
|
| 124 |
-
const int idx = (h * W + w) * C;
|
| 125 |
-
const float tmp = static_cast<float>(input[idx + c_idx]);
|
| 126 |
-
if (IS_AVG_POOLING){
|
| 127 |
-
pooling = pooling + tmp;
|
| 128 |
-
}
|
| 129 |
-
else{
|
| 130 |
-
pooling = pooling > tmp ? pooling : tmp;
|
| 131 |
-
}
|
| 132 |
-
}
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
T output_val;
|
| 136 |
-
if (IS_AVG_POOLING){
|
| 137 |
-
output_val = T(pooling/kernel_size2);
|
| 138 |
-
}
|
| 139 |
-
else{
|
| 140 |
-
output_val = T(pooling);
|
| 141 |
-
}
|
| 142 |
-
output[c_idx] = output_val;
|
| 143 |
-
}
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
template<typename T2, typename T, bool IS_AVG_POOLING>
|
| 147 |
-
__global__ void pooling_nhwc_element2_kernel(T2* output,
|
| 148 |
-
const T2* input,
|
| 149 |
-
const int N,
|
| 150 |
-
const int H,
|
| 151 |
-
const int W,
|
| 152 |
-
const int C,
|
| 153 |
-
const int output_H,
|
| 154 |
-
const int output_W,
|
| 155 |
-
const int kernel_H,
|
| 156 |
-
const int kernel_W,
|
| 157 |
-
const int stride_H,
|
| 158 |
-
const int stride_W,
|
| 159 |
-
const int padding_H,
|
| 160 |
-
const int padding_W)
|
| 161 |
-
{
|
| 162 |
-
const int tid = threadIdx.x;
|
| 163 |
-
const int n_idx = blockIdx.x;
|
| 164 |
-
const int output_h_idx = blockIdx.y;
|
| 165 |
-
const int output_w_idx = blockIdx.z;
|
| 166 |
-
|
| 167 |
-
int h_start_idx = output_h_idx * stride_H - padding_H;
|
| 168 |
-
int h_end_idx = h_start_idx + kernel_H;
|
| 169 |
-
h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx;
|
| 170 |
-
h_end_idx = h_end_idx > H ? H : h_end_idx;
|
| 171 |
-
|
| 172 |
-
int w_start_idx = output_w_idx * stride_W - padding_W;
|
| 173 |
-
int w_end_idx = w_start_idx + kernel_W;
|
| 174 |
-
w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx;
|
| 175 |
-
w_end_idx = w_end_idx > W ? W : w_end_idx;
|
| 176 |
-
|
| 177 |
-
input += n_idx * H * W * C;
|
| 178 |
-
output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C;
|
| 179 |
-
const int kernel_size2 = kernel_H * kernel_W;
|
| 180 |
-
for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) {
|
| 181 |
-
float2 pooling;
|
| 182 |
-
if (IS_AVG_POOLING) {
|
| 183 |
-
pooling = {0.0f, 0.0f};
|
| 184 |
-
}
|
| 185 |
-
else {
|
| 186 |
-
pooling = {-FLT_MAX, -FLT_MAX};
|
| 187 |
-
}
|
| 188 |
-
for (int h = h_start_idx; h < h_end_idx; h++) {
|
| 189 |
-
for (int w = w_start_idx; w < w_end_idx; w++) {
|
| 190 |
-
const int idx = (h * W + w) * C;
|
| 191 |
-
const T2 tmp = input[idx + c_idx];
|
| 192 |
-
const float2 tmp_flt2 = {static_cast<float>(tmp.x), static_cast<float>(tmp.y)};
|
| 193 |
-
if (IS_AVG_POOLING) {
|
| 194 |
-
pooling.x += tmp_flt2.x;
|
| 195 |
-
pooling.y += tmp_flt2.y;
|
| 196 |
-
}
|
| 197 |
-
else {
|
| 198 |
-
pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x;
|
| 199 |
-
pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y;
|
| 200 |
-
}
|
| 201 |
-
}
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
T2 output_val;
|
| 205 |
-
if (IS_AVG_POOLING) {
|
| 206 |
-
output_val.x = T(pooling.x/kernel_size2);
|
| 207 |
-
output_val.y = T(pooling.y/kernel_size2);
|
| 208 |
-
}
|
| 209 |
-
else {
|
| 210 |
-
output_val.x = T(pooling.x);
|
| 211 |
-
output_val.y = T(pooling.y);
|
| 212 |
-
}
|
| 213 |
-
output[c_idx] = output_val;
|
| 214 |
-
}
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
/**
|
| 218 |
-
* output [N, 1, 1, C]
|
| 219 |
-
* input [N, H, W, C]
|
| 220 |
-
* grid(C, N)
|
| 221 |
-
* block(block_size) -- each block deals with H*W/block_size elements;
|
| 222 |
-
*/
|
| 223 |
-
template<typename T, bool IS_AVG_POOLING>
|
| 224 |
-
__global__ void pooling_nxhTo1x1_element1_kernel(
|
| 225 |
-
T* output, const T* input, const int N, const int HW, const int C)
|
| 226 |
-
{
|
| 227 |
-
const int c_idx = blockIdx.x;
|
| 228 |
-
const int n_idx = blockIdx.y;
|
| 229 |
-
float pooling[1];
|
| 230 |
-
if (IS_AVG_POOLING) {
|
| 231 |
-
pooling[0] = 0.0f;
|
| 232 |
-
}
|
| 233 |
-
else {
|
| 234 |
-
pooling[0] = -FLT_MAX;
|
| 235 |
-
}
|
| 236 |
-
const size_t input_offset = n_idx * HW * C + c_idx;
|
| 237 |
-
input += input_offset;
|
| 238 |
-
const size_t output_offset = n_idx * C + c_idx;
|
| 239 |
-
output += output_offset;
|
| 240 |
-
int tid = threadIdx.x;
|
| 241 |
-
|
| 242 |
-
for (int index = tid; index < HW; index += blockDim.x) {
|
| 243 |
-
float val = static_cast<float>(input[index * C]);
|
| 244 |
-
if (IS_AVG_POOLING) {
|
| 245 |
-
pooling[0] += val;
|
| 246 |
-
}
|
| 247 |
-
else {
|
| 248 |
-
pooling[0] = pooling[0] > val ? pooling[0] : val;
|
| 249 |
-
}
|
| 250 |
-
}
|
| 251 |
-
if (blockDim.x <= 32) {
|
| 252 |
-
if (IS_AVG_POOLING) {
|
| 253 |
-
warpReduceSum<float, 1>(pooling);
|
| 254 |
-
}
|
| 255 |
-
else {
|
| 256 |
-
warpReduceMax<float, 1>(pooling);
|
| 257 |
-
}
|
| 258 |
-
}
|
| 259 |
-
else {
|
| 260 |
-
if (IS_AVG_POOLING) {
|
| 261 |
-
blockReduceSum<float, 1>(pooling);
|
| 262 |
-
}
|
| 263 |
-
else {
|
| 264 |
-
blockReduceMax<float, 1>(pooling);
|
| 265 |
-
}
|
| 266 |
-
}
|
| 267 |
-
__syncthreads();
|
| 268 |
-
if (threadIdx.x == 0) {
|
| 269 |
-
T output_val;
|
| 270 |
-
if (IS_AVG_POOLING) {
|
| 271 |
-
output_val = T(pooling[0] / HW);
|
| 272 |
-
}
|
| 273 |
-
else {
|
| 274 |
-
output_val = T(pooling[0]);
|
| 275 |
-
}
|
| 276 |
-
output[0] = output_val;
|
| 277 |
-
}
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
/**
|
| 282 |
-
* output [N, 1, 1, C]
|
| 283 |
-
* input [N, H, W, C]
|
| 284 |
-
* grid(C/2, N)
|
| 285 |
-
* block(block_size) -- each thread deals with H*W/block_size * 2 elements;
|
| 286 |
-
*/
|
| 287 |
-
template<typename T2, typename T, bool IS_AVG_POOLING>
|
| 288 |
-
__global__ void pooling_nxhTo1x1_element2_kernel(
|
| 289 |
-
T2* output, const T2* input, const int N, const int HW, const int C)
|
| 290 |
-
{
|
| 291 |
-
const int c_idx = blockIdx.x;
|
| 292 |
-
const int n_idx = blockIdx.y;
|
| 293 |
-
float pooling[2];
|
| 294 |
-
if (IS_AVG_POOLING) {
|
| 295 |
-
pooling[0] = pooling[1] = 0.0f;
|
| 296 |
-
}
|
| 297 |
-
else {
|
| 298 |
-
pooling[0] = pooling[1] = -FLT_MAX;
|
| 299 |
-
}
|
| 300 |
-
const int C_2 = C / 2;
|
| 301 |
-
const size_t input_offset = n_idx * HW * C_2 + c_idx;
|
| 302 |
-
input += input_offset;
|
| 303 |
-
const size_t output_offset = n_idx * C_2 + c_idx;
|
| 304 |
-
output += output_offset;
|
| 305 |
-
int tid = threadIdx.x;
|
| 306 |
-
|
| 307 |
-
for (int index = tid; index < HW; index += blockDim.x) {
|
| 308 |
-
T2 val = input[index * C_2];
|
| 309 |
-
float2 val_flt2 = {static_cast<float>(val.x), static_cast<float>(val.y)};
|
| 310 |
-
if (IS_AVG_POOLING) {
|
| 311 |
-
pooling[0] += val_flt2.x;
|
| 312 |
-
pooling[1] += val_flt2.y;
|
| 313 |
-
}
|
| 314 |
-
else {
|
| 315 |
-
pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x;
|
| 316 |
-
pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y;
|
| 317 |
-
}
|
| 318 |
-
}
|
| 319 |
-
if (blockDim.x <= 32) {
|
| 320 |
-
if (IS_AVG_POOLING) {
|
| 321 |
-
warpReduceSum<float, 2>(pooling);
|
| 322 |
-
}
|
| 323 |
-
else {
|
| 324 |
-
warpReduceMax<float, 2>(pooling);
|
| 325 |
-
}
|
| 326 |
-
}
|
| 327 |
-
else {
|
| 328 |
-
if (IS_AVG_POOLING) {
|
| 329 |
-
blockReduceSum<float, 2>(pooling);
|
| 330 |
-
}
|
| 331 |
-
else {
|
| 332 |
-
blockReduceMax<float, 2>(pooling);
|
| 333 |
-
}
|
| 334 |
-
}
|
| 335 |
-
__syncthreads();
|
| 336 |
-
if (threadIdx.x == 0) {
|
| 337 |
-
T2 output_val;
|
| 338 |
-
if (IS_AVG_POOLING) {
|
| 339 |
-
output_val.x = T(pooling[0] / HW);
|
| 340 |
-
output_val.y = T(pooling[1] / HW);
|
| 341 |
-
}
|
| 342 |
-
else {
|
| 343 |
-
output_val.x = T(pooling[0]);
|
| 344 |
-
output_val.y = T(pooling[1]);
|
| 345 |
-
}
|
| 346 |
-
output[0] = output_val;
|
| 347 |
-
}
|
| 348 |
-
}
|
| 349 |
-
|
| 350 |
-
template <typename T>
|
| 351 |
-
void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size,
|
| 352 |
-
cutlass::Tensor4DCoord filter_tensor_size,
|
| 353 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 354 |
-
cutlass::Tensor4DCoord padding,
|
| 355 |
-
cutlass::MatrixCoord stride,
|
| 356 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 357 |
-
TensorRef<T, layout::TensorNHWC> ref_output,
|
| 358 |
-
int poolingType, //0 for avg pooling ; 1 for max pooling
|
| 359 |
-
cudaStream_t stream) {
|
| 360 |
-
|
| 361 |
-
assert(input_tensor_size.n() == output_tensor_size.n() &&
|
| 362 |
-
input_tensor_size.c() == output_tensor_size.c());
|
| 363 |
-
|
| 364 |
-
const int N = input_tensor_size.n();
|
| 365 |
-
const int H = input_tensor_size.h();
|
| 366 |
-
const int W = input_tensor_size.w();
|
| 367 |
-
const int C = input_tensor_size.c();
|
| 368 |
-
const int padding_H = padding.h();
|
| 369 |
-
const int padding_W = padding.w();
|
| 370 |
-
const int kernel_H = filter_tensor_size.h();
|
| 371 |
-
const int kernel_W = filter_tensor_size.w();
|
| 372 |
-
const int stride_H = stride.row();
|
| 373 |
-
const int stride_W = stride.column();
|
| 374 |
-
|
| 375 |
-
const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H);
|
| 376 |
-
const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W);
|
| 377 |
-
|
| 378 |
-
assert(output_tensor_size.h() == output_H &&
|
| 379 |
-
output_tensor_size.w() == output_W);
|
| 380 |
-
|
| 381 |
-
if (C % 2 != 0) {
|
| 382 |
-
if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) {
|
| 383 |
-
dim3 grid(C, N);
|
| 384 |
-
dim3 block(256);
|
| 385 |
-
if (H*W < block.x){
|
| 386 |
-
block.x = (H*W + 31)/32*32;
|
| 387 |
-
}
|
| 388 |
-
if (poolingType == 0) {
|
| 389 |
-
pooling_nxhTo1x1_element1_kernel<T, true><<<grid, block, 0, stream>>>(
|
| 390 |
-
ref_output.data(),
|
| 391 |
-
ref_input.data(),
|
| 392 |
-
N,
|
| 393 |
-
H*W,
|
| 394 |
-
C);
|
| 395 |
-
} // if (poolingType == 0)
|
| 396 |
-
else {
|
| 397 |
-
pooling_nxhTo1x1_element1_kernel<T, false><<<grid, block, 0, stream>>>(
|
| 398 |
-
ref_output.data(),
|
| 399 |
-
ref_input.data(),
|
| 400 |
-
N,
|
| 401 |
-
H*W,
|
| 402 |
-
C);
|
| 403 |
-
}
|
| 404 |
-
} // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0))
|
| 405 |
-
else {
|
| 406 |
-
dim3 grid(N, output_H, output_W);
|
| 407 |
-
dim3 block(256);
|
| 408 |
-
if (C < block.x) {
|
| 409 |
-
block.x = C;
|
| 410 |
-
}
|
| 411 |
-
if (poolingType == 0) {
|
| 412 |
-
pooling_nhwc_element1_kernel<T, true><<<grid, block, 0, stream>>>(
|
| 413 |
-
ref_output.data(),
|
| 414 |
-
ref_input.data(),
|
| 415 |
-
N,
|
| 416 |
-
H,
|
| 417 |
-
W,
|
| 418 |
-
C,
|
| 419 |
-
output_H,
|
| 420 |
-
output_W,
|
| 421 |
-
kernel_H,
|
| 422 |
-
kernel_W,
|
| 423 |
-
stride_H,
|
| 424 |
-
stride_W,
|
| 425 |
-
padding_H,
|
| 426 |
-
padding_W);
|
| 427 |
-
} // if (poolingType == 0)
|
| 428 |
-
else {
|
| 429 |
-
pooling_nhwc_element1_kernel<T, false><<<grid, block, 0, stream>>>(
|
| 430 |
-
ref_output.data(),
|
| 431 |
-
ref_input.data(),
|
| 432 |
-
N,
|
| 433 |
-
H,
|
| 434 |
-
W,
|
| 435 |
-
C,
|
| 436 |
-
output_H,
|
| 437 |
-
output_W,
|
| 438 |
-
kernel_H,
|
| 439 |
-
kernel_W,
|
| 440 |
-
stride_H,
|
| 441 |
-
stride_W,
|
| 442 |
-
padding_H,
|
| 443 |
-
padding_W);
|
| 444 |
-
}
|
| 445 |
-
}
|
| 446 |
-
} // if (C % 2 != 0))
|
| 447 |
-
else {
|
| 448 |
-
if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) {
|
| 449 |
-
dim3 grid(C/2, N);
|
| 450 |
-
dim3 block(256);
|
| 451 |
-
if (H*W < block.x){
|
| 452 |
-
block.x = (H*W + 31)/32*32;
|
| 453 |
-
}
|
| 454 |
-
if (poolingType == 0) {
|
| 455 |
-
if (std::is_same<T, float>::value) {
|
| 456 |
-
pooling_nxhTo1x1_element2_kernel<float2, float, true><<<grid, block, 0, stream>>>(
|
| 457 |
-
(float2*)(ref_output.data()),
|
| 458 |
-
(const float2*)(ref_input.data()),
|
| 459 |
-
N,
|
| 460 |
-
H*W,
|
| 461 |
-
C);
|
| 462 |
-
} // if (std::is_same<T, float>::value)
|
| 463 |
-
else {
|
| 464 |
-
pooling_nxhTo1x1_element2_kernel<half2, half, true><<<grid, block, 0, stream>>>(
|
| 465 |
-
(half2*)(ref_output.data()),
|
| 466 |
-
(const half2*)(ref_input.data()),
|
| 467 |
-
N,
|
| 468 |
-
H*W,
|
| 469 |
-
C);
|
| 470 |
-
}
|
| 471 |
-
} // if (poolingType == 0)
|
| 472 |
-
else {
|
| 473 |
-
if (std::is_same<T, float>::value) {
|
| 474 |
-
pooling_nxhTo1x1_element2_kernel<float2, float, false><<<grid, block, 0, stream>>>(
|
| 475 |
-
(float2*)(ref_output.data()),
|
| 476 |
-
(const float2*)(ref_input.data()),
|
| 477 |
-
N,
|
| 478 |
-
H*W,
|
| 479 |
-
C);
|
| 480 |
-
} // if (std::is_same<T, float>::value)
|
| 481 |
-
else {
|
| 482 |
-
pooling_nxhTo1x1_element2_kernel<half2, half, false><<<grid, block, 0, stream>>>(
|
| 483 |
-
(half2*)(ref_output.data()),
|
| 484 |
-
(const half2*)(ref_input.data()),
|
| 485 |
-
N,
|
| 486 |
-
H*W,
|
| 487 |
-
C);
|
| 488 |
-
}
|
| 489 |
-
}
|
| 490 |
-
} // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0))
|
| 491 |
-
else {
|
| 492 |
-
dim3 grid(N, output_H, output_W);
|
| 493 |
-
dim3 block(256);
|
| 494 |
-
if (C/2 < block.x) {
|
| 495 |
-
block.x = C/2;
|
| 496 |
-
}
|
| 497 |
-
if (poolingType == 0) {
|
| 498 |
-
if (std::is_same<T, float>::value) {
|
| 499 |
-
pooling_nhwc_element2_kernel<float2, float, true><<<grid, block, 0, stream>>>(
|
| 500 |
-
(float2*)(ref_output.data()),
|
| 501 |
-
(const float2*)(ref_input.data()),
|
| 502 |
-
N,
|
| 503 |
-
H,
|
| 504 |
-
W,
|
| 505 |
-
C/2,
|
| 506 |
-
output_H,
|
| 507 |
-
output_W,
|
| 508 |
-
kernel_H,
|
| 509 |
-
kernel_W,
|
| 510 |
-
stride_H,
|
| 511 |
-
stride_W,
|
| 512 |
-
padding_H,
|
| 513 |
-
padding_W);
|
| 514 |
-
} // if (std::is_same<T, float>::value)
|
| 515 |
-
else {
|
| 516 |
-
pooling_nhwc_element2_kernel<half2, half, true><<<grid, block, 0, stream>>>(
|
| 517 |
-
(half2*)(ref_output.data()),
|
| 518 |
-
(const half2*)(ref_input.data()),
|
| 519 |
-
N,
|
| 520 |
-
H,
|
| 521 |
-
W,
|
| 522 |
-
C/2,
|
| 523 |
-
output_H,
|
| 524 |
-
output_W,
|
| 525 |
-
kernel_H,
|
| 526 |
-
kernel_W,
|
| 527 |
-
stride_H,
|
| 528 |
-
stride_W,
|
| 529 |
-
padding_H,
|
| 530 |
-
padding_W);
|
| 531 |
-
}
|
| 532 |
-
} // if (poolingType == 0)
|
| 533 |
-
else {
|
| 534 |
-
if (std::is_same<T, float>::value) {
|
| 535 |
-
pooling_nhwc_element2_kernel<float2, float, false><<<grid, block, 0, stream>>>(
|
| 536 |
-
(float2*)(ref_output.data()),
|
| 537 |
-
(const float2*)(ref_input.data()),
|
| 538 |
-
N,
|
| 539 |
-
H,
|
| 540 |
-
W,
|
| 541 |
-
C/2,
|
| 542 |
-
output_H,
|
| 543 |
-
output_W,
|
| 544 |
-
kernel_H,
|
| 545 |
-
kernel_W,
|
| 546 |
-
stride_H,
|
| 547 |
-
stride_W,
|
| 548 |
-
padding_H,
|
| 549 |
-
padding_W);
|
| 550 |
-
} // if (std::is_same<T, float>::value)
|
| 551 |
-
else {
|
| 552 |
-
pooling_nhwc_element2_kernel<half2, half, false><<<grid, block, 0, stream>>>(
|
| 553 |
-
(half2*)(ref_output.data()),
|
| 554 |
-
(const half2*)(ref_input.data()),
|
| 555 |
-
N,
|
| 556 |
-
H,
|
| 557 |
-
W,
|
| 558 |
-
C/2,
|
| 559 |
-
output_H,
|
| 560 |
-
output_W,
|
| 561 |
-
kernel_H,
|
| 562 |
-
kernel_W,
|
| 563 |
-
stride_H,
|
| 564 |
-
stride_W,
|
| 565 |
-
padding_H,
|
| 566 |
-
padding_W);
|
| 567 |
-
}
|
| 568 |
-
}
|
| 569 |
-
}
|
| 570 |
-
}
|
| 571 |
-
}
|
| 572 |
-
|
| 573 |
-
} //namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout.
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include "cutlass/cutlass.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_types.h"
|
| 42 |
-
#include "cutlass/tensor_coord.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
|
| 45 |
-
namespace cutlass {
|
| 46 |
-
|
| 47 |
-
/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout.
|
| 48 |
-
* \tparam T: data type
|
| 49 |
-
*/
|
| 50 |
-
template <typename T>
|
| 51 |
-
void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size,
|
| 52 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 53 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 54 |
-
TensorRef<T, layout::TensorNCHW> ref_output,
|
| 55 |
-
cudaStream_t stream);
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
template <typename T>
|
| 59 |
-
__global__ void nhwc_to_nchw_kernel(T *output,
|
| 60 |
-
const T *input,
|
| 61 |
-
const int n,
|
| 62 |
-
const int h,
|
| 63 |
-
const int w,
|
| 64 |
-
const int c) {
|
| 65 |
-
|
| 66 |
-
const int hw = h*w;
|
| 67 |
-
const int hwc = hw*c;
|
| 68 |
-
__shared__ T shbuf[32 * (32 + 1)];
|
| 69 |
-
const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x;
|
| 70 |
-
const int32_t wid = tid / 32;
|
| 71 |
-
const int32_t lid = tid % 32;
|
| 72 |
-
const int32_t ni = blockIdx.z;
|
| 73 |
-
const int32_t hwi0 = blockIdx.y * 32;
|
| 74 |
-
const int32_t ci0 = blockIdx.x * 32;
|
| 75 |
-
|
| 76 |
-
const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0;
|
| 77 |
-
const T *A = input + input_idx;
|
| 78 |
-
if (ci0 + lid < c) {
|
| 79 |
-
const int lid_x_33 = lid * 33;
|
| 80 |
-
if ((hwi0 + 32) <= hw) {
|
| 81 |
-
int hwi = wid; // between 0 and 7
|
| 82 |
-
CUTLASS_PRAGMA_UNROLL
|
| 83 |
-
for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) {
|
| 84 |
-
shbuf[lid_x_33 + hwi] = A[lid];
|
| 85 |
-
A = &A[8 * c];
|
| 86 |
-
hwi += 8;
|
| 87 |
-
}
|
| 88 |
-
} else {
|
| 89 |
-
for (int hwi = wid; hwi < 32; hwi += 8) {
|
| 90 |
-
if ((hwi + hwi0) < hw) {
|
| 91 |
-
shbuf[lid_x_33 + hwi] = A[lid];
|
| 92 |
-
}
|
| 93 |
-
A = &A[8 * c];
|
| 94 |
-
}
|
| 95 |
-
}
|
| 96 |
-
}
|
| 97 |
-
__syncthreads();
|
| 98 |
-
|
| 99 |
-
const int32_t hwiOut = hwi0 + lid;
|
| 100 |
-
output = &output[ni * hwc + hwiOut];
|
| 101 |
-
if (hwiOut < hw) {
|
| 102 |
-
if (ci0 + 32 < c) {
|
| 103 |
-
int cI = wid;
|
| 104 |
-
CUTLASS_PRAGMA_UNROLL
|
| 105 |
-
for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) {
|
| 106 |
-
output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid];
|
| 107 |
-
cI += 8;
|
| 108 |
-
}
|
| 109 |
-
} else {
|
| 110 |
-
for (int cI = wid; cI < 32; cI += 8) {
|
| 111 |
-
if (ci0 + cI < c) {
|
| 112 |
-
output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid];
|
| 113 |
-
}
|
| 114 |
-
}
|
| 115 |
-
}
|
| 116 |
-
}
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
template <typename T>
|
| 120 |
-
void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size,
|
| 121 |
-
cutlass::Tensor4DCoord output_tensor_size,
|
| 122 |
-
TensorRef<T, layout::TensorNHWC> ref_input,
|
| 123 |
-
TensorRef<T, layout::TensorNCHW> ref_output,
|
| 124 |
-
cudaStream_t stream) {
|
| 125 |
-
|
| 126 |
-
assert(
|
| 127 |
-
input_tensor_size.n() == output_tensor_size.n() &&
|
| 128 |
-
input_tensor_size.h() == output_tensor_size.c() &&
|
| 129 |
-
input_tensor_size.w() == output_tensor_size.h() &&
|
| 130 |
-
input_tensor_size.c() == output_tensor_size.w());
|
| 131 |
-
|
| 132 |
-
int n = input_tensor_size.n();
|
| 133 |
-
int h = input_tensor_size.h();
|
| 134 |
-
int w = input_tensor_size.w();
|
| 135 |
-
int c = input_tensor_size.c();
|
| 136 |
-
|
| 137 |
-
dim3 grid((c + 31)/32, (h*w + 31)/32, n);
|
| 138 |
-
dim3 block(32, 8);
|
| 139 |
-
nhwc_to_nchw_kernel<<<grid, block, 0, stream>>>(ref_output.data(), ref_input.data(),
|
| 140 |
-
n, h, w, c);
|
| 141 |
-
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
} //namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h
DELETED
|
@@ -1,186 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/layout/tensor.h"
|
| 36 |
-
#include "cutlass/numeric_types.h"
|
| 37 |
-
#include "cutlass/tensor_coord.h"
|
| 38 |
-
#include "cutlass/tensor_ref.h"
|
| 39 |
-
#include "cutlass/util/device_utils.h"
|
| 40 |
-
#include <cfloat>
|
| 41 |
-
|
| 42 |
-
namespace cutlass {
|
| 43 |
-
|
| 44 |
-
__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
| 45 |
-
const float4 *weight,
|
| 46 |
-
const int m, const int n, float epsilon) {
|
| 47 |
-
const int m_idx = blockIdx.x;
|
| 48 |
-
const int tid = threadIdx.x;
|
| 49 |
-
const int bdimx = blockDim.x;
|
| 50 |
-
__shared__ float s_mean;
|
| 51 |
-
float local_sums[1] = {0.0f};
|
| 52 |
-
const int n_8 = n / 8;
|
| 53 |
-
int offset = m_idx * n_8;
|
| 54 |
-
input += offset;
|
| 55 |
-
output += offset;
|
| 56 |
-
|
| 57 |
-
for (int index = tid; index < n_8; index += bdimx) {
|
| 58 |
-
const float4 local_val = input[index];
|
| 59 |
-
const half2 *h1 = (half2 *)&local_val.x;
|
| 60 |
-
const half2 *h2 = (half2 *)&local_val.y;
|
| 61 |
-
const half2 *h3 = (half2 *)&local_val.z;
|
| 62 |
-
const half2 *h4 = (half2 *)&local_val.w;
|
| 63 |
-
local_sums[0] += static_cast<float>(h1->x) * static_cast<float>(h1->x) +
|
| 64 |
-
static_cast<float>(h1->y) * static_cast<float>(h1->y) +
|
| 65 |
-
static_cast<float>(h2->x) * static_cast<float>(h2->x) +
|
| 66 |
-
static_cast<float>(h2->y) * static_cast<float>(h2->y) +
|
| 67 |
-
static_cast<float>(h3->x) * static_cast<float>(h3->x) +
|
| 68 |
-
static_cast<float>(h3->y) * static_cast<float>(h3->y) +
|
| 69 |
-
static_cast<float>(h4->x) * static_cast<float>(h4->x) +
|
| 70 |
-
static_cast<float>(h4->y) * static_cast<float>(h4->y);
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
if (blockDim.x <= 32) {
|
| 74 |
-
warpReduceSum<float, 1>(local_sums);
|
| 75 |
-
} else {
|
| 76 |
-
blockReduceSum<float, 1>(local_sums);
|
| 77 |
-
}
|
| 78 |
-
if (threadIdx.x == 0) {
|
| 79 |
-
s_mean = rsqrtf(local_sums[0] / n + epsilon);
|
| 80 |
-
}
|
| 81 |
-
__syncthreads();
|
| 82 |
-
|
| 83 |
-
for (int index = tid; index < n_8; index += bdimx) {
|
| 84 |
-
const float4 local_val = input[index];
|
| 85 |
-
const float4 weight_val = weight[index];
|
| 86 |
-
|
| 87 |
-
const half2 *l1 = (half2 *)&local_val.x;
|
| 88 |
-
const half2 *l2 = (half2 *)&local_val.y;
|
| 89 |
-
const half2 *l3 = (half2 *)&local_val.z;
|
| 90 |
-
const half2 *l4 = (half2 *)&local_val.w;
|
| 91 |
-
|
| 92 |
-
const half2 *g1 = (half2 *)&weight_val.x;
|
| 93 |
-
const half2 *g2 = (half2 *)&weight_val.y;
|
| 94 |
-
const half2 *g3 = (half2 *)&weight_val.z;
|
| 95 |
-
const half2 *g4 = (half2 *)&weight_val.w;
|
| 96 |
-
|
| 97 |
-
float4 tmp;
|
| 98 |
-
half2 *h1 = (half2 *)&tmp.x;
|
| 99 |
-
half2 *h2 = (half2 *)&tmp.y;
|
| 100 |
-
half2 *h3 = (half2 *)&tmp.z;
|
| 101 |
-
half2 *h4 = (half2 *)&tmp.w;
|
| 102 |
-
|
| 103 |
-
h1->x = half(static_cast<float>(l1->x) * s_mean * static_cast<float>(g1->x));
|
| 104 |
-
h1->y = half(static_cast<float>(l1->y) * s_mean * static_cast<float>(g1->y));
|
| 105 |
-
h2->x = half(static_cast<float>(l2->x) * s_mean * static_cast<float>(g2->x));
|
| 106 |
-
h2->y = half(static_cast<float>(l2->y) * s_mean * static_cast<float>(g2->y));
|
| 107 |
-
h3->x = half(static_cast<float>(l3->x) * s_mean * static_cast<float>(g3->x));
|
| 108 |
-
h3->y = half(static_cast<float>(l3->y) * s_mean * static_cast<float>(g3->y));
|
| 109 |
-
h4->x = half(static_cast<float>(l4->x) * s_mean * static_cast<float>(g4->x));
|
| 110 |
-
h4->y = half(static_cast<float>(l4->y) * s_mean * static_cast<float>(g4->y));
|
| 111 |
-
|
| 112 |
-
output[index] = tmp;
|
| 113 |
-
}
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
template<typename T>
|
| 117 |
-
__global__ void rmsnorm_twoPassAlgo_e1(T* output,
|
| 118 |
-
const T* input,
|
| 119 |
-
const T* weight,
|
| 120 |
-
const int m, const int n,
|
| 121 |
-
float epsilon)
|
| 122 |
-
{
|
| 123 |
-
const int m_idx = blockIdx.x;
|
| 124 |
-
const int tid = threadIdx.x;
|
| 125 |
-
const int bdimx = blockDim.x;
|
| 126 |
-
__shared__ float s_mean;
|
| 127 |
-
float local_sums[1] = {0.0f};
|
| 128 |
-
int offset = m_idx * n;
|
| 129 |
-
input += offset;
|
| 130 |
-
output += offset;
|
| 131 |
-
|
| 132 |
-
for (int index = tid ; index < n ; index += bdimx){
|
| 133 |
-
float local_val = static_cast<float>(input[index]);
|
| 134 |
-
local_sums[0] += local_val * local_val;
|
| 135 |
-
}
|
| 136 |
-
if (blockDim.x <= 32) {
|
| 137 |
-
warpReduceSum<float, 1>(local_sums);
|
| 138 |
-
}
|
| 139 |
-
else {
|
| 140 |
-
blockReduceSum<float, 1>(local_sums);
|
| 141 |
-
}
|
| 142 |
-
if (threadIdx.x == 0) {
|
| 143 |
-
s_mean = rsqrtf(local_sums[0] / n + epsilon);
|
| 144 |
-
}
|
| 145 |
-
__syncthreads();
|
| 146 |
-
|
| 147 |
-
for (int index = tid ; index < n ; index += bdimx){
|
| 148 |
-
const T weight_val = weight[index];
|
| 149 |
-
const T local_val = input[index];
|
| 150 |
-
output[index] = T(static_cast<float>(local_val) * s_mean * static_cast<float>(weight_val));
|
| 151 |
-
}
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
template <typename T>
|
| 155 |
-
void rmsnorm(cutlass::MatrixCoord tensor_size,
|
| 156 |
-
TensorRef<T, layout::RowMajor> ref_output,
|
| 157 |
-
TensorRef<T, layout::RowMajor> ref_input,
|
| 158 |
-
TensorRef<T, layout::RowMajor> ref_weight,
|
| 159 |
-
cudaStream_t stream, float epsilon = 1e-5f){
|
| 160 |
-
const int m = tensor_size.row();
|
| 161 |
-
const int n = tensor_size.column();
|
| 162 |
-
T* output = ref_output.data();
|
| 163 |
-
const T* input = ref_input.data();
|
| 164 |
-
const T* weight = ref_weight.data();
|
| 165 |
-
dim3 grid(m);
|
| 166 |
-
|
| 167 |
-
if (n % 8 == 0 && std::is_same<T, cutlass::half_t>::value) {
|
| 168 |
-
dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32));
|
| 169 |
-
|
| 170 |
-
rmsnorm_twoPassAlgo_e8<<<grid, block, 0, stream>>>(
|
| 171 |
-
(float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon);
|
| 172 |
-
} else {
|
| 173 |
-
dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32));
|
| 174 |
-
|
| 175 |
-
rmsnorm_twoPassAlgo_e1<<<grid, block, 0, stream>>>(
|
| 176 |
-
output, input, weight, m, n, epsilon);
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
auto result = cudaGetLastError();
|
| 180 |
-
if (result != cudaSuccess) {
|
| 181 |
-
std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl;
|
| 182 |
-
abort();
|
| 183 |
-
}
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h
DELETED
|
@@ -1,127 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief utils code for device cutlass code
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#pragma once
|
| 37 |
-
|
| 38 |
-
#include <cuda_fp16.h>
|
| 39 |
-
#include <cfloat>
|
| 40 |
-
#define FINAL_MASK 0xffffffff
|
| 41 |
-
|
| 42 |
-
struct half4 {
|
| 43 |
-
half x, y, z, w;
|
| 44 |
-
};
|
| 45 |
-
|
| 46 |
-
template<typename T, int NUM>
|
| 47 |
-
__inline__ __device__ T warpReduceSum(T* val)
|
| 48 |
-
{
|
| 49 |
-
#pragma unroll
|
| 50 |
-
for (int i = 0; i < NUM; i++) {
|
| 51 |
-
#pragma unroll
|
| 52 |
-
for (int mask = 16; mask > 0; mask >>= 1)
|
| 53 |
-
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
|
| 54 |
-
}
|
| 55 |
-
return (T)(0.0f);
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
template<typename T, int NUM>
|
| 59 |
-
__inline__ __device__ T blockReduceSum(T* val)
|
| 60 |
-
{
|
| 61 |
-
__shared__ T shared[NUM][33];
|
| 62 |
-
int lane = threadIdx.x & 0x1f;
|
| 63 |
-
int wid = threadIdx.x >> 5;
|
| 64 |
-
|
| 65 |
-
warpReduceSum<T, NUM>(val);
|
| 66 |
-
|
| 67 |
-
if (lane == 0) {
|
| 68 |
-
#pragma unroll
|
| 69 |
-
for (int i = 0; i < NUM; i++) {
|
| 70 |
-
shared[i][wid] = val[i];
|
| 71 |
-
}
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
__syncthreads();
|
| 75 |
-
|
| 76 |
-
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
|
| 77 |
-
#pragma unroll
|
| 78 |
-
for (int i = 0; i < NUM; i++) {
|
| 79 |
-
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
|
| 80 |
-
}
|
| 81 |
-
warpReduceSum<T, NUM>(val);
|
| 82 |
-
return (T)0.0f;
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
template<typename T, int NUM>
|
| 86 |
-
__inline__ __device__ T warpReduceMax(T* val)
|
| 87 |
-
{
|
| 88 |
-
#pragma unroll
|
| 89 |
-
for (int i = 0; i < NUM; i++) {
|
| 90 |
-
#pragma unroll
|
| 91 |
-
for (int mask = 16; mask > 0; mask >>= 1)
|
| 92 |
-
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
|
| 93 |
-
}
|
| 94 |
-
return (T)(0.0f);
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
template<typename T, int NUM>
|
| 98 |
-
__inline__ __device__ T blockReduceMax(T* val)
|
| 99 |
-
{
|
| 100 |
-
static __shared__ T shared[32][NUM];
|
| 101 |
-
int lane = threadIdx.x & 0x1f; // in-warp idx
|
| 102 |
-
int wid = threadIdx.x >> 5; // warp idx
|
| 103 |
-
|
| 104 |
-
warpReduceMax<T, NUM>(val); // get maxx in each warp
|
| 105 |
-
|
| 106 |
-
if (lane == 0) // record in-warp maxx by warp Idx
|
| 107 |
-
{
|
| 108 |
-
#pragma unroll
|
| 109 |
-
for (int i = 0; i < NUM; i++) {
|
| 110 |
-
shared[wid][i] = val[i];
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
__syncthreads();
|
| 115 |
-
|
| 116 |
-
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
| 117 |
-
// blockDim.x is not divided by 32
|
| 118 |
-
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
|
| 119 |
-
#pragma unroll
|
| 120 |
-
for (int i = 0; i < NUM; i++) {
|
| 121 |
-
val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX);
|
| 122 |
-
}
|
| 123 |
-
warpReduceMax<T, NUM>(val);
|
| 124 |
-
|
| 125 |
-
return (T)0.0f;
|
| 126 |
-
}
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
/*! \file
|
| 34 |
-
\brief This header contains a class to parametrize a statistical distribution function.
|
| 35 |
-
*/
|
| 36 |
-
|
| 37 |
-
#include <ostream>
|
| 38 |
-
|
| 39 |
-
namespace cutlass {
|
| 40 |
-
|
| 41 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
-
|
| 43 |
-
/// Distribution type
|
| 44 |
-
struct Distribution {
|
| 45 |
-
/// Variant types
|
| 46 |
-
enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes };
|
| 47 |
-
|
| 48 |
-
/// Distribution state
|
| 49 |
-
union {
|
| 50 |
-
/// Uniform distribution
|
| 51 |
-
struct {
|
| 52 |
-
double min;
|
| 53 |
-
double max;
|
| 54 |
-
// Percent elements set to NaN
|
| 55 |
-
double pnan;
|
| 56 |
-
} uniform;
|
| 57 |
-
|
| 58 |
-
/// Gaussian distribution
|
| 59 |
-
struct {
|
| 60 |
-
double mean;
|
| 61 |
-
double stddev;
|
| 62 |
-
double pnz;
|
| 63 |
-
double pnzA;
|
| 64 |
-
double pnzB;
|
| 65 |
-
double pnzC;
|
| 66 |
-
} gaussian;
|
| 67 |
-
|
| 68 |
-
/// Elements are linear combination of row and column index
|
| 69 |
-
struct {
|
| 70 |
-
double start;
|
| 71 |
-
double delta;
|
| 72 |
-
} sequential;
|
| 73 |
-
};
|
| 74 |
-
|
| 75 |
-
/// Active variant kind
|
| 76 |
-
Kind kind;
|
| 77 |
-
|
| 78 |
-
/// Random values are cast to integer after scaling by this power of two
|
| 79 |
-
int int_scale;
|
| 80 |
-
|
| 81 |
-
//
|
| 82 |
-
// Methods
|
| 83 |
-
//
|
| 84 |
-
|
| 85 |
-
Distribution() : kind(Invalid), int_scale(0) {}
|
| 86 |
-
|
| 87 |
-
/// Configures distribution as uniform random
|
| 88 |
-
Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) {
|
| 89 |
-
kind = Uniform;
|
| 90 |
-
uniform.min = _min;
|
| 91 |
-
uniform.max = _max;
|
| 92 |
-
int_scale = _int_scale;
|
| 93 |
-
uniform.pnan = _pnan;
|
| 94 |
-
return *this;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
/// Configures distribution as Gaussian distribution
|
| 98 |
-
Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) {
|
| 99 |
-
kind = Gaussian;
|
| 100 |
-
gaussian.mean = _mean;
|
| 101 |
-
gaussian.stddev = _stddev;
|
| 102 |
-
gaussian.pnz = _pnz;
|
| 103 |
-
gaussian.pnzA = _pnz;
|
| 104 |
-
gaussian.pnzB = _pnz;
|
| 105 |
-
gaussian.pnzC = _pnz;
|
| 106 |
-
int_scale = _int_scale;
|
| 107 |
-
return *this;
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
/// Sets identity
|
| 111 |
-
Distribution &set_identity() {
|
| 112 |
-
kind = Identity;
|
| 113 |
-
return *this;
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
/// Sets sequential
|
| 117 |
-
Distribution &set_sequential(double start, double delta, int _int_scale = 0) {
|
| 118 |
-
kind = Sequential;
|
| 119 |
-
sequential.start = start;
|
| 120 |
-
sequential.delta = delta;
|
| 121 |
-
int_scale = _int_scale;
|
| 122 |
-
return *this;
|
| 123 |
-
}
|
| 124 |
-
};
|
| 125 |
-
|
| 126 |
-
} // namespace cutlass
|
| 127 |
-
|
| 128 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 129 |
-
|
| 130 |
-
/// Prints a Distribution to ostream
|
| 131 |
-
inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) {
|
| 132 |
-
switch (dist.kind) {
|
| 133 |
-
case cutlass::Distribution::Uniform:
|
| 134 |
-
out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max
|
| 135 |
-
<< ", pnan: " << dist.uniform.pnan;
|
| 136 |
-
break;
|
| 137 |
-
case cutlass::Distribution::Gaussian:
|
| 138 |
-
out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev
|
| 139 |
-
<< ", pnzA: " << dist.gaussian.pnzA << ", pnzB: "
|
| 140 |
-
<< dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC;
|
| 141 |
-
break;
|
| 142 |
-
case cutlass::Distribution::Identity:
|
| 143 |
-
out << "identity";
|
| 144 |
-
break;
|
| 145 |
-
case cutlass::Distribution::Sequential:
|
| 146 |
-
out << "sequential";
|
| 147 |
-
break;
|
| 148 |
-
default:
|
| 149 |
-
out << "unknown";
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
out << ", int_scale: " << dist.int_scale;
|
| 153 |
-
|
| 154 |
-
return out;
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h
DELETED
|
@@ -1,69 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
******************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
/**
|
| 35 |
-
* \file
|
| 36 |
-
* \brief C++ exception semantics for CUDA error codes
|
| 37 |
-
*/
|
| 38 |
-
|
| 39 |
-
#include <cuda_runtime.h>
|
| 40 |
-
#include <iosfwd>
|
| 41 |
-
#include <stdexcept>
|
| 42 |
-
|
| 43 |
-
#include "cutlass/platform/platform.h"
|
| 44 |
-
|
| 45 |
-
namespace cutlass {
|
| 46 |
-
|
| 47 |
-
/// C++ exception wrapper for CUDA \p cudaError_t
|
| 48 |
-
class cuda_exception : public std::exception {
|
| 49 |
-
public:
|
| 50 |
-
/// Constructor
|
| 51 |
-
cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {}
|
| 52 |
-
|
| 53 |
-
/// Returns the underlying CUDA \p cudaError_t
|
| 54 |
-
cudaError_t cudaError() const { return err; }
|
| 55 |
-
|
| 56 |
-
protected:
|
| 57 |
-
/// Explanatory string
|
| 58 |
-
const char* msg;
|
| 59 |
-
|
| 60 |
-
/// Underlying CUDA \p cudaError_t
|
| 61 |
-
cudaError_t err;
|
| 62 |
-
};
|
| 63 |
-
|
| 64 |
-
/// Writes a cuda_exception instance to an output stream
|
| 65 |
-
inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) {
|
| 66 |
-
return out << e.what() << ": " << cudaGetErrorString(e.cudaError());
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp
DELETED
|
@@ -1,369 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief GETT command line parser to gather semantic modes, their stride order, and extents.
|
| 33 |
-
*/
|
| 34 |
-
#pragma once
|
| 35 |
-
|
| 36 |
-
#include <iostream>
|
| 37 |
-
#include <iomanip>
|
| 38 |
-
#include <utility>
|
| 39 |
-
#include <type_traits>
|
| 40 |
-
#include <vector>
|
| 41 |
-
#include <map>
|
| 42 |
-
#include <algorithm>
|
| 43 |
-
#include <numeric>
|
| 44 |
-
|
| 45 |
-
#include "cutlass/util/command_line.h"
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
|
| 49 |
-
// Output shortcuts
|
| 50 |
-
std::ostream& operator<<(std::ostream& os, std::vector<char> data) {
|
| 51 |
-
for (auto& a : data) os << a;
|
| 52 |
-
return os;
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
template <class T>
|
| 56 |
-
std::ostream& operator<<(std::ostream& os, std::vector<T> data) {
|
| 57 |
-
for (auto& a : data) os << a << " ";
|
| 58 |
-
return os;
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
struct GettCommandLine {
|
| 62 |
-
struct GettProblem {
|
| 63 |
-
using extent_type = int;
|
| 64 |
-
using stride_type = int64_t;
|
| 65 |
-
|
| 66 |
-
// Row modes: appear in A and C/D
|
| 67 |
-
std::vector<extent_type> M;
|
| 68 |
-
std::vector<stride_type> ldAm;
|
| 69 |
-
std::vector<stride_type> ldCm;
|
| 70 |
-
|
| 71 |
-
// Column modes: appear in B and C/D
|
| 72 |
-
std::vector<extent_type> N;
|
| 73 |
-
std::vector<stride_type> ldBn;
|
| 74 |
-
std::vector<stride_type> ldCn;
|
| 75 |
-
|
| 76 |
-
// Reduction modes: appear in A and B
|
| 77 |
-
std::vector<extent_type> K;
|
| 78 |
-
std::vector<stride_type> ldAk;
|
| 79 |
-
std::vector<stride_type> ldBk;
|
| 80 |
-
|
| 81 |
-
// Batch modes: appear in all in/out tensors
|
| 82 |
-
std::vector<extent_type> L;
|
| 83 |
-
std::vector<stride_type> ldAl;
|
| 84 |
-
std::vector<stride_type> ldBl;
|
| 85 |
-
std::vector<stride_type> ldCl;
|
| 86 |
-
};
|
| 87 |
-
|
| 88 |
-
static GettProblem
|
| 89 |
-
parse(int argc, char const* argv[], bool parse_verbose = false) {
|
| 90 |
-
using extent_type = typename GettProblem::extent_type;
|
| 91 |
-
using stride_type = typename GettProblem::stride_type;
|
| 92 |
-
|
| 93 |
-
cutlass::CommandLine cmd(argc, argv);
|
| 94 |
-
|
| 95 |
-
// modeA
|
| 96 |
-
std::vector<char> a_mode;
|
| 97 |
-
cmd.get_cmd_line_arguments("modeA", a_mode);
|
| 98 |
-
|
| 99 |
-
// modeB
|
| 100 |
-
std::vector<char> b_mode;
|
| 101 |
-
cmd.get_cmd_line_arguments("modeB", b_mode);
|
| 102 |
-
|
| 103 |
-
// modeC
|
| 104 |
-
std::vector<char> c_mode;
|
| 105 |
-
cmd.get_cmd_line_arguments("modeC", c_mode);
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
// mode_sizes
|
| 109 |
-
std::map<char,extent_type> mode_size;
|
| 110 |
-
// First, initialize all modes in a, b, c to make sure they're in map
|
| 111 |
-
for (char a : a_mode) mode_size[a] = 1;
|
| 112 |
-
for (char b : b_mode) mode_size[b] = 1;
|
| 113 |
-
for (char c : c_mode) mode_size[c] = 1;
|
| 114 |
-
|
| 115 |
-
// Then, overwrite the ones in -extent
|
| 116 |
-
std::vector<std::pair<std::string, std::string> > extent_tokens;
|
| 117 |
-
cmd.get_cmd_line_argument_pairs("extents", extent_tokens);
|
| 118 |
-
for (auto e : extent_tokens) {
|
| 119 |
-
if (std::get<0>(e).size() > 1) {
|
| 120 |
-
std::cerr << "ERROR: Mode name must only be 1 character long.\n";
|
| 121 |
-
print_usage();
|
| 122 |
-
exit(1);
|
| 123 |
-
}
|
| 124 |
-
char label = std::get<0>(e)[0];
|
| 125 |
-
int size = std::stoi(std::get<1>(e));
|
| 126 |
-
mode_size[label] = size;
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
// Print out symbolic modes and their extents
|
| 130 |
-
if (parse_verbose) {
|
| 131 |
-
std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n";
|
| 132 |
-
for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n";
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
//
|
| 136 |
-
// Collect/Compute strides
|
| 137 |
-
//
|
| 138 |
-
|
| 139 |
-
std::map<char,stride_type> mode_ldA;
|
| 140 |
-
std::map<char,stride_type> mode_ldB;
|
| 141 |
-
std::map<char,stride_type> mode_ldC;
|
| 142 |
-
|
| 143 |
-
{
|
| 144 |
-
stride_type current;
|
| 145 |
-
|
| 146 |
-
current = 1;
|
| 147 |
-
for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; }
|
| 148 |
-
|
| 149 |
-
current = 1;
|
| 150 |
-
for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; }
|
| 151 |
-
|
| 152 |
-
current = 1;
|
| 153 |
-
for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; }
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
//
|
| 157 |
-
// Collect mode categories
|
| 158 |
-
//
|
| 159 |
-
|
| 160 |
-
std::vector<char> row_mode; // rows
|
| 161 |
-
std::vector<char> col_mode; // columns
|
| 162 |
-
std::vector<char> red_mode; // reductions
|
| 163 |
-
std::vector<char> bat_mode; // batches
|
| 164 |
-
|
| 165 |
-
{
|
| 166 |
-
std::vector<char> a_label = a_mode;
|
| 167 |
-
std::vector<char> b_label = b_mode;
|
| 168 |
-
std::vector<char> c_label = c_mode;
|
| 169 |
-
|
| 170 |
-
std::sort(std::begin(a_label), std::end(a_label));
|
| 171 |
-
std::sort(std::begin(b_label), std::end(b_label));
|
| 172 |
-
std::sort(std::begin(c_label), std::end(c_label));
|
| 173 |
-
|
| 174 |
-
// std::set_intersections to find semantic category of each symbolic mode
|
| 175 |
-
std::set_intersection(std::begin(a_label), std::end(a_label),
|
| 176 |
-
std::begin(c_label), std::end(c_label),
|
| 177 |
-
std::back_inserter(row_mode));
|
| 178 |
-
|
| 179 |
-
std::set_intersection(std::begin(b_label), std::end(b_label),
|
| 180 |
-
std::begin(c_label), std::end(c_label),
|
| 181 |
-
std::back_inserter(col_mode));
|
| 182 |
-
|
| 183 |
-
std::set_intersection(std::begin(a_label), std::end(a_label),
|
| 184 |
-
std::begin(b_label), std::end(b_label),
|
| 185 |
-
std::back_inserter(red_mode));
|
| 186 |
-
|
| 187 |
-
std::set_intersection(std::begin(row_mode), std::end(row_mode),
|
| 188 |
-
std::begin(col_mode), std::end(col_mode),
|
| 189 |
-
std::back_inserter(bat_mode));
|
| 190 |
-
|
| 191 |
-
// std::set_difference to remove batch modes from other semantic modes
|
| 192 |
-
for (char l : bat_mode) {
|
| 193 |
-
row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode));
|
| 194 |
-
col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode));
|
| 195 |
-
red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode));
|
| 196 |
-
}
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
// Print out the semantic association of each symbolic mode
|
| 200 |
-
if (parse_verbose) {
|
| 201 |
-
std::cout << " rows : " << row_mode << '\n';
|
| 202 |
-
std::cout << " cols : " << col_mode << '\n';
|
| 203 |
-
std::cout << " reds : " << red_mode << '\n';
|
| 204 |
-
std::cout << " bats : " << bat_mode << '\n';
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
//
|
| 208 |
-
// Permute modes
|
| 209 |
-
//
|
| 210 |
-
|
| 211 |
-
// Permute the batched modes to promote coalescing
|
| 212 |
-
// Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size
|
| 213 |
-
std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) {
|
| 214 |
-
return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1])
|
| 215 |
-
< std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]);
|
| 216 |
-
});
|
| 217 |
-
// Compute sizes and strides of ordered reduction modes
|
| 218 |
-
std::vector<extent_type> L;
|
| 219 |
-
std::vector<stride_type> ldAl;
|
| 220 |
-
std::vector<stride_type> ldBl;
|
| 221 |
-
std::vector<stride_type> ldCl;
|
| 222 |
-
for (char l : bat_mode) {
|
| 223 |
-
L.push_back(mode_size[l]);
|
| 224 |
-
ldAl.push_back(mode_ldA[l]);
|
| 225 |
-
ldBl.push_back(mode_ldB[l]);
|
| 226 |
-
ldCl.push_back(mode_ldC[l]);
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
// Permute the reduction modes to promote coalescing
|
| 230 |
-
// Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size
|
| 231 |
-
std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) {
|
| 232 |
-
return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1])
|
| 233 |
-
< std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]);
|
| 234 |
-
});
|
| 235 |
-
// Compute sizes and strides of ordered reduction modes
|
| 236 |
-
std::vector<extent_type> K;
|
| 237 |
-
std::vector<stride_type> ldAk;
|
| 238 |
-
std::vector<stride_type> ldBk;
|
| 239 |
-
for (char k : red_mode) {
|
| 240 |
-
K.push_back(mode_size[k]);
|
| 241 |
-
ldAk.push_back(mode_ldA[k]);
|
| 242 |
-
ldBk.push_back(mode_ldB[k]);
|
| 243 |
-
}
|
| 244 |
-
|
| 245 |
-
// Permute the row modes to promote coalescing
|
| 246 |
-
// Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm
|
| 247 |
-
std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) {
|
| 248 |
-
return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1])
|
| 249 |
-
< std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]);
|
| 250 |
-
});
|
| 251 |
-
// Compute sizes and strides of ordered row modes
|
| 252 |
-
std::vector<extent_type> M;
|
| 253 |
-
std::vector<stride_type> ldAm;
|
| 254 |
-
std::vector<stride_type> ldCm;
|
| 255 |
-
for (char m : row_mode) {
|
| 256 |
-
M.push_back(mode_size[m]);
|
| 257 |
-
ldAm.push_back(mode_ldA[m]);
|
| 258 |
-
ldCm.push_back(mode_ldC[m]);
|
| 259 |
-
}
|
| 260 |
-
|
| 261 |
-
// Permute the col modes to promote coalescing
|
| 262 |
-
// Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn
|
| 263 |
-
std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) {
|
| 264 |
-
return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1])
|
| 265 |
-
< std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]);
|
| 266 |
-
});
|
| 267 |
-
// Compute sizes and strides of ordered col modes
|
| 268 |
-
std::vector<extent_type> N;
|
| 269 |
-
std::vector<stride_type> ldBn;
|
| 270 |
-
std::vector<stride_type> ldCn;
|
| 271 |
-
for (char n : col_mode) {
|
| 272 |
-
N.push_back(mode_size[n]);
|
| 273 |
-
ldBn.push_back(mode_ldB[n]);
|
| 274 |
-
ldCn.push_back(mode_ldC[n]);
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
if (parse_verbose) {
|
| 278 |
-
std::cout << "C_";
|
| 279 |
-
if (! row_mode.empty()) {
|
| 280 |
-
std::cout << "(" << row_mode << ")";
|
| 281 |
-
}
|
| 282 |
-
if (! col_mode.empty()) {
|
| 283 |
-
std::cout << "(" << col_mode << ")";
|
| 284 |
-
}
|
| 285 |
-
if (! bat_mode.empty()) {
|
| 286 |
-
std::cout << "(" << bat_mode << ")";
|
| 287 |
-
}
|
| 288 |
-
std::cout << " = A_";
|
| 289 |
-
if (! row_mode.empty()) {
|
| 290 |
-
std::cout << "(" << row_mode << ")";
|
| 291 |
-
}
|
| 292 |
-
if (! red_mode.empty()) {
|
| 293 |
-
std::cout << "(" << red_mode << ")";
|
| 294 |
-
}
|
| 295 |
-
if (! bat_mode.empty()) {
|
| 296 |
-
std::cout << "(" << bat_mode << ")";
|
| 297 |
-
}
|
| 298 |
-
std::cout << " * B_";
|
| 299 |
-
if (! col_mode.empty()) {
|
| 300 |
-
std::cout << "(" << col_mode << ")";
|
| 301 |
-
}
|
| 302 |
-
if (! red_mode.empty()) {
|
| 303 |
-
std::cout << "(" << red_mode << ")";
|
| 304 |
-
}
|
| 305 |
-
if (! bat_mode.empty()) {
|
| 306 |
-
std::cout << "(" << bat_mode << ")";
|
| 307 |
-
}
|
| 308 |
-
std::cout << '\n';
|
| 309 |
-
|
| 310 |
-
int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{});
|
| 311 |
-
int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{});
|
| 312 |
-
int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{});
|
| 313 |
-
int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{});
|
| 314 |
-
|
| 315 |
-
std::cout << " M : (" << M_size << ") ";
|
| 316 |
-
for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " ";
|
| 317 |
-
std::cout << '\n';
|
| 318 |
-
std::cout << " N : (" << N_size << ") ";
|
| 319 |
-
for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " ";
|
| 320 |
-
std::cout << '\n';
|
| 321 |
-
std::cout << " K : (" << K_size << ") ";
|
| 322 |
-
for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " ";
|
| 323 |
-
std::cout << '\n';
|
| 324 |
-
std::cout << " L : (" << L_size << ") ";
|
| 325 |
-
for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " ";
|
| 326 |
-
std::cout << '\n';
|
| 327 |
-
|
| 328 |
-
std::cout << " ldAm : " << ldAm << '\n';
|
| 329 |
-
std::cout << " ldAk : " << ldAk << '\n';
|
| 330 |
-
std::cout << " ldAl : " << ldAl << '\n';
|
| 331 |
-
std::cout << " ldBn : " << ldBn << '\n';
|
| 332 |
-
std::cout << " ldBk : " << ldBk << '\n';
|
| 333 |
-
std::cout << " ldBl : " << ldBl << '\n';
|
| 334 |
-
std::cout << " ldCm : " << ldCm << '\n';
|
| 335 |
-
std::cout << " ldCn : " << ldCn << '\n';
|
| 336 |
-
std::cout << " ldCl : " << ldCl << '\n';
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
return {M, ldAm, ldCm,
|
| 340 |
-
N, ldBn, ldCn,
|
| 341 |
-
K, ldAk, ldBk,
|
| 342 |
-
L, ldAl, ldBl, ldCl};
|
| 343 |
-
}
|
| 344 |
-
|
| 345 |
-
static void
|
| 346 |
-
print_usage() {
|
| 347 |
-
std::cout <<
|
| 348 |
-
"GETT problem command line parser:\n"
|
| 349 |
-
" --modeA=<m0,...>\n"
|
| 350 |
-
" A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n"
|
| 351 |
-
" The semantic association of each symbolic mode is determined automatically.\n\n"
|
| 352 |
-
|
| 353 |
-
" --modeB=<m0,...>\n"
|
| 354 |
-
" A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n"
|
| 355 |
-
" The semantic association of each symbolic mode is determined automatically.\n\n"
|
| 356 |
-
|
| 357 |
-
" --modeC=<m0,...>\n"
|
| 358 |
-
" A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n"
|
| 359 |
-
" The semantic association of each symbolic mode is determined automatically.\n\n"
|
| 360 |
-
|
| 361 |
-
" --extents=<mode:extent,....>\n"
|
| 362 |
-
" A command delimited list of symbolic mode and its corresponding extent.\n"
|
| 363 |
-
" Extents are defaulted to 1 if any are not provided.\n\n"
|
| 364 |
-
|
| 365 |
-
"Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n";
|
| 366 |
-
}
|
| 367 |
-
};
|
| 368 |
-
|
| 369 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp
DELETED
|
@@ -1,116 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <cuda.h>
|
| 35 |
-
|
| 36 |
-
#include <cute/util/debug.hpp>
|
| 37 |
-
|
| 38 |
-
namespace cute
|
| 39 |
-
{
|
| 40 |
-
|
| 41 |
-
void
|
| 42 |
-
device_init(int device_id, bool quiet = false)
|
| 43 |
-
{
|
| 44 |
-
cudaDeviceProp device_prop;
|
| 45 |
-
std::size_t device_free_physmem;
|
| 46 |
-
std::size_t device_total_physmem;
|
| 47 |
-
|
| 48 |
-
CUTE_CHECK_ERROR(cudaSetDevice(device_id));
|
| 49 |
-
CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem));
|
| 50 |
-
CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id));
|
| 51 |
-
|
| 52 |
-
if (device_prop.major < 1) {
|
| 53 |
-
fprintf(stderr, "Device does not support CUDA.\n");
|
| 54 |
-
exit(1);
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
//float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000;
|
| 58 |
-
|
| 59 |
-
if (!quiet) {
|
| 60 |
-
printf("Using device %d: %s (SM%d, %d SMs)\n",
|
| 61 |
-
device_id, device_prop.name,
|
| 62 |
-
device_prop.major * 10 + device_prop.minor,
|
| 63 |
-
device_prop.multiProcessorCount);
|
| 64 |
-
fflush(stdout);
|
| 65 |
-
}
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
/**
|
| 69 |
-
* Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores.
|
| 70 |
-
*/
|
| 71 |
-
inline int
|
| 72 |
-
_ConvertSMVer2Cores(int major, int minor)
|
| 73 |
-
{
|
| 74 |
-
// Defines for GPU Architecture types (using the SM version to determine
|
| 75 |
-
// the # of cores per SM
|
| 76 |
-
typedef struct {
|
| 77 |
-
int SM; // 0xMm (hexadecimal notation), M = SM Major version,
|
| 78 |
-
// and m = SM minor version
|
| 79 |
-
int Cores;
|
| 80 |
-
} sSMtoCores;
|
| 81 |
-
|
| 82 |
-
sSMtoCores nGpuArchCoresPerSM[] = {
|
| 83 |
-
{0x30, 192},
|
| 84 |
-
{0x32, 192},
|
| 85 |
-
{0x35, 192},
|
| 86 |
-
{0x37, 192},
|
| 87 |
-
{0x50, 128},
|
| 88 |
-
{0x52, 128},
|
| 89 |
-
{0x53, 128},
|
| 90 |
-
{0x60, 64},
|
| 91 |
-
{0x61, 128},
|
| 92 |
-
{0x62, 128},
|
| 93 |
-
{0x70, 64},
|
| 94 |
-
{0x72, 64},
|
| 95 |
-
{0x75, 64},
|
| 96 |
-
{-1, -1}};
|
| 97 |
-
|
| 98 |
-
int index = 0;
|
| 99 |
-
|
| 100 |
-
while (nGpuArchCoresPerSM[index].SM != -1) {
|
| 101 |
-
if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) {
|
| 102 |
-
return nGpuArchCoresPerSM[index].Cores;
|
| 103 |
-
}
|
| 104 |
-
index++;
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
// If we don't find the values, we default use the previous one
|
| 108 |
-
// to run properly
|
| 109 |
-
printf("MapSMtoCores for SM %d.%d is undefined."
|
| 110 |
-
" Default to use %d Cores/SM\n",
|
| 111 |
-
major, minor, nGpuArchCoresPerSM[index - 1].Cores);
|
| 112 |
-
|
| 113 |
-
return nGpuArchCoresPerSM[index - 1].Cores;
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
} // end namespace cute
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief reorder data from the host side
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#pragma once
|
| 37 |
-
|
| 38 |
-
#include "cutlass/coord.h"
|
| 39 |
-
#include "cutlass/util/host_tensor.h"
|
| 40 |
-
#include "cutlass/tensor_view.h"
|
| 41 |
-
#include "cutlass/util/tensor_view_io.h"
|
| 42 |
-
#include "cutlass/util/reference/host/gemm.h"
|
| 43 |
-
|
| 44 |
-
namespace cutlass {
|
| 45 |
-
|
| 46 |
-
/// This is needed for the interleaved integer tensor core kernels. The purpose
|
| 47 |
-
/// is to use skip the shared memory part in the epilogue.
|
| 48 |
-
template <int Interleaved, typename Element, typename Layout>
|
| 49 |
-
void reorder_column(TensorRef<Element, Layout> dest,
|
| 50 |
-
TensorRef<Element, Layout> src,
|
| 51 |
-
cutlass::gemm::GemmCoord problem_size) {
|
| 52 |
-
const int InstructionShapeCol = 8;
|
| 53 |
-
// 4 threads per Quad
|
| 54 |
-
const int ElementsPerThread = InstructionShapeCol / 4;
|
| 55 |
-
// 4 threads per Quad
|
| 56 |
-
const int ReorderedElementsPerThread =
|
| 57 |
-
Interleaved / 4;
|
| 58 |
-
|
| 59 |
-
for (int n = 0; n < problem_size.n(); n++) {
|
| 60 |
-
for (int k = 0; k < problem_size.k(); k++) {
|
| 61 |
-
dest.at({k, (n / Interleaved) * Interleaved +
|
| 62 |
-
((n % ReorderedElementsPerThread) / ElementsPerThread) *
|
| 63 |
-
InstructionShapeCol +
|
| 64 |
-
((n % Interleaved) / ReorderedElementsPerThread) *
|
| 65 |
-
ElementsPerThread +
|
| 66 |
-
(n % ElementsPerThread)}) = src.at({k, n});
|
| 67 |
-
}
|
| 68 |
-
}
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
template <int ColumnInterleaved, int LayoutInterleaved = ColumnInterleaved, typename Element, typename Layout>
|
| 72 |
-
void reorder_convK(TensorRef<Element, Layout> dest,
|
| 73 |
-
TensorRef<Element, Layout> src,
|
| 74 |
-
cutlass::gemm::GemmCoord problem_size) {
|
| 75 |
-
|
| 76 |
-
TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedDest(dest.data(), dest.stride(0));
|
| 77 |
-
TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedSrc(src.data(), src.stride(0));
|
| 78 |
-
|
| 79 |
-
reorder_column<ColumnInterleaved>(
|
| 80 |
-
mappedDest, mappedSrc, problem_size);
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
/// This is needed for the sparse tensor core kernels. The purpose
|
| 84 |
-
/// is to use ldmatrix to load from shared memory to the register file.
|
| 85 |
-
template <typename Element, typename LayoutDest, typename LayoutSrc>
|
| 86 |
-
void reorder_meta(TensorRef<Element, LayoutDest> dest,
|
| 87 |
-
TensorRef<Element, LayoutSrc> src,
|
| 88 |
-
cutlass::gemm::GemmCoord problem_size) {
|
| 89 |
-
for (int m = 0; m < problem_size.m(); m++) {
|
| 90 |
-
for (int k = 0; k < problem_size.k(); k++) {
|
| 91 |
-
// First reorder the rows.
|
| 92 |
-
int group = (sizeof(Element) == 2) ? 32 : 16;
|
| 93 |
-
int interweave = (sizeof(Element) == 2) ? 4 : 2;
|
| 94 |
-
|
| 95 |
-
int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
|
| 96 |
-
int dest_col = k;
|
| 97 |
-
|
| 98 |
-
// Next swizzle the 2x2 blocks from Z to N.
|
| 99 |
-
if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
|
| 100 |
-
++dest_row;
|
| 101 |
-
--dest_col;
|
| 102 |
-
} else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
|
| 103 |
-
--dest_row;
|
| 104 |
-
++dest_col;
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
dest.at({dest_row, dest_col}) = src.at({m, k});
|
| 108 |
-
}
|
| 109 |
-
}
|
| 110 |
-
}
|
| 111 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h
DELETED
|
@@ -1,541 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
/*! \file
|
| 34 |
-
\brief HostTensor contributes management for both host and device memory.
|
| 35 |
-
|
| 36 |
-
HostTensor allocates host and device memory upon construction. Basic element-wise operations on
|
| 37 |
-
host memory synchronize device memory automatically. Explicit copy operations provide abstractions
|
| 38 |
-
for CUDA memcpy operations.
|
| 39 |
-
|
| 40 |
-
Call {host, device}_{data, ref, view}() for accessing host or device memory.
|
| 41 |
-
|
| 42 |
-
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
|
| 43 |
-
*/
|
| 44 |
-
|
| 45 |
-
#include <vector>
|
| 46 |
-
|
| 47 |
-
#include "cutlass/cutlass.h"
|
| 48 |
-
#include "cutlass/tensor_ref.h"
|
| 49 |
-
#include "cutlass/tensor_view.h"
|
| 50 |
-
#include "cutlass/fast_math.h"
|
| 51 |
-
|
| 52 |
-
#include "device_memory.h"
|
| 53 |
-
|
| 54 |
-
namespace cutlass {
|
| 55 |
-
|
| 56 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
-
|
| 58 |
-
/// Host tensor
|
| 59 |
-
template <
|
| 60 |
-
/// Data type of element stored within tensor (concept: NumericType)
|
| 61 |
-
typename Element_,
|
| 62 |
-
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
| 63 |
-
typename Layout_
|
| 64 |
-
>
|
| 65 |
-
class HostTensor {
|
| 66 |
-
public:
|
| 67 |
-
|
| 68 |
-
/// Data type of individual access
|
| 69 |
-
using Element = Element_;
|
| 70 |
-
|
| 71 |
-
/// Mapping function from logical coordinate to linear memory
|
| 72 |
-
using Layout = Layout_;
|
| 73 |
-
|
| 74 |
-
/// Logical rank of tensor index space
|
| 75 |
-
static int const kRank = Layout::kRank;
|
| 76 |
-
|
| 77 |
-
/// Index type
|
| 78 |
-
using Index = typename Layout::Index;
|
| 79 |
-
|
| 80 |
-
/// Long index used for pointer offsets
|
| 81 |
-
using LongIndex = typename Layout::LongIndex;
|
| 82 |
-
|
| 83 |
-
/// Coordinate in logical tensor space
|
| 84 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 85 |
-
|
| 86 |
-
/// Layout's stride vector
|
| 87 |
-
using Stride = typename Layout::Stride;
|
| 88 |
-
|
| 89 |
-
/// Tensor reference to device memory
|
| 90 |
-
using TensorRef = TensorRef<Element, Layout>;
|
| 91 |
-
|
| 92 |
-
/// Tensor reference to constant device memory
|
| 93 |
-
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 94 |
-
|
| 95 |
-
/// Tensor reference to device memory
|
| 96 |
-
using TensorView = TensorView<Element, Layout>;
|
| 97 |
-
|
| 98 |
-
/// Tensor reference to constant device memory
|
| 99 |
-
using ConstTensorView = typename TensorView::ConstTensorView;
|
| 100 |
-
|
| 101 |
-
/// Reference to element in tensor
|
| 102 |
-
using Reference = typename TensorRef::Reference;
|
| 103 |
-
|
| 104 |
-
/// Constant reference to element in tensor
|
| 105 |
-
using ConstReference = typename ConstTensorRef::Reference;
|
| 106 |
-
|
| 107 |
-
private:
|
| 108 |
-
using StorageUnit = typename platform::conditional_t<std::is_same_v<Element, bool>, uint8_t, // Avoid the std::vector<bool> specialization
|
| 109 |
-
typename platform::conditional_t<sizeof_bits<Element>::value % 8 == 0, // Handle subbyte types
|
| 110 |
-
Element, uint8_t>>;
|
| 111 |
-
using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
|
| 112 |
-
static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits;
|
| 113 |
-
static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements;
|
| 114 |
-
static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes;
|
| 115 |
-
static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit;
|
| 116 |
-
|
| 117 |
-
//
|
| 118 |
-
// Data members
|
| 119 |
-
//
|
| 120 |
-
|
| 121 |
-
/// Extent of tensor in logical dimensions
|
| 122 |
-
TensorCoord extent_;
|
| 123 |
-
|
| 124 |
-
/// Layout object
|
| 125 |
-
Layout layout_;
|
| 126 |
-
|
| 127 |
-
/// Host-side memory allocation
|
| 128 |
-
std::vector<StorageUnit> host_;
|
| 129 |
-
|
| 130 |
-
/// Device-side memory
|
| 131 |
-
device_memory::allocation<StorageUnit> device_;
|
| 132 |
-
|
| 133 |
-
/// number of containers
|
| 134 |
-
size_t count_to_container_storage_unit_count(size_t count) {
|
| 135 |
-
return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit;
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
public:
|
| 139 |
-
//
|
| 140 |
-
// Device and Host Methods
|
| 141 |
-
//
|
| 142 |
-
|
| 143 |
-
/// Default constructor
|
| 144 |
-
HostTensor() {}
|
| 145 |
-
|
| 146 |
-
/// Constructs a tensor given an extent. Assumes a packed layout
|
| 147 |
-
HostTensor(
|
| 148 |
-
TensorCoord const &extent,
|
| 149 |
-
bool device_backed = true
|
| 150 |
-
) {
|
| 151 |
-
|
| 152 |
-
this->reset(extent, Layout::packed(extent), device_backed);
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
/// Constructs a tensor given an extent and layout
|
| 156 |
-
HostTensor(
|
| 157 |
-
TensorCoord const &extent,
|
| 158 |
-
Layout const &layout,
|
| 159 |
-
bool device_backed = true
|
| 160 |
-
) {
|
| 161 |
-
|
| 162 |
-
this->reset(extent, layout, device_backed);
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
~HostTensor() { }
|
| 166 |
-
|
| 167 |
-
/// Clears the HostTensor allocation to size/capacity = 0
|
| 168 |
-
void reset() {
|
| 169 |
-
extent_ = TensorCoord();
|
| 170 |
-
layout_ = Layout::packed(extent_);
|
| 171 |
-
|
| 172 |
-
host_.clear();
|
| 173 |
-
device_.reset();
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
/// Resizes internal memory allocations without affecting layout or extent
|
| 177 |
-
void reserve(
|
| 178 |
-
size_t count, ///< size of tensor in elements
|
| 179 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated
|
| 180 |
-
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 181 |
-
CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")");
|
| 182 |
-
#endif
|
| 183 |
-
|
| 184 |
-
device_.reset();
|
| 185 |
-
host_.clear();
|
| 186 |
-
|
| 187 |
-
size_t count_container = count_to_container_storage_unit_count(count);
|
| 188 |
-
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 189 |
-
CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")");
|
| 190 |
-
#endif
|
| 191 |
-
host_.resize(count_container);
|
| 192 |
-
|
| 193 |
-
// Allocate memory
|
| 194 |
-
StorageUnit* device_memory = nullptr;
|
| 195 |
-
if (device_backed_) {
|
| 196 |
-
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 197 |
-
CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")");
|
| 198 |
-
#endif
|
| 199 |
-
device_memory = device_memory::allocate<StorageUnit>(count_container);
|
| 200 |
-
}
|
| 201 |
-
device_.reset(device_memory, device_backed_ ? count_container : 0);
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 205 |
-
/// extent and layout.
|
| 206 |
-
void reset(
|
| 207 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 208 |
-
Layout const &layout, ///< layout object of tensor
|
| 209 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 210 |
-
|
| 211 |
-
extent_ = extent;
|
| 212 |
-
layout_ = layout;
|
| 213 |
-
|
| 214 |
-
reserve(size_t(layout_.capacity(extent_)), device_backed_);
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 218 |
-
/// extent and layout. Assumes a packed tensor configuration.
|
| 219 |
-
void reset(
|
| 220 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 221 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 222 |
-
|
| 223 |
-
reset(extent, Layout::packed(extent), device_backed_);
|
| 224 |
-
}
|
| 225 |
-
|
| 226 |
-
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 227 |
-
/// To force allocation, call reset().
|
| 228 |
-
void resize(
|
| 229 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 230 |
-
Layout const &layout, ///< layout object of tensor
|
| 231 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 232 |
-
|
| 233 |
-
extent_ = extent;
|
| 234 |
-
layout_ = layout;
|
| 235 |
-
|
| 236 |
-
LongIndex new_size = size_t(layout_.capacity(extent_));
|
| 237 |
-
LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));
|
| 238 |
-
|
| 239 |
-
if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
|
| 240 |
-
reserve(new_size, device_backed_);
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 245 |
-
/// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
|
| 246 |
-
void resize(
|
| 247 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 248 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 249 |
-
|
| 250 |
-
resize(extent, Layout::packed(extent), device_backed_);
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
/// Returns the logical number of elements stored in the host tensor
|
| 254 |
-
size_t size() const {
|
| 255 |
-
return layout_.capacity(extent_);
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
/// Returns the logical capacity in terms of number of elements. May be larger than the size().
|
| 259 |
-
LongIndex capacity() const {
|
| 260 |
-
return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements;
|
| 261 |
-
}
|
| 262 |
-
|
| 263 |
-
/// Gets pointer to host data
|
| 264 |
-
Element * host_data() { return reinterpret_cast<Element *>(host_.data()); }
|
| 265 |
-
|
| 266 |
-
/// Gets pointer to host data with a pointer offset
|
| 267 |
-
Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
|
| 268 |
-
|
| 269 |
-
/// Gets a reference to an element in host memory
|
| 270 |
-
Reference host_data(LongIndex idx) {
|
| 271 |
-
return ReferenceFactory<Element>::get(host_data(), idx);
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
/// Gets pointer to host data
|
| 275 |
-
Element const * host_data() const { return reinterpret_cast<Element const *>(host_.data()); }
|
| 276 |
-
|
| 277 |
-
/// Gets pointer to host data with a pointer offset
|
| 278 |
-
Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
|
| 279 |
-
|
| 280 |
-
/// Gets a constant reference to an element in host memory
|
| 281 |
-
ConstReference host_data(LongIndex idx) const {
|
| 282 |
-
return ReferenceFactory<Element const>::get(host_data(), idx);
|
| 283 |
-
}
|
| 284 |
-
|
| 285 |
-
/// Gets pointer to device data
|
| 286 |
-
Element * device_data() { return reinterpret_cast<Element *>(device_.get()); }
|
| 287 |
-
|
| 288 |
-
/// Gets pointer to device data
|
| 289 |
-
Element const * device_data() const { return reinterpret_cast<Element const *>(device_.get()); }
|
| 290 |
-
|
| 291 |
-
/// Gets pointer to device data with a pointer offset
|
| 292 |
-
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
|
| 293 |
-
|
| 294 |
-
/// Gets pointer to device data with a pointer offset
|
| 295 |
-
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
|
| 296 |
-
|
| 297 |
-
/// Accesses the tensor reference pointing to data
|
| 298 |
-
TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
|
| 299 |
-
|
| 300 |
-
/// Accesses the tensor reference pointing to data
|
| 301 |
-
ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
|
| 302 |
-
|
| 303 |
-
/// Accesses the tensor reference pointing to data
|
| 304 |
-
TensorRef device_ref(LongIndex ptr_element_offset=0) {
|
| 305 |
-
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
|
| 306 |
-
}
|
| 307 |
-
|
| 308 |
-
/// Accesses the tensor reference pointing to data
|
| 309 |
-
ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
|
| 310 |
-
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
/// Accesses the tensor reference pointing to data
|
| 314 |
-
TensorView host_view(LongIndex ptr_element_offset=0) {
|
| 315 |
-
return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
/// Accesses the tensor reference pointing to data
|
| 319 |
-
ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
|
| 320 |
-
return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
-
/// Accesses the tensor reference pointing to data
|
| 324 |
-
TensorView device_view(LongIndex ptr_element_offset=0) {
|
| 325 |
-
return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 326 |
-
}
|
| 327 |
-
|
| 328 |
-
/// Accesses the tensor reference pointing to data
|
| 329 |
-
ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
|
| 330 |
-
return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 331 |
-
}
|
| 332 |
-
|
| 333 |
-
/// Returns true if device memory is allocated
|
| 334 |
-
bool device_backed() const {
|
| 335 |
-
return (device_.get() == nullptr) ? false : true;
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
/// Returns the layout object
|
| 340 |
-
Layout & layout() {
|
| 341 |
-
return layout_;
|
| 342 |
-
}
|
| 343 |
-
|
| 344 |
-
/// Returns the layout object
|
| 345 |
-
Layout layout() const {
|
| 346 |
-
return layout_;
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
-
/// Returns the layout object's stride vector
|
| 350 |
-
Stride stride() const {
|
| 351 |
-
return layout_.stride();
|
| 352 |
-
}
|
| 353 |
-
|
| 354 |
-
/// Returns the layout object's stride vector
|
| 355 |
-
Stride & stride() {
|
| 356 |
-
return layout_.stride();
|
| 357 |
-
}
|
| 358 |
-
|
| 359 |
-
/// Returns the layout object's stride in a given physical dimension
|
| 360 |
-
LongIndex stride(int dim) const {
|
| 361 |
-
return layout_.stride().at(dim);
|
| 362 |
-
}
|
| 363 |
-
|
| 364 |
-
/// Returns the layout object's stride in a given physical dimension
|
| 365 |
-
LongIndex & stride(int dim) {
|
| 366 |
-
return layout_.stride().at(dim);
|
| 367 |
-
}
|
| 368 |
-
|
| 369 |
-
/// Computes the offset of an index from the origin of the tensor
|
| 370 |
-
LongIndex offset(TensorCoord const& coord) const {
|
| 371 |
-
return layout_(coord);
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
/// Returns a reference to the element at the logical Coord in host memory
|
| 375 |
-
Reference at(TensorCoord const& coord) {
|
| 376 |
-
return host_data(offset(coord));
|
| 377 |
-
}
|
| 378 |
-
|
| 379 |
-
/// Returns a const reference to the element at the logical Coord in host memory
|
| 380 |
-
ConstReference at(TensorCoord const& coord) const {
|
| 381 |
-
return host_data(offset(coord));
|
| 382 |
-
}
|
| 383 |
-
|
| 384 |
-
/// Returns the extent of the tensor
|
| 385 |
-
TensorCoord extent() const {
|
| 386 |
-
return extent_;
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
/// Returns the extent of the tensor
|
| 390 |
-
TensorCoord & extent() {
|
| 391 |
-
return extent_;
|
| 392 |
-
}
|
| 393 |
-
|
| 394 |
-
/// Copies data from device to host
|
| 395 |
-
void sync_host() {
|
| 396 |
-
if (device_backed()) {
|
| 397 |
-
device_memory::copy_to_host(
|
| 398 |
-
host_.data(), device_.get(), device_.size());
|
| 399 |
-
}
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
/// Copies data from host to device
|
| 403 |
-
void sync_device() {
|
| 404 |
-
if (device_backed()) {
|
| 405 |
-
device_memory::copy_to_device(
|
| 406 |
-
device_.get(), host_.data(), host_.size());
|
| 407 |
-
}
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 411 |
-
void copy_in_device_to_host(
|
| 412 |
-
Element const* ptr_device, ///< source device memory
|
| 413 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 414 |
-
|
| 415 |
-
if (count < 0) {
|
| 416 |
-
count = capacity();
|
| 417 |
-
}
|
| 418 |
-
else {
|
| 419 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 420 |
-
}
|
| 421 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 422 |
-
device_memory::copy_to_host(
|
| 423 |
-
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
|
| 424 |
-
}
|
| 425 |
-
|
| 426 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 427 |
-
void copy_in_device_to_device(
|
| 428 |
-
Element const* ptr_device, ///< source device memory
|
| 429 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 430 |
-
|
| 431 |
-
if (count < 0) {
|
| 432 |
-
count = capacity();
|
| 433 |
-
}
|
| 434 |
-
else {
|
| 435 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 436 |
-
}
|
| 437 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 438 |
-
device_memory::copy_device_to_device(
|
| 439 |
-
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
|
| 440 |
-
}
|
| 441 |
-
|
| 442 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 443 |
-
void copy_in_host_to_device(
|
| 444 |
-
Element const* ptr_host, ///< source host memory
|
| 445 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 446 |
-
|
| 447 |
-
if (count < 0) {
|
| 448 |
-
count = capacity();
|
| 449 |
-
}
|
| 450 |
-
else {
|
| 451 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 452 |
-
}
|
| 453 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 454 |
-
device_memory::copy_to_device(
|
| 455 |
-
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
|
| 456 |
-
}
|
| 457 |
-
|
| 458 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 459 |
-
void copy_in_host_to_host(
|
| 460 |
-
Element const* ptr_host, ///< source host memory
|
| 461 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 462 |
-
|
| 463 |
-
if (count < 0) {
|
| 464 |
-
count = capacity();
|
| 465 |
-
}
|
| 466 |
-
else {
|
| 467 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 468 |
-
}
|
| 469 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 470 |
-
device_memory::copy_host_to_host(
|
| 471 |
-
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
|
| 472 |
-
}
|
| 473 |
-
|
| 474 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 475 |
-
void copy_out_device_to_host(
|
| 476 |
-
Element * ptr_host, ///< source device memory
|
| 477 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 478 |
-
|
| 479 |
-
if (count < 0) {
|
| 480 |
-
count = capacity();
|
| 481 |
-
}
|
| 482 |
-
else {
|
| 483 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 484 |
-
}
|
| 485 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 486 |
-
device_memory::copy_to_host(
|
| 487 |
-
reinterpret_cast<StorageUnit *>(ptr_host), device_.get(), container_count);
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 491 |
-
void copy_out_device_to_device(
|
| 492 |
-
Element * ptr_device, ///< source device memory
|
| 493 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 494 |
-
|
| 495 |
-
if (count < 0) {
|
| 496 |
-
count = capacity();
|
| 497 |
-
}
|
| 498 |
-
else {
|
| 499 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 500 |
-
}
|
| 501 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 502 |
-
device_memory::copy_device_to_device(
|
| 503 |
-
reinterpret_cast<StorageUnit *>(ptr_device), device_.get(), container_count);
|
| 504 |
-
}
|
| 505 |
-
|
| 506 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 507 |
-
void copy_out_host_to_device(
|
| 508 |
-
Element * ptr_device, ///< source host memory
|
| 509 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 510 |
-
|
| 511 |
-
if (count < 0) {
|
| 512 |
-
count = capacity();
|
| 513 |
-
}
|
| 514 |
-
else {
|
| 515 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 516 |
-
}
|
| 517 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 518 |
-
device_memory::copy_to_device(
|
| 519 |
-
reinterpret_cast<StorageUnit *>(ptr_device), host_.data(), container_count);
|
| 520 |
-
}
|
| 521 |
-
|
| 522 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 523 |
-
void copy_out_host_to_host(
|
| 524 |
-
Element * ptr_host, ///< source host memory
|
| 525 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 526 |
-
|
| 527 |
-
if (count < 0) {
|
| 528 |
-
count = capacity();
|
| 529 |
-
}
|
| 530 |
-
else {
|
| 531 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 532 |
-
}
|
| 533 |
-
size_t container_count = count_to_container_storage_unit_count(count);
|
| 534 |
-
device_memory::copy_host_to_host(
|
| 535 |
-
reinterpret_cast<StorageUnit *>(ptr_host), host_.data(), container_count);
|
| 536 |
-
}
|
| 537 |
-
};
|
| 538 |
-
|
| 539 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 540 |
-
|
| 541 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h
DELETED
|
@@ -1,591 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
/*! \file
|
| 34 |
-
\brief HostTensor contributes management for both host and device memory.
|
| 35 |
-
|
| 36 |
-
HostTensor allocates host and device memory upon construction. Basic element-wise operations on
|
| 37 |
-
host memory synchronize device memory automatically. Explicit copy operations provide abstractions
|
| 38 |
-
for CUDA memcpy operations.
|
| 39 |
-
|
| 40 |
-
Call {host, device}_{data, ref, view}() for accessing host or device memory.
|
| 41 |
-
|
| 42 |
-
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
|
| 43 |
-
*/
|
| 44 |
-
|
| 45 |
-
#include <vector>
|
| 46 |
-
|
| 47 |
-
#include "cutlass/cutlass.h"
|
| 48 |
-
|
| 49 |
-
#include "cutlass/tensor_ref_planar_complex.h"
|
| 50 |
-
#include "cutlass/tensor_view_planar_complex.h"
|
| 51 |
-
|
| 52 |
-
#include "device_memory.h"
|
| 53 |
-
|
| 54 |
-
namespace cutlass {
|
| 55 |
-
|
| 56 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
-
|
| 58 |
-
/// Host tensor
|
| 59 |
-
template <
|
| 60 |
-
/// Data type of element stored within tensor (concept: NumericType)
|
| 61 |
-
typename Element_,
|
| 62 |
-
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
| 63 |
-
typename Layout_
|
| 64 |
-
>
|
| 65 |
-
class HostTensorPlanarComplex {
|
| 66 |
-
public:
|
| 67 |
-
|
| 68 |
-
/// Data type of individual access
|
| 69 |
-
using Element = Element_;
|
| 70 |
-
|
| 71 |
-
/// Mapping function from logical coordinate to linear memory
|
| 72 |
-
using Layout = Layout_;
|
| 73 |
-
|
| 74 |
-
/// Logical rank of tensor index space
|
| 75 |
-
static int const kRank = Layout::kRank;
|
| 76 |
-
|
| 77 |
-
/// Index type
|
| 78 |
-
using Index = typename Layout::Index;
|
| 79 |
-
|
| 80 |
-
/// Long index used for pointer offsets
|
| 81 |
-
using LongIndex = typename Layout::LongIndex;
|
| 82 |
-
|
| 83 |
-
/// Coordinate in logical tensor space
|
| 84 |
-
using TensorCoord = typename Layout::TensorCoord;
|
| 85 |
-
|
| 86 |
-
/// Layout's stride vector
|
| 87 |
-
using Stride = typename Layout::Stride;
|
| 88 |
-
|
| 89 |
-
/// Tensor reference to device memory
|
| 90 |
-
using TensorRef = TensorRefPlanarComplex<Element, Layout>;
|
| 91 |
-
|
| 92 |
-
/// Tensor reference to constant device memory
|
| 93 |
-
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 94 |
-
|
| 95 |
-
/// Tensor reference to device memory
|
| 96 |
-
using TensorView = TensorViewPlanarComplex<Element, Layout>;
|
| 97 |
-
|
| 98 |
-
/// Tensor reference to constant device memory
|
| 99 |
-
using ConstTensorView = typename TensorView::ConstTensorView;
|
| 100 |
-
|
| 101 |
-
/// Reference to element in tensor
|
| 102 |
-
using Reference = typename TensorRef::Reference;
|
| 103 |
-
|
| 104 |
-
/// Constant reference to element in tensor
|
| 105 |
-
using ConstReference = typename ConstTensorRef::Reference;
|
| 106 |
-
|
| 107 |
-
private:
|
| 108 |
-
|
| 109 |
-
//
|
| 110 |
-
// Data members
|
| 111 |
-
//
|
| 112 |
-
|
| 113 |
-
/// Extent of tensor in logical dimensions
|
| 114 |
-
TensorCoord extent_;
|
| 115 |
-
|
| 116 |
-
/// Layout object
|
| 117 |
-
Layout layout_;
|
| 118 |
-
|
| 119 |
-
/// Host-side memory allocation
|
| 120 |
-
std::vector<Element> host_;
|
| 121 |
-
|
| 122 |
-
/// Device-side memory
|
| 123 |
-
device_memory::allocation<Element> device_;
|
| 124 |
-
|
| 125 |
-
public:
|
| 126 |
-
//
|
| 127 |
-
// Device and Host Methods
|
| 128 |
-
//
|
| 129 |
-
|
| 130 |
-
/// Default constructor
|
| 131 |
-
HostTensorPlanarComplex() {}
|
| 132 |
-
|
| 133 |
-
/// Constructs a tensor given an extent. Assumes a packed layout
|
| 134 |
-
HostTensorPlanarComplex(
|
| 135 |
-
TensorCoord const &extent,
|
| 136 |
-
bool device_backed = true
|
| 137 |
-
) {
|
| 138 |
-
|
| 139 |
-
this->reset(extent, Layout::packed(extent), device_backed);
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
/// Constructs a tensor given an extent and layout
|
| 143 |
-
HostTensorPlanarComplex(
|
| 144 |
-
TensorCoord const &extent,
|
| 145 |
-
Layout const &layout,
|
| 146 |
-
bool device_backed = true
|
| 147 |
-
) {
|
| 148 |
-
|
| 149 |
-
this->reset(extent, layout, device_backed);
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
~HostTensorPlanarComplex() { }
|
| 153 |
-
|
| 154 |
-
/// Clears the HostTensor allocation to size/capacity = 0
|
| 155 |
-
void reset() {
|
| 156 |
-
extent_ = TensorCoord();
|
| 157 |
-
layout_ = Layout::packed(extent_);
|
| 158 |
-
|
| 159 |
-
host_.clear();
|
| 160 |
-
device_.reset();
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
/// Resizes internal memory allocations without affecting layout or extent
|
| 164 |
-
void reserve(
|
| 165 |
-
size_t count, ///< size of tensor in elements
|
| 166 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated
|
| 167 |
-
|
| 168 |
-
device_.reset();
|
| 169 |
-
host_.clear();
|
| 170 |
-
|
| 171 |
-
host_.resize(count * 2);
|
| 172 |
-
|
| 173 |
-
// Allocate memory
|
| 174 |
-
Element* device_memory = nullptr;
|
| 175 |
-
if (device_backed_) {
|
| 176 |
-
device_memory = device_memory::allocate<Element>(count * 2);
|
| 177 |
-
}
|
| 178 |
-
device_.reset(device_memory, device_backed_ ? count * 2 : 0);
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 182 |
-
/// extent and layout.
|
| 183 |
-
void reset(
|
| 184 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 185 |
-
Layout const &layout, ///< layout object of tensor
|
| 186 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 187 |
-
|
| 188 |
-
extent_ = extent;
|
| 189 |
-
layout_ = layout;
|
| 190 |
-
|
| 191 |
-
reserve(size_t(layout_.capacity(extent_)), device_backed_);
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 195 |
-
/// extent and layout. Assumes a packed tensor configuration.
|
| 196 |
-
void reset(
|
| 197 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 198 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 199 |
-
|
| 200 |
-
reset(extent, Layout::packed(extent), device_backed_);
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 204 |
-
/// To force allocation, call reset().
|
| 205 |
-
void resize(
|
| 206 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 207 |
-
Layout const &layout, ///< layout object of tensor
|
| 208 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 209 |
-
|
| 210 |
-
extent_ = extent;
|
| 211 |
-
layout_ = layout;
|
| 212 |
-
|
| 213 |
-
LongIndex new_size = size_t(layout_.capacity(extent_));
|
| 214 |
-
|
| 215 |
-
if (static_cast<decltype(host_.size())>(new_size * 2) > host_.size()) {
|
| 216 |
-
reserve(new_size);
|
| 217 |
-
}
|
| 218 |
-
}
|
| 219 |
-
|
| 220 |
-
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 221 |
-
/// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
|
| 222 |
-
void resize(
|
| 223 |
-
TensorCoord const &extent, ///< extent of logical tensor
|
| 224 |
-
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 225 |
-
|
| 226 |
-
resize(extent, Layout::packed(extent), device_backed_);
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
/// Returns the number of elements stored in the host tensor
|
| 230 |
-
size_t size() const {
|
| 231 |
-
return host_.size() / 2;
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
/// Returns the logical capacity based on extent and layout. May differ from size().
|
| 235 |
-
LongIndex capacity() const {
|
| 236 |
-
return layout_.capacity(extent_);
|
| 237 |
-
}
|
| 238 |
-
|
| 239 |
-
/// Stride between real and imaginary parts
|
| 240 |
-
LongIndex imaginary_stride() const {
|
| 241 |
-
return host_.size() / 2;
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
/// Gets pointer to host data
|
| 245 |
-
Element * host_data() { return host_.data(); }
|
| 246 |
-
|
| 247 |
-
/// Gets pointer to host data imaginary part
|
| 248 |
-
Element * host_data_imag() { return host_.data() + imaginary_stride(); }
|
| 249 |
-
|
| 250 |
-
/// Gets pointer to host data with a pointer offset
|
| 251 |
-
Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; }
|
| 252 |
-
|
| 253 |
-
/// Gets pointer to host data with a pointer offset
|
| 254 |
-
Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; }
|
| 255 |
-
|
| 256 |
-
/// Gets a reference to an element in host memory
|
| 257 |
-
Reference host_data(LongIndex idx) {
|
| 258 |
-
return PlanarComplexReference<Element>(host_data() + idx, host_data_imag() + idx);
|
| 259 |
-
}
|
| 260 |
-
|
| 261 |
-
/// Gets pointer to host data
|
| 262 |
-
Element const * host_data() const { return host_.data(); }
|
| 263 |
-
|
| 264 |
-
/// Gets pointer to host data imaginary part
|
| 265 |
-
Element const * host_data_imag() const { return host_.data() + imaginary_stride(); }
|
| 266 |
-
|
| 267 |
-
/// Gets a constant reference to an element in host memory
|
| 268 |
-
ConstReference host_data(LongIndex idx) const {
|
| 269 |
-
return PlanarComplexReference<Element const>(host_data() + idx, host_data_imag() + idx);
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
/// Gets pointer to device data
|
| 273 |
-
Element * device_data() { return device_.get(); }
|
| 274 |
-
|
| 275 |
-
/// Gets pointer to device data with a pointer offset
|
| 276 |
-
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; }
|
| 277 |
-
|
| 278 |
-
/// Gets pointer to device data
|
| 279 |
-
Element const * device_data() const { return device_.get(); }
|
| 280 |
-
|
| 281 |
-
/// Gets pointer to device data with a pointer offset
|
| 282 |
-
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
|
| 283 |
-
|
| 284 |
-
/// Gets a pointer to the device data imaginary part
|
| 285 |
-
Element * device_data_imag() { return device_.get() + imaginary_stride(); }
|
| 286 |
-
|
| 287 |
-
/// Accesses the tensor reference pointing to data
|
| 288 |
-
TensorRef host_ref(LongIndex ptr_element_offset=0) {
|
| 289 |
-
return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 290 |
-
}
|
| 291 |
-
|
| 292 |
-
/// Returns a tensor reference to the real part of the tensor
|
| 293 |
-
cutlass::TensorRef<Element, Layout> host_ref_real() {
|
| 294 |
-
return cutlass::TensorRef<Element, Layout>(host_data(), layout_);
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
/// Returns a tensor reference to the real part of the tensor
|
| 298 |
-
cutlass::TensorRef<Element, Layout> host_ref_imag() {
|
| 299 |
-
return cutlass::TensorRef<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_);
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
/// Accesses the tensor reference pointing to data
|
| 303 |
-
ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const {
|
| 304 |
-
return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 305 |
-
}
|
| 306 |
-
|
| 307 |
-
/// Accesses the tensor reference pointing to data
|
| 308 |
-
TensorRef device_ref(LongIndex ptr_element_offset=0) {
|
| 309 |
-
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 310 |
-
}
|
| 311 |
-
|
| 312 |
-
/// Accesses the tensor reference pointing to data
|
| 313 |
-
ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
|
| 314 |
-
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 315 |
-
}
|
| 316 |
-
|
| 317 |
-
/// Returns a tensor reference to the real part of the tensor
|
| 318 |
-
cutlass::TensorRef<Element, Layout> device_ref_real() {
|
| 319 |
-
return cutlass::TensorRef<Element, Layout>(device_data(), layout_);
|
| 320 |
-
}
|
| 321 |
-
|
| 322 |
-
/// Returns a tensor reference to the real part of the tensor
|
| 323 |
-
cutlass::TensorRef<Element, Layout> device_ref_imag() {
|
| 324 |
-
return cutlass::TensorRef<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_);
|
| 325 |
-
}
|
| 326 |
-
|
| 327 |
-
/// Accesses the tensor reference pointing to data
|
| 328 |
-
TensorView host_view(LongIndex ptr_element_offset=0) {
|
| 329 |
-
return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 330 |
-
}
|
| 331 |
-
|
| 332 |
-
/// Accesses the tensor reference pointing to data
|
| 333 |
-
ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
|
| 334 |
-
return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 335 |
-
}
|
| 336 |
-
|
| 337 |
-
/// Accesses the tensor reference pointing to data
|
| 338 |
-
cutlass::TensorView<Element, Layout> host_view_real() {
|
| 339 |
-
return cutlass::TensorView<Element, Layout>(host_data(), layout_, extent_);
|
| 340 |
-
}
|
| 341 |
-
|
| 342 |
-
/// Accesses the tensor reference pointing to data
|
| 343 |
-
cutlass::TensorView<Element, Layout> host_view_imag() {
|
| 344 |
-
return cutlass::TensorView<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_, extent_);
|
| 345 |
-
}
|
| 346 |
-
|
| 347 |
-
/// Accesses the tensor reference pointing to data
|
| 348 |
-
TensorView device_view(LongIndex ptr_element_offset=0) {
|
| 349 |
-
return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
/// Accesses the tensor reference pointing to data
|
| 353 |
-
ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
|
| 354 |
-
return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 355 |
-
}
|
| 356 |
-
|
| 357 |
-
/// Accesses the tensor reference pointing to data
|
| 358 |
-
cutlass::TensorView<Element, Layout> device_view_real() {
|
| 359 |
-
return cutlass::TensorView<Element, Layout>(device_data(), layout_, extent_);
|
| 360 |
-
}
|
| 361 |
-
|
| 362 |
-
/// Accesses the tensor reference pointing to data
|
| 363 |
-
cutlass::TensorView<Element, Layout> device_view_imag() {
|
| 364 |
-
return cutlass::TensorView<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_, extent_);
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
/// Returns true if device memory is allocated
|
| 368 |
-
bool device_backed() const {
|
| 369 |
-
return (device_.get() == nullptr) ? false : true;
|
| 370 |
-
}
|
| 371 |
-
|
| 372 |
-
/// Returns the layout object
|
| 373 |
-
Layout layout() const {
|
| 374 |
-
return layout_;
|
| 375 |
-
}
|
| 376 |
-
|
| 377 |
-
/// Returns the layout object's stride vector
|
| 378 |
-
Stride stride() const {
|
| 379 |
-
return layout_.stride();
|
| 380 |
-
}
|
| 381 |
-
|
| 382 |
-
/// Returns the layout object's stride in a given physical dimension
|
| 383 |
-
Index stride(int dim) const {
|
| 384 |
-
return layout_.stride().at(dim);
|
| 385 |
-
}
|
| 386 |
-
|
| 387 |
-
/// Computes the offset of an index from the origin of the tensor
|
| 388 |
-
LongIndex offset(TensorCoord const& coord) const {
|
| 389 |
-
return layout_(coord);
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
/// Returns a reference to the element at the logical Coord in host memory
|
| 393 |
-
Reference at(TensorCoord const& coord) {
|
| 394 |
-
return host_data(offset(coord));
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
/// Returns a const reference to the element at the logical Coord in host memory
|
| 398 |
-
ConstReference at(TensorCoord const& coord) const {
|
| 399 |
-
return host_data(offset(coord));
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
/// Returns the extent of the tensor
|
| 403 |
-
TensorCoord extent() const {
|
| 404 |
-
return extent_;
|
| 405 |
-
}
|
| 406 |
-
|
| 407 |
-
/// Returns the extent of the tensor
|
| 408 |
-
TensorCoord & extent() {
|
| 409 |
-
return extent_;
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
/// Copies data from device to host
|
| 413 |
-
void sync_host() {
|
| 414 |
-
if (device_backed()) {
|
| 415 |
-
device_memory::copy_to_host(
|
| 416 |
-
host_data(), device_data(), imaginary_stride() * 2);
|
| 417 |
-
}
|
| 418 |
-
}
|
| 419 |
-
|
| 420 |
-
/// Copies data from host to device
|
| 421 |
-
void sync_device() {
|
| 422 |
-
if (device_backed()) {
|
| 423 |
-
device_memory::copy_to_device(
|
| 424 |
-
device_data(), host_data(), imaginary_stride() * 2);
|
| 425 |
-
}
|
| 426 |
-
}
|
| 427 |
-
|
| 428 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 429 |
-
void copy_in_device_to_host(
|
| 430 |
-
Element const* ptr_device_real, ///< source device memory
|
| 431 |
-
Element const* ptr_device_imag, ///< source device memory
|
| 432 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 433 |
-
|
| 434 |
-
if (count < 0) {
|
| 435 |
-
count = capacity();
|
| 436 |
-
}
|
| 437 |
-
else {
|
| 438 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 439 |
-
}
|
| 440 |
-
|
| 441 |
-
device_memory::copy_to_host(
|
| 442 |
-
host_data(), ptr_device_real, count);
|
| 443 |
-
|
| 444 |
-
device_memory::copy_to_host(
|
| 445 |
-
host_data_imag(), ptr_device_imag, count);
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 449 |
-
void copy_in_device_to_device(
|
| 450 |
-
Element const* ptr_device_real, ///< source device memory
|
| 451 |
-
Element const* ptr_device_imag, ///< source device memory
|
| 452 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 453 |
-
|
| 454 |
-
if (count < 0) {
|
| 455 |
-
count = capacity();
|
| 456 |
-
}
|
| 457 |
-
else {
|
| 458 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 459 |
-
}
|
| 460 |
-
|
| 461 |
-
device_memory::copy_device_to_device(
|
| 462 |
-
device_data(), ptr_device_real, count);
|
| 463 |
-
|
| 464 |
-
device_memory::copy_device_to_device(
|
| 465 |
-
device_data_imag(), ptr_device_imag, count);
|
| 466 |
-
}
|
| 467 |
-
|
| 468 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 469 |
-
void copy_in_host_to_device(
|
| 470 |
-
Element const* ptr_host_real, ///< source host memory
|
| 471 |
-
Element const* ptr_host_imag, ///< source host memory
|
| 472 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 473 |
-
|
| 474 |
-
if (count < 0) {
|
| 475 |
-
count = capacity();
|
| 476 |
-
}
|
| 477 |
-
else {
|
| 478 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 479 |
-
}
|
| 480 |
-
|
| 481 |
-
device_memory::copy_to_device(
|
| 482 |
-
device_data(), ptr_host_real, count);
|
| 483 |
-
|
| 484 |
-
device_memory::copy_to_device(
|
| 485 |
-
device_data_imag(), ptr_host_imag, count);
|
| 486 |
-
}
|
| 487 |
-
|
| 488 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 489 |
-
void copy_in_host_to_host(
|
| 490 |
-
Element const* ptr_host_real, ///< source host memory
|
| 491 |
-
Element const* ptr_host_imag, ///< source host memory
|
| 492 |
-
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 493 |
-
|
| 494 |
-
if (count < 0) {
|
| 495 |
-
count = capacity();
|
| 496 |
-
}
|
| 497 |
-
else {
|
| 498 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 499 |
-
}
|
| 500 |
-
|
| 501 |
-
device_memory::copy_host_to_host(
|
| 502 |
-
host_data(), ptr_host_real, count);
|
| 503 |
-
|
| 504 |
-
device_memory::copy_host_to_host(
|
| 505 |
-
host_data_imag(), ptr_host_imag, count);
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 509 |
-
void copy_out_device_to_host(
|
| 510 |
-
Element * ptr_host_real, ///< source device memory
|
| 511 |
-
Element * ptr_host_imag, ///< source device memory
|
| 512 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 513 |
-
|
| 514 |
-
if (count < 0) {
|
| 515 |
-
count = capacity();
|
| 516 |
-
}
|
| 517 |
-
else {
|
| 518 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
device_memory::copy_to_host(
|
| 522 |
-
ptr_host_real, device_data(), count);
|
| 523 |
-
|
| 524 |
-
device_memory::copy_to_host(
|
| 525 |
-
ptr_host_imag, device_data_imag(), count);
|
| 526 |
-
}
|
| 527 |
-
|
| 528 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 529 |
-
void copy_out_device_to_device(
|
| 530 |
-
Element * ptr_device_real, ///< source device memory
|
| 531 |
-
Element * ptr_device_imag, ///< source device memory
|
| 532 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 533 |
-
|
| 534 |
-
if (count < 0) {
|
| 535 |
-
count = capacity();
|
| 536 |
-
}
|
| 537 |
-
else {
|
| 538 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 539 |
-
}
|
| 540 |
-
|
| 541 |
-
device_memory::copy_device_to_device(
|
| 542 |
-
ptr_device_real, device_data(), count);
|
| 543 |
-
|
| 544 |
-
device_memory::copy_device_to_device(
|
| 545 |
-
ptr_device_imag, device_data_imag(), count);
|
| 546 |
-
}
|
| 547 |
-
|
| 548 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 549 |
-
void copy_out_host_to_device(
|
| 550 |
-
Element * ptr_device_real, ///< source device memory
|
| 551 |
-
Element * ptr_device_imag, ///< source device memory
|
| 552 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 553 |
-
|
| 554 |
-
if (count < 0) {
|
| 555 |
-
count = capacity();
|
| 556 |
-
}
|
| 557 |
-
else {
|
| 558 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 559 |
-
}
|
| 560 |
-
|
| 561 |
-
device_memory::copy_to_device(
|
| 562 |
-
ptr_device_real, host_data(), count);
|
| 563 |
-
|
| 564 |
-
device_memory::copy_to_device(
|
| 565 |
-
ptr_device_imag, host_data_imag(), count);
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
-
/// Copy data from a caller-supplied device pointer into host memory.
|
| 569 |
-
void copy_out_host_to_host(
|
| 570 |
-
Element * ptr_host_real, ///< source host memory
|
| 571 |
-
Element * ptr_host_imag, ///< source host memory
|
| 572 |
-
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 573 |
-
|
| 574 |
-
if (count < 0) {
|
| 575 |
-
count = capacity();
|
| 576 |
-
}
|
| 577 |
-
else {
|
| 578 |
-
count = __NV_STD_MIN(capacity(), count);
|
| 579 |
-
}
|
| 580 |
-
|
| 581 |
-
device_memory::copy_host_to_host(
|
| 582 |
-
ptr_host_real, host_data(), count);
|
| 583 |
-
|
| 584 |
-
device_memory::copy_host_to_host(
|
| 585 |
-
ptr_host_imag, host_data_imag(), count);
|
| 586 |
-
}
|
| 587 |
-
};
|
| 588 |
-
|
| 589 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 590 |
-
|
| 591 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief uncompress sparse matrix from the host side
|
| 34 |
-
*/
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/util/host_tensor.h"
|
| 39 |
-
#include "cutlass/tensor_view.h"
|
| 40 |
-
#include "cutlass/util/tensor_view_io.h"
|
| 41 |
-
#include "cutlass/util/reference/host/gemm.h"
|
| 42 |
-
|
| 43 |
-
namespace cutlass {
|
| 44 |
-
|
| 45 |
-
// uncompress sparse tensor core A matrix
|
| 46 |
-
template <typename ElementA, typename LayoutA, typename ElementE,
|
| 47 |
-
typename LayoutE>
|
| 48 |
-
void uncompress(TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
|
| 49 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 50 |
-
TensorRef<ElementE, LayoutE> tensor_e, int row, int col) {
|
| 51 |
-
// How many uncompressed data we can get with ElementE meta data
|
| 52 |
-
int DecompressedElementsPerElementE =
|
| 53 |
-
256 / cutlass::sizeof_bits<ElementA>::value;
|
| 54 |
-
|
| 55 |
-
// Process 4bit meta data a time
|
| 56 |
-
int step;
|
| 57 |
-
|
| 58 |
-
// 1:2 or 2:4 or 4:8
|
| 59 |
-
int a, b;
|
| 60 |
-
|
| 61 |
-
if (cutlass::sizeof_bits<ElementA>::value == 4) {
|
| 62 |
-
step = 8;
|
| 63 |
-
a = 4;
|
| 64 |
-
b = 8;
|
| 65 |
-
} else if (cutlass::sizeof_bits<ElementA>::value == 8) {
|
| 66 |
-
step = 4;
|
| 67 |
-
a = 2;
|
| 68 |
-
b = 4;
|
| 69 |
-
} else if (cutlass::sizeof_bits<ElementA>::value == 16) {
|
| 70 |
-
step = 4;
|
| 71 |
-
a = 2;
|
| 72 |
-
b = 4;
|
| 73 |
-
} else if (cutlass::sizeof_bits<ElementA>::value == 32) {
|
| 74 |
-
step = 2;
|
| 75 |
-
a = 1;
|
| 76 |
-
b = 2;
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
int ElementsPerE = (cutlass::sizeof_bits<ElementA>::value == 4) ? 2 : 1;
|
| 80 |
-
|
| 81 |
-
for (int r = 0; r < row; ++r) {
|
| 82 |
-
for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) {
|
| 83 |
-
|
| 84 |
-
ElementE meta = tensor_e.at(MatrixCoord(r, c));
|
| 85 |
-
|
| 86 |
-
for (int i = 0; i < DecompressedElementsPerElementE; i += step) {
|
| 87 |
-
int e = (meta >> (i / step * 4)) & 0xf;
|
| 88 |
-
int idx0 = e & 0x3;
|
| 89 |
-
int idx1 = e >> 2;
|
| 90 |
-
|
| 91 |
-
if (a == 1) idx0 = idx0 / 2;
|
| 92 |
-
|
| 93 |
-
for (int ii = 0; ii < step; ii += ElementsPerE) {
|
| 94 |
-
int real_col =
|
| 95 |
-
c * DecompressedElementsPerElementE + i + ii;
|
| 96 |
-
int compressed_col = (real_col / b) * a;
|
| 97 |
-
|
| 98 |
-
if (ii == (idx0 * ElementsPerE)) {
|
| 99 |
-
uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
|
| 100 |
-
tensor_a.at(MatrixCoord(r, compressed_col));
|
| 101 |
-
if (ElementsPerE == 2)
|
| 102 |
-
uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
|
| 103 |
-
tensor_a.at(MatrixCoord(r, compressed_col + 1));
|
| 104 |
-
} else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) {
|
| 105 |
-
uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
|
| 106 |
-
tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE));
|
| 107 |
-
if (ElementsPerE == 2)
|
| 108 |
-
uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
|
| 109 |
-
tensor_a.at(
|
| 110 |
-
MatrixCoord(r, compressed_col + ElementsPerE + 1));
|
| 111 |
-
} else {
|
| 112 |
-
uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
|
| 113 |
-
ElementA(0);
|
| 114 |
-
if (ElementsPerE == 2)
|
| 115 |
-
uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
|
| 116 |
-
ElementA(0);
|
| 117 |
-
}
|
| 118 |
-
}
|
| 119 |
-
}
|
| 120 |
-
}
|
| 121 |
-
}
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
// uncompress ELL block sparse matrix
|
| 125 |
-
template <typename ElementA, typename LayoutA,
|
| 126 |
-
typename ElementE, typename LayoutE>
|
| 127 |
-
void uncompress_ell_block_sparse(
|
| 128 |
-
TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
|
| 129 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 130 |
-
TensorRef<ElementE, LayoutE> ell_idx,
|
| 131 |
-
int rows, int cols,
|
| 132 |
-
int ell_num_cols, int ell_blocksize) {
|
| 133 |
-
|
| 134 |
-
for (int r = 0; r < rows / ell_blocksize; ++r) {
|
| 135 |
-
for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) {
|
| 136 |
-
|
| 137 |
-
ElementE idx = ell_idx.at(MatrixCoord(r, c));
|
| 138 |
-
|
| 139 |
-
if (idx != -1) {
|
| 140 |
-
int row_begin = r * ell_blocksize;
|
| 141 |
-
int col_begin_real = idx * ell_blocksize;
|
| 142 |
-
int col_begin = c * ell_blocksize;
|
| 143 |
-
|
| 144 |
-
for (int i = 0; i < ell_blocksize; ++i) {
|
| 145 |
-
for (int j = 0; j < ell_blocksize; ++j) {
|
| 146 |
-
uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) =
|
| 147 |
-
tensor_a.at(
|
| 148 |
-
MatrixCoord(row_begin + i, col_begin +j));
|
| 149 |
-
}
|
| 150 |
-
}
|
| 151 |
-
}
|
| 152 |
-
}
|
| 153 |
-
}
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
} // namespace cutlass
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/numeric_types.h"
|
| 36 |
-
|
| 37 |
-
// integer_sequence moved to cutlass/numeric_types.h
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp
DELETED
|
@@ -1,472 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Utilities for mixed input data type kernels.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include <cuda.h>
|
| 38 |
-
#include "cute/layout.hpp"
|
| 39 |
-
#include "cute/tensor.hpp"
|
| 40 |
-
#include "cute/arch/mma_sm90.hpp"
|
| 41 |
-
#include "cutlass/cutlass.h"
|
| 42 |
-
#include "cutlass/util/device_memory.h"
|
| 43 |
-
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 44 |
-
#include "cute/util/type_traits.hpp"
|
| 45 |
-
|
| 46 |
-
namespace cutlass {
|
| 47 |
-
|
| 48 |
-
#define CUDA_CHECK(status) \
|
| 49 |
-
{ \
|
| 50 |
-
cudaError_t error = status; \
|
| 51 |
-
if (error != cudaSuccess) { \
|
| 52 |
-
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
|
| 53 |
-
<< " at line: " << __LINE__ << std::endl; \
|
| 54 |
-
exit(EXIT_FAILURE); \
|
| 55 |
-
} \
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
template <
|
| 59 |
-
class QuantizedElement,
|
| 60 |
-
class DequantizedElement,
|
| 61 |
-
class OperandLayout,
|
| 62 |
-
class ElementScale,
|
| 63 |
-
class ElementZero,
|
| 64 |
-
class ScaleBroadCastLayout,
|
| 65 |
-
class ThrLayout>
|
| 66 |
-
__global__ void dequantize_kernel(DequantizedElement* dq_buffer,
|
| 67 |
-
QuantizedElement const* q_buffer,
|
| 68 |
-
OperandLayout const operand_layout,
|
| 69 |
-
ElementScale const* scale_buffer,
|
| 70 |
-
ElementZero const* zero_buffer,
|
| 71 |
-
ScaleBroadCastLayout const broadcasted_scale_layout,
|
| 72 |
-
ThrLayout thr_layout) {
|
| 73 |
-
using namespace cute;
|
| 74 |
-
|
| 75 |
-
// Represent the full tensors to gmem elements.
|
| 76 |
-
// These are expected to have shape [MN, K, L]
|
| 77 |
-
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
|
| 78 |
-
cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr<QuantizedElement const>(q_buffer), operand_layout);
|
| 79 |
-
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
| 80 |
-
// It is expected that K % G == 0
|
| 81 |
-
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
| 82 |
-
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
| 83 |
-
|
| 84 |
-
// Assign 1 thread per element in the thread block
|
| 85 |
-
auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
| 86 |
-
auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
| 87 |
-
|
| 88 |
-
// Tile across the block
|
| 89 |
-
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
|
| 90 |
-
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
| 91 |
-
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
| 92 |
-
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
|
| 93 |
-
|
| 94 |
-
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
|
| 95 |
-
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
|
| 96 |
-
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
|
| 97 |
-
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
|
| 98 |
-
|
| 99 |
-
// Make a fragment of registers to hold gmem loads
|
| 100 |
-
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
| 101 |
-
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
|
| 102 |
-
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
|
| 103 |
-
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
| 104 |
-
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
|
| 105 |
-
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
|
| 106 |
-
|
| 107 |
-
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
|
| 108 |
-
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
|
| 109 |
-
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
| 110 |
-
|
| 111 |
-
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
|
| 112 |
-
|
| 113 |
-
for (int ii = 0; ii < num_iters; ++ii) {
|
| 114 |
-
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
|
| 115 |
-
if (thread_offset < cute::size<0>(operand_layout)) {
|
| 116 |
-
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
| 117 |
-
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
| 118 |
-
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
| 119 |
-
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
| 120 |
-
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
| 121 |
-
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{});
|
| 122 |
-
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{});
|
| 123 |
-
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
| 124 |
-
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
| 125 |
-
}
|
| 126 |
-
}
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
template <
|
| 130 |
-
class QuantizedElement,
|
| 131 |
-
class DequantizedElement,
|
| 132 |
-
class OperandLayout,
|
| 133 |
-
class ElementScale,
|
| 134 |
-
class ElementZero,
|
| 135 |
-
class ScaleLayout>
|
| 136 |
-
static void dequantize(DequantizedElement* dq_buffer,
|
| 137 |
-
QuantizedElement const* q_buffer,
|
| 138 |
-
OperandLayout const operand_layout,
|
| 139 |
-
ElementScale const* scale_buffer,
|
| 140 |
-
ElementZero const* zero_buffer,
|
| 141 |
-
ScaleLayout const scale_layout,
|
| 142 |
-
int const group_size,
|
| 143 |
-
cudaStream_t &stream) {
|
| 144 |
-
using namespace cute;
|
| 145 |
-
|
| 146 |
-
constexpr int tpb = 128;
|
| 147 |
-
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
| 148 |
-
|
| 149 |
-
const auto num_rows = get<0>(shape(operand_layout));
|
| 150 |
-
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
| 151 |
-
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
| 152 |
-
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
| 153 |
-
|
| 154 |
-
if (num_rows != size<0>(scale_layout)) {
|
| 155 |
-
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
| 156 |
-
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
| 157 |
-
<< std::endl;
|
| 158 |
-
exit(-1);
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
const auto scale_stride0 = get<0>(stride(scale_layout));
|
| 162 |
-
const auto scale_stride1 = get<1>(stride(scale_layout));
|
| 163 |
-
const auto scale_stride2 = get<2>(stride(scale_layout));
|
| 164 |
-
|
| 165 |
-
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
| 166 |
-
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
| 167 |
-
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
| 168 |
-
|
| 169 |
-
const auto blocks_x = gemm_k;
|
| 170 |
-
const auto blocks_y = batches;
|
| 171 |
-
|
| 172 |
-
dim3 blocks(blocks_x, blocks_y, 1);
|
| 173 |
-
dequantize_kernel<<<blocks, tpb, 0, stream>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
| 174 |
-
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
template <typename T>
|
| 178 |
-
class packed_scale_t {
|
| 179 |
-
public:
|
| 180 |
-
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
| 181 |
-
cute::is_same_v<T, cutlass::uint8_t> ||
|
| 182 |
-
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
| 183 |
-
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
| 184 |
-
"only 8 bit arithmetic types are supported.");
|
| 185 |
-
CUTLASS_HOST_DEVICE
|
| 186 |
-
explicit packed_scale_t(T val) {
|
| 187 |
-
if constexpr (!cute::is_unsigned_v<T>) {
|
| 188 |
-
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
| 189 |
-
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
| 190 |
-
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
| 191 |
-
}
|
| 192 |
-
else {
|
| 193 |
-
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
| 194 |
-
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
| 195 |
-
}
|
| 196 |
-
}
|
| 197 |
-
CUTLASS_HOST_DEVICE
|
| 198 |
-
packed_scale_t() = default;
|
| 199 |
-
CUTLASS_HOST_DEVICE
|
| 200 |
-
explicit operator float() const {
|
| 201 |
-
return float(get());
|
| 202 |
-
}
|
| 203 |
-
CUTLASS_HOST_DEVICE
|
| 204 |
-
bool operator==(packed_scale_t const& rhs) const {
|
| 205 |
-
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
| 206 |
-
}
|
| 207 |
-
CUTLASS_HOST_DEVICE
|
| 208 |
-
bool operator!=(packed_scale_t const& rhs) const {
|
| 209 |
-
return !(*this == rhs);
|
| 210 |
-
}
|
| 211 |
-
CUTLASS_HOST_DEVICE
|
| 212 |
-
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 213 |
-
return packed_scale_t(lhs.get() + rhs.get());
|
| 214 |
-
}
|
| 215 |
-
CUTLASS_HOST_DEVICE
|
| 216 |
-
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 217 |
-
return packed_scale_t(lhs.get() - rhs.get());
|
| 218 |
-
}
|
| 219 |
-
CUTLASS_HOST_DEVICE
|
| 220 |
-
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 221 |
-
return packed_scale_t(lhs.get() * rhs.get());
|
| 222 |
-
}
|
| 223 |
-
CUTLASS_HOST_DEVICE
|
| 224 |
-
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 225 |
-
return packed_scale_t(lhs.get() / rhs.get());
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
private:
|
| 229 |
-
using Storage = uint32_t;
|
| 230 |
-
using Stage = uint8_t;
|
| 231 |
-
|
| 232 |
-
Storage storage[2] {};
|
| 233 |
-
|
| 234 |
-
CUTLASS_HOST_DEVICE
|
| 235 |
-
static Storage pack4(T c1, T c2, T c3, T c4) {
|
| 236 |
-
Storage result = 0;
|
| 237 |
-
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
| 238 |
-
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
| 239 |
-
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
| 240 |
-
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
| 241 |
-
return result;
|
| 242 |
-
}
|
| 243 |
-
CUTLASS_HOST_DEVICE
|
| 244 |
-
T get() const {
|
| 245 |
-
auto stage = static_cast<Stage>(storage[0] >> 8);
|
| 246 |
-
#if defined(__CUDA_ARCH__)
|
| 247 |
-
return reinterpret_cast<T const&>(stage);
|
| 248 |
-
#else
|
| 249 |
-
T tmp;
|
| 250 |
-
std::memcpy(&tmp, &stage, sizeof(Stage));
|
| 251 |
-
return tmp;
|
| 252 |
-
#endif
|
| 253 |
-
}
|
| 254 |
-
CUTLASS_HOST_DEVICE
|
| 255 |
-
T get(int idx) const {
|
| 256 |
-
Stage stage;
|
| 257 |
-
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
| 258 |
-
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
| 259 |
-
#if defined(__CUDA_ARCH__)
|
| 260 |
-
return reinterpret_cast<T const&>(stage);
|
| 261 |
-
#else
|
| 262 |
-
T tmp;
|
| 263 |
-
std::memcpy(&tmp, &stage, sizeof(Stage));
|
| 264 |
-
return tmp;
|
| 265 |
-
#endif
|
| 266 |
-
}
|
| 267 |
-
};
|
| 268 |
-
|
| 269 |
-
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
| 270 |
-
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
| 271 |
-
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
| 272 |
-
static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) {
|
| 273 |
-
|
| 274 |
-
using StorageType = cutlass::int4b_t::Storage;
|
| 275 |
-
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
|
| 276 |
-
const size_t host_buf_size = block_size / pack;
|
| 277 |
-
std::vector<StorageType> host_buf(host_buf_size);
|
| 278 |
-
cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size);
|
| 279 |
-
|
| 280 |
-
for (auto&& d : host_buf) {
|
| 281 |
-
StorageType out = 0;
|
| 282 |
-
StorageType mask = 0x0f;
|
| 283 |
-
for (int i = 0; i < pack; i++) {
|
| 284 |
-
cutlass::int4b_t curr;
|
| 285 |
-
curr.storage = (d >> (i * 4)) & 0x0f;
|
| 286 |
-
switch (curr) {
|
| 287 |
-
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
| 288 |
-
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
| 289 |
-
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
| 290 |
-
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
| 291 |
-
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
| 292 |
-
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
| 293 |
-
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
| 294 |
-
default: break;
|
| 295 |
-
}
|
| 296 |
-
out |= (curr.storage << (4 * i)) & mask;
|
| 297 |
-
mask <<= 4;
|
| 298 |
-
}
|
| 299 |
-
d = out;
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size);
|
| 303 |
-
return true;
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
template <class ElementScale>
|
| 307 |
-
static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array<ElementScale, 8> *block_out, const size_t block_size) {
|
| 308 |
-
std::vector<ElementScale> data_in(block_size);
|
| 309 |
-
std::vector<cutlass::Array<ElementScale, 8>> data_out(block_size);
|
| 310 |
-
|
| 311 |
-
try {
|
| 312 |
-
cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size);
|
| 313 |
-
}
|
| 314 |
-
catch (cutlass::cuda_exception const& e) {
|
| 315 |
-
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
| 316 |
-
return false;
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
for (size_t i = 0; i < block_size; i++) {
|
| 320 |
-
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
| 321 |
-
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
-
try {
|
| 325 |
-
cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size);
|
| 326 |
-
}
|
| 327 |
-
catch (cutlass::cuda_exception const& e) {
|
| 328 |
-
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
| 329 |
-
return false;
|
| 330 |
-
}
|
| 331 |
-
return true;
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
template <class T, class = void>
|
| 335 |
-
struct UnderlyingElement {
|
| 336 |
-
using type = T;
|
| 337 |
-
};
|
| 338 |
-
|
| 339 |
-
template <class T>
|
| 340 |
-
struct UnderlyingElement<T, cute::void_t<typename T::Element>> {
|
| 341 |
-
using type = typename T::Element;
|
| 342 |
-
};
|
| 343 |
-
|
| 344 |
-
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
| 345 |
-
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
| 346 |
-
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
| 347 |
-
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
| 348 |
-
// In addition, we can reorder the values across several MMA instructions to get even wider
|
| 349 |
-
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
|
| 350 |
-
// more optimal conversion instruction sequences (ValLayout parameter).
|
| 351 |
-
template <class ElementMma,
|
| 352 |
-
class AtomLayout = cute::Layout<cute::_1>,
|
| 353 |
-
class ValLayout = cute::Layout<cute::_1>>
|
| 354 |
-
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
|
| 355 |
-
{
|
| 356 |
-
using namespace cute;
|
| 357 |
-
|
| 358 |
-
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
|
| 359 |
-
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
|
| 360 |
-
|
| 361 |
-
// 1. Choose an MMA atom to access TV layout and MN shape
|
| 362 |
-
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
| 363 |
-
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
|
| 364 |
-
using MmaTraits = MMA_Traits<MmaAtom>;
|
| 365 |
-
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
|
| 366 |
-
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
| 367 |
-
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
|
| 368 |
-
|
| 369 |
-
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
|
| 370 |
-
// Note: this assumes A is partitioned between warps along M mode
|
| 371 |
-
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
| 372 |
-
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
| 373 |
-
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
|
| 374 |
-
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
|
| 375 |
-
|
| 376 |
-
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
|
| 377 |
-
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
|
| 378 |
-
|
| 379 |
-
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
| 380 |
-
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
|
| 381 |
-
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
|
| 382 |
-
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
|
| 383 |
-
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
|
| 384 |
-
|
| 385 |
-
return layout_atom;
|
| 386 |
-
}
|
| 387 |
-
|
| 388 |
-
template <class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
|
| 389 |
-
__global__ void reorder_tensor_kernel(
|
| 390 |
-
cute::Tensor<EngineSrc, LayoutSrc> S,
|
| 391 |
-
cute::Tensor<EngineDst, LayoutDst> D,
|
| 392 |
-
TiledCopy tiled_copy)
|
| 393 |
-
{
|
| 394 |
-
using namespace cute;
|
| 395 |
-
|
| 396 |
-
using T = typename EngineDst::value_type;
|
| 397 |
-
|
| 398 |
-
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
| 399 |
-
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
| 400 |
-
|
| 401 |
-
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
|
| 402 |
-
Tensor tS = thread_copy.partition_S(gS);
|
| 403 |
-
Tensor tD = thread_copy.partition_D(gD);
|
| 404 |
-
|
| 405 |
-
copy(tiled_copy, tS, tD);
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
template <class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
| 409 |
-
void reorder_tensor(
|
| 410 |
-
cute::Tensor<EngineSrc, LayoutSrc> S,
|
| 411 |
-
cute::Tensor<EngineDst, LayoutDst> D)
|
| 412 |
-
{
|
| 413 |
-
using namespace cute;
|
| 414 |
-
|
| 415 |
-
using T = typename EngineDst::value_type;
|
| 416 |
-
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
| 417 |
-
|
| 418 |
-
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
|
| 419 |
-
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
|
| 420 |
-
auto has_major_mode = [](auto s) {
|
| 421 |
-
return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; });
|
| 422 |
-
};
|
| 423 |
-
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
|
| 424 |
-
"Could not find stride-1 mode in destination layout");
|
| 425 |
-
constexpr int N = shape_div(Int<8>{}, Int<sizeof_bits_v<T>>{});
|
| 426 |
-
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
|
| 427 |
-
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
|
| 428 |
-
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
|
| 429 |
-
|
| 430 |
-
// Make a tiled copy with a simple row-major thread order and above layout
|
| 431 |
-
int constexpr NumThreads = 128;
|
| 432 |
-
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
|
| 433 |
-
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
|
| 434 |
-
|
| 435 |
-
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
|
| 436 |
-
using TileShape = Shape<_16>;
|
| 437 |
-
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
|
| 438 |
-
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
|
| 439 |
-
|
| 440 |
-
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
|
| 441 |
-
CUDA_CHECK(cudaDeviceSynchronize());
|
| 442 |
-
}
|
| 443 |
-
|
| 444 |
-
// In-place version
|
| 445 |
-
template <class T, class LayoutSrc, class LayoutDst>
|
| 446 |
-
void reorder_tensor(
|
| 447 |
-
T const* src,
|
| 448 |
-
LayoutSrc const& layout_src,
|
| 449 |
-
T * dst,
|
| 450 |
-
LayoutDst const& layout_dst)
|
| 451 |
-
{
|
| 452 |
-
using namespace cute;
|
| 453 |
-
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
|
| 454 |
-
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
|
| 455 |
-
}
|
| 456 |
-
|
| 457 |
-
// In-place version
|
| 458 |
-
template <class T, class LayoutSrc, class LayoutDst>
|
| 459 |
-
void reorder_tensor(
|
| 460 |
-
T * data,
|
| 461 |
-
LayoutSrc const& layout_src,
|
| 462 |
-
LayoutDst const& layout_dst)
|
| 463 |
-
{
|
| 464 |
-
using namespace cute;
|
| 465 |
-
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
| 466 |
-
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
| 467 |
-
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
| 468 |
-
}
|
| 469 |
-
|
| 470 |
-
#undef CUDA_CHECK
|
| 471 |
-
|
| 472 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp
DELETED
|
@@ -1,570 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cute/layout.hpp"
|
| 38 |
-
#include "cute/container/array.hpp" // cute::array
|
| 39 |
-
#include "cutlass/conv/convolution.h" // cutlass::conv::Operator
|
| 40 |
-
|
| 41 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
-
|
| 43 |
-
namespace cutlass {
|
| 44 |
-
|
| 45 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
-
|
| 47 |
-
// Strides without batch mode
|
| 48 |
-
|
| 49 |
-
template <class IntT>
|
| 50 |
-
CUTLASS_HOST_DEVICE
|
| 51 |
-
cute::Stride<IntT, cute::Int<1>>
|
| 52 |
-
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
|
| 53 |
-
static_assert(std::is_integral_v<IntT>,
|
| 54 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 55 |
-
auto s_copy = s;
|
| 56 |
-
cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
|
| 57 |
-
return s_copy;
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
template <class IntT>
|
| 61 |
-
CUTLASS_HOST_DEVICE
|
| 62 |
-
cute::Stride<cute::Int<1>, IntT>
|
| 63 |
-
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,int> shape_MKL) {
|
| 64 |
-
static_assert(std::is_integral_v<IntT>,
|
| 65 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 66 |
-
auto s_copy = s;
|
| 67 |
-
cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
|
| 68 |
-
return s_copy;
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 72 |
-
|
| 73 |
-
// Strides with batch mode
|
| 74 |
-
|
| 75 |
-
template <class IntT>
|
| 76 |
-
CUTLASS_HOST_DEVICE
|
| 77 |
-
cute::Stride<IntT, cute::Int<1>, int64_t>
|
| 78 |
-
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
|
| 79 |
-
static_assert(std::is_integral_v<IntT>,
|
| 80 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 81 |
-
auto s_copy = s;
|
| 82 |
-
cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
|
| 83 |
-
int batch_count = cute::get<2>(shape_MKL);
|
| 84 |
-
if (batch_count > 1) {
|
| 85 |
-
cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
|
| 86 |
-
}
|
| 87 |
-
else {
|
| 88 |
-
cute::get<2>(s_copy) = static_cast<IntT>(0);
|
| 89 |
-
}
|
| 90 |
-
return s_copy;
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
template <class IntT>
|
| 94 |
-
CUTLASS_HOST_DEVICE
|
| 95 |
-
cute::Stride<cute::Int<1>, IntT, int64_t>
|
| 96 |
-
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
|
| 97 |
-
static_assert(std::is_integral_v<IntT>,
|
| 98 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 99 |
-
auto s_copy = s;
|
| 100 |
-
cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
|
| 101 |
-
int batch_count = cute::get<2>(shape_MKL);
|
| 102 |
-
if (batch_count > 1) {
|
| 103 |
-
cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
|
| 104 |
-
}
|
| 105 |
-
else {
|
| 106 |
-
cute::get<2>(s_copy) = static_cast<IntT>(0);
|
| 107 |
-
}
|
| 108 |
-
return s_copy;
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 112 |
-
|
| 113 |
-
// Strides with group mode
|
| 114 |
-
|
| 115 |
-
template <class StrideIntT>
|
| 116 |
-
CUTLASS_HOST_DEVICE
|
| 117 |
-
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
|
| 118 |
-
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
| 119 |
-
static_assert(std::is_integral_v<StrideIntT>,
|
| 120 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 121 |
-
auto s_copy = s;
|
| 122 |
-
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
|
| 123 |
-
return s_copy;
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
template <class StrideIntT>
|
| 127 |
-
CUTLASS_HOST_DEVICE
|
| 128 |
-
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
|
| 129 |
-
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
| 130 |
-
static_assert(std::is_integral_v<StrideIntT>,
|
| 131 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 132 |
-
auto s_copy = s;
|
| 133 |
-
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
|
| 134 |
-
return s_copy;
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 138 |
-
|
| 139 |
-
// Strides for convolutions
|
| 140 |
-
|
| 141 |
-
// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
|
| 142 |
-
// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order
|
| 143 |
-
// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout
|
| 144 |
-
// right in KTRSC order and can be coalesced to just k.
|
| 145 |
-
// We enforce this condition here with asserts.
|
| 146 |
-
template <class IntT, size_t RankT_>
|
| 147 |
-
CUTLASS_HOST_DEVICE
|
| 148 |
-
cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
|
| 149 |
-
make_cute_packed_stride(
|
| 150 |
-
cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
|
| 151 |
-
cute::array<int32_t, RankT_> shape_output,
|
| 152 |
-
cute::array<IntT, RankT_> stride_output,
|
| 153 |
-
cutlass::conv::Operator conv_op) {
|
| 154 |
-
static_assert(std::is_integral_v<IntT>,
|
| 155 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 156 |
-
static_assert(RankT_ >= 3u);
|
| 157 |
-
constexpr static int RankT = static_cast<int>(RankT_);
|
| 158 |
-
|
| 159 |
-
assert(stride_output[RankT-1] == 1);
|
| 160 |
-
cute::for_each(cute::make_seq<RankT-2>{}, [&](auto i) {
|
| 161 |
-
assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]);
|
| 162 |
-
});
|
| 163 |
-
|
| 164 |
-
auto s_copy = s;
|
| 165 |
-
cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ?
|
| 166 |
-
stride_output[0] :
|
| 167 |
-
stride_output[RankT-2];
|
| 168 |
-
return s_copy;
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
//
|
| 172 |
-
// Activation tensor ((w, h, d, n), _1) for fprop kernel
|
| 173 |
-
//
|
| 174 |
-
|
| 175 |
-
// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
|
| 176 |
-
template <class IntT>
|
| 177 |
-
CUTLASS_HOST_DEVICE
|
| 178 |
-
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
|
| 179 |
-
make_cute_packed_stride(
|
| 180 |
-
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
|
| 181 |
-
cute::array<IntT, 3> stride_nwc,
|
| 182 |
-
conv::Operator ConvOp) {
|
| 183 |
-
static_assert(std::is_integral_v<IntT>,
|
| 184 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 185 |
-
assert(stride_nwc[2] == 1);
|
| 186 |
-
auto s_copy = s;
|
| 187 |
-
cute::get<0,0>(s_copy) = stride_nwc[1];
|
| 188 |
-
cute::get<0,1>(s_copy) = stride_nwc[0];
|
| 189 |
-
return s_copy;
|
| 190 |
-
}
|
| 191 |
-
|
| 192 |
-
// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
|
| 193 |
-
template <class IntT>
|
| 194 |
-
CUTLASS_HOST_DEVICE
|
| 195 |
-
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
|
| 196 |
-
make_cute_packed_stride(
|
| 197 |
-
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
|
| 198 |
-
cute::array<IntT, 4> stride_nhwc,
|
| 199 |
-
conv::Operator ConvOp) {
|
| 200 |
-
static_assert(std::is_integral_v<IntT>,
|
| 201 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 202 |
-
assert(stride_nhwc[3] == 1);
|
| 203 |
-
auto s_copy = s;
|
| 204 |
-
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 205 |
-
cute::get<0,i>(s_copy) = stride_nhwc[2-i];
|
| 206 |
-
});
|
| 207 |
-
return s_copy;
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
|
| 211 |
-
template <class IntT>
|
| 212 |
-
CUTLASS_HOST_DEVICE
|
| 213 |
-
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
|
| 214 |
-
make_cute_packed_stride(
|
| 215 |
-
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
|
| 216 |
-
cute::array<IntT, 5> stride_ndhwc,
|
| 217 |
-
conv::Operator ConvOp) {
|
| 218 |
-
static_assert(std::is_integral_v<IntT>,
|
| 219 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 220 |
-
|
| 221 |
-
assert(stride_ndhwc[4] == 1);
|
| 222 |
-
auto s_copy = s;
|
| 223 |
-
cute::for_each(cute::make_seq<4>{}, [&](auto i) {
|
| 224 |
-
cute::get<0,i>(s_copy) = stride_ndhwc[3-i];
|
| 225 |
-
});
|
| 226 |
-
return s_copy;
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
//
|
| 230 |
-
// Filter tensor (k, (_1, s, r, t)) for fprop kernel
|
| 231 |
-
//
|
| 232 |
-
|
| 233 |
-
// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
|
| 234 |
-
template <class IntT>
|
| 235 |
-
CUTLASS_HOST_DEVICE
|
| 236 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
|
| 237 |
-
make_cute_packed_stride(
|
| 238 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
|
| 239 |
-
cute::array<IntT, 3> stride_ksc,
|
| 240 |
-
conv::Operator ConvOp) {
|
| 241 |
-
static_assert(std::is_integral_v<IntT>,
|
| 242 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 243 |
-
|
| 244 |
-
assert(stride_ksc[2] == 1);
|
| 245 |
-
auto s_copy = s;
|
| 246 |
-
cute::get<0,0>(s_copy) = stride_ksc[0];
|
| 247 |
-
cute::get<1,1>(s_copy) = stride_ksc[1];
|
| 248 |
-
return s_copy;
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
|
| 252 |
-
template <class IntT>
|
| 253 |
-
CUTLASS_HOST_DEVICE
|
| 254 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
|
| 255 |
-
make_cute_packed_stride(
|
| 256 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
|
| 257 |
-
cute::array<IntT, 4> stride_krsc,
|
| 258 |
-
conv::Operator ConvOp) {
|
| 259 |
-
static_assert(std::is_integral_v<IntT>,
|
| 260 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 261 |
-
|
| 262 |
-
assert(stride_krsc[3] == 1);
|
| 263 |
-
auto s_copy = s;
|
| 264 |
-
cute::get<0,0>(s_copy) = stride_krsc[0];
|
| 265 |
-
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 266 |
-
cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
|
| 267 |
-
});
|
| 268 |
-
return s_copy;
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
|
| 272 |
-
template <class IntT>
|
| 273 |
-
CUTLASS_HOST_DEVICE
|
| 274 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
|
| 275 |
-
make_cute_packed_stride(
|
| 276 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
|
| 277 |
-
cute::array<IntT, 5> stride_ktrsc,
|
| 278 |
-
conv::Operator ConvOp) {
|
| 279 |
-
static_assert(std::is_integral_v<IntT>,
|
| 280 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 281 |
-
|
| 282 |
-
assert(stride_ktrsc[4] == 1);
|
| 283 |
-
auto s_copy = s;
|
| 284 |
-
cute::get<0,0>(s_copy) = stride_ktrsc[0];
|
| 285 |
-
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 286 |
-
cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
|
| 287 |
-
});
|
| 288 |
-
return s_copy;
|
| 289 |
-
}
|
| 290 |
-
|
| 291 |
-
//
|
| 292 |
-
// Activation tensor (_1, (w, h, d, n)) for wgrad kernel
|
| 293 |
-
//
|
| 294 |
-
// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel
|
| 295 |
-
//
|
| 296 |
-
|
| 297 |
-
// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
|
| 298 |
-
// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
|
| 299 |
-
template <class IntT>
|
| 300 |
-
CUTLASS_HOST_DEVICE
|
| 301 |
-
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
|
| 302 |
-
make_cute_packed_stride(
|
| 303 |
-
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
|
| 304 |
-
cute::array<IntT, 3> stride_nwc,
|
| 305 |
-
conv::Operator ConvOp) {
|
| 306 |
-
static_assert(std::is_integral_v<IntT>,
|
| 307 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 308 |
-
|
| 309 |
-
assert(stride_nwc[2] == 1);
|
| 310 |
-
auto s_copy = s;
|
| 311 |
-
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 312 |
-
cute::get<1,0>(s_copy) = stride_nwc[1];
|
| 313 |
-
cute::get<1,1>(s_copy) = stride_nwc[0];
|
| 314 |
-
}
|
| 315 |
-
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 316 |
-
// stride_nwc in dgrad is ksc.
|
| 317 |
-
cute::get<1,0>(s_copy) = stride_nwc[0];
|
| 318 |
-
cute::get<1,1>(s_copy) = stride_nwc[1];
|
| 319 |
-
}
|
| 320 |
-
return s_copy;
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
-
// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
|
| 324 |
-
// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
|
| 325 |
-
template <class IntT>
|
| 326 |
-
CUTLASS_HOST_DEVICE
|
| 327 |
-
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
|
| 328 |
-
make_cute_packed_stride(
|
| 329 |
-
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
|
| 330 |
-
cute::array<IntT, 4> stride_nhwc,
|
| 331 |
-
conv::Operator ConvOp) {
|
| 332 |
-
static_assert(std::is_integral_v<IntT>,
|
| 333 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 334 |
-
|
| 335 |
-
assert(stride_nhwc[3] == 1);
|
| 336 |
-
auto s_copy = s;
|
| 337 |
-
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 338 |
-
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 339 |
-
cute::get<1,i>(s_copy) = stride_nhwc[2-i];
|
| 340 |
-
});
|
| 341 |
-
}
|
| 342 |
-
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 343 |
-
// stride_nhwc in dgrad is krsc.
|
| 344 |
-
cute::get<1,0>(s_copy) = stride_nhwc[0];
|
| 345 |
-
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 346 |
-
cute::get<1,2-i>(s_copy) = stride_nhwc[i+1];
|
| 347 |
-
});
|
| 348 |
-
}
|
| 349 |
-
return s_copy;
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
|
| 353 |
-
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
|
| 354 |
-
template <class IntT>
|
| 355 |
-
CUTLASS_HOST_DEVICE
|
| 356 |
-
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
|
| 357 |
-
make_cute_packed_stride(
|
| 358 |
-
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
|
| 359 |
-
cute::array<IntT, 5> stride_ndhwc,
|
| 360 |
-
conv::Operator ConvOp) {
|
| 361 |
-
static_assert(std::is_integral_v<IntT>,
|
| 362 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 363 |
-
|
| 364 |
-
assert(stride_ndhwc[4] == 1);
|
| 365 |
-
auto s_copy = s;
|
| 366 |
-
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 367 |
-
cute::for_each(cute::make_seq<4>{}, [&](auto i) {
|
| 368 |
-
cute::get<1,i>(s_copy) = stride_ndhwc[3-i];
|
| 369 |
-
});
|
| 370 |
-
}
|
| 371 |
-
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 372 |
-
// stride_ndhwc in dgrad is ktrsc.
|
| 373 |
-
cute::get<1,0>(s_copy) = stride_ndhwc[0];
|
| 374 |
-
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 375 |
-
cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1];
|
| 376 |
-
});
|
| 377 |
-
}
|
| 378 |
-
return s_copy;
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
//
|
| 382 |
-
// NZPQ tensor (_1, nzpq) for wgrad kernel
|
| 383 |
-
//
|
| 384 |
-
|
| 385 |
-
// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
|
| 386 |
-
template <class IntT>
|
| 387 |
-
CUTLASS_HOST_DEVICE
|
| 388 |
-
cute::Stride<cute::Int<1>, IntT>
|
| 389 |
-
make_cute_packed_stride(
|
| 390 |
-
cute::Stride<cute::Int<1>, IntT> s,
|
| 391 |
-
cute::array<IntT, 3> stride_nqk,
|
| 392 |
-
conv::Operator ConvOp) {
|
| 393 |
-
static_assert(std::is_integral_v<IntT>,
|
| 394 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 395 |
-
|
| 396 |
-
assert(stride_nqk[2] == 1);
|
| 397 |
-
auto s_copy = s;
|
| 398 |
-
cute::get<1>(s_copy) = stride_nqk[1];
|
| 399 |
-
return s_copy;
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
|
| 403 |
-
template <class IntT>
|
| 404 |
-
CUTLASS_HOST_DEVICE
|
| 405 |
-
cute::Stride<cute::Int<1>, IntT>
|
| 406 |
-
make_cute_packed_stride(
|
| 407 |
-
cute::Stride<cute::Int<1>, IntT> s,
|
| 408 |
-
cute::array<IntT, 4> stride_npqk,
|
| 409 |
-
conv::Operator ConvOp) {
|
| 410 |
-
static_assert(std::is_integral_v<IntT>,
|
| 411 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 412 |
-
|
| 413 |
-
assert(stride_npqk[3] == 1);
|
| 414 |
-
auto s_copy = s;
|
| 415 |
-
cute::get<1>(s_copy) = stride_npqk[2];
|
| 416 |
-
return s_copy;
|
| 417 |
-
}
|
| 418 |
-
|
| 419 |
-
// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
|
| 420 |
-
template <class IntT>
|
| 421 |
-
CUTLASS_HOST_DEVICE
|
| 422 |
-
cute::Stride<cute::Int<1>, IntT>
|
| 423 |
-
make_cute_packed_stride(
|
| 424 |
-
cute::Stride<cute::Int<1>, IntT> s,
|
| 425 |
-
cute::array<IntT, 5> stride_nzpqk,
|
| 426 |
-
conv::Operator ConvOp) {
|
| 427 |
-
static_assert(std::is_integral_v<IntT>,
|
| 428 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 429 |
-
|
| 430 |
-
assert(stride_nzpqk[4] == 1);
|
| 431 |
-
auto s_copy = s;
|
| 432 |
-
cute::get<1>(s_copy) = stride_nzpqk[3];
|
| 433 |
-
return s_copy;
|
| 434 |
-
}
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
//
|
| 439 |
-
// Wgrad output tensor (k, (_1, s, r, t), _0)
|
| 440 |
-
//
|
| 441 |
-
|
| 442 |
-
// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
|
| 443 |
-
template <class IntT>
|
| 444 |
-
CUTLASS_HOST_DEVICE
|
| 445 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
|
| 446 |
-
make_cute_packed_stride(
|
| 447 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
|
| 448 |
-
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
|
| 449 |
-
cute::array<IntT, 3> stride_ksc,
|
| 450 |
-
conv::Operator ConvOp) {
|
| 451 |
-
static_assert(std::is_integral_v<IntT>,
|
| 452 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 453 |
-
|
| 454 |
-
assert(stride_ksc[2] == 1);
|
| 455 |
-
auto s_copy = s;
|
| 456 |
-
cute::get<0,0>(s_copy) = stride_ksc[0];
|
| 457 |
-
cute::get<1,1>(s_copy) = stride_ksc[1];
|
| 458 |
-
return s_copy;
|
| 459 |
-
}
|
| 460 |
-
|
| 461 |
-
// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
|
| 462 |
-
template <class IntT>
|
| 463 |
-
CUTLASS_HOST_DEVICE
|
| 464 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
|
| 465 |
-
make_cute_packed_stride(
|
| 466 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
|
| 467 |
-
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
|
| 468 |
-
cute::array<IntT, 4> stride_krsc,
|
| 469 |
-
conv::Operator ConvOp) {
|
| 470 |
-
static_assert(std::is_integral_v<IntT>,
|
| 471 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 472 |
-
|
| 473 |
-
assert(stride_krsc[3] == 1);
|
| 474 |
-
auto s_copy = s;
|
| 475 |
-
cute::get<0,0>(s_copy) = stride_krsc[0];
|
| 476 |
-
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 477 |
-
cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
|
| 478 |
-
});
|
| 479 |
-
return s_copy;
|
| 480 |
-
}
|
| 481 |
-
|
| 482 |
-
// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
|
| 483 |
-
template <class IntT>
|
| 484 |
-
CUTLASS_HOST_DEVICE
|
| 485 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
|
| 486 |
-
make_cute_packed_stride(
|
| 487 |
-
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
|
| 488 |
-
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
|
| 489 |
-
cute::array<IntT, 5> stride_ktrsc,
|
| 490 |
-
conv::Operator ConvOp) {
|
| 491 |
-
static_assert(std::is_integral_v<IntT>,
|
| 492 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 493 |
-
|
| 494 |
-
assert(stride_ktrsc[4] == 1);
|
| 495 |
-
auto s_copy = s;
|
| 496 |
-
cute::get<0,0>(s_copy) = stride_ktrsc[0];
|
| 497 |
-
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 498 |
-
cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
|
| 499 |
-
});
|
| 500 |
-
return s_copy;
|
| 501 |
-
}
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
//
|
| 505 |
-
// Wgrad output tensor ((_1, s, r, t), k, _0)
|
| 506 |
-
//
|
| 507 |
-
|
| 508 |
-
// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0)
|
| 509 |
-
template <class IntT>
|
| 510 |
-
CUTLASS_HOST_DEVICE
|
| 511 |
-
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>>
|
| 512 |
-
make_cute_packed_stride(
|
| 513 |
-
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>> s,
|
| 514 |
-
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
|
| 515 |
-
cute::array<IntT, 3> stride_ksc,
|
| 516 |
-
conv::Operator ConvOp) {
|
| 517 |
-
static_assert(std::is_integral_v<IntT>,
|
| 518 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 519 |
-
|
| 520 |
-
assert(stride_ksc[2] == 1);
|
| 521 |
-
auto s_copy = s;
|
| 522 |
-
cute::get<1,0>(s_copy) = stride_ksc[0];
|
| 523 |
-
cute::get<0,1>(s_copy) = stride_ksc[1];
|
| 524 |
-
return s_copy;
|
| 525 |
-
}
|
| 526 |
-
|
| 527 |
-
// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0)
|
| 528 |
-
template <class IntT>
|
| 529 |
-
CUTLASS_HOST_DEVICE
|
| 530 |
-
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>>
|
| 531 |
-
make_cute_packed_stride(
|
| 532 |
-
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>> s,
|
| 533 |
-
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
|
| 534 |
-
cute::array<IntT, 4> stride_krsc,
|
| 535 |
-
conv::Operator ConvOp) {
|
| 536 |
-
static_assert(std::is_integral_v<IntT>,
|
| 537 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 538 |
-
|
| 539 |
-
assert(stride_krsc[3] == 1);
|
| 540 |
-
auto s_copy = s;
|
| 541 |
-
cute::get<1,0>(s_copy) = stride_krsc[0];
|
| 542 |
-
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 543 |
-
cute::get<0,2-i>(s_copy) = stride_krsc[i+1];
|
| 544 |
-
});
|
| 545 |
-
return s_copy;
|
| 546 |
-
}
|
| 547 |
-
|
| 548 |
-
// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0)
|
| 549 |
-
template <class IntT>
|
| 550 |
-
CUTLASS_HOST_DEVICE
|
| 551 |
-
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>>
|
| 552 |
-
make_cute_packed_stride(
|
| 553 |
-
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>> s,
|
| 554 |
-
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
|
| 555 |
-
cute::array<IntT, 5> stride_ktrsc,
|
| 556 |
-
conv::Operator ConvOp) {
|
| 557 |
-
static_assert(std::is_integral_v<IntT>,
|
| 558 |
-
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 559 |
-
|
| 560 |
-
assert(stride_ktrsc[4] == 1);
|
| 561 |
-
auto s_copy = s;
|
| 562 |
-
cute::get<1,0>(s_copy) = stride_ktrsc[0];
|
| 563 |
-
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 564 |
-
cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1];
|
| 565 |
-
});
|
| 566 |
-
return s_copy;
|
| 567 |
-
}
|
| 568 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 569 |
-
|
| 570 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp
DELETED
|
@@ -1,341 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <array>
|
| 35 |
-
#include <cassert>
|
| 36 |
-
#include <cmath>
|
| 37 |
-
#include <iostream>
|
| 38 |
-
#include <type_traits>
|
| 39 |
-
|
| 40 |
-
#include <cute/util/type_traits.hpp>
|
| 41 |
-
#include <cute/tensor.hpp>
|
| 42 |
-
|
| 43 |
-
#include <cute/numeric/numeric_types.hpp>
|
| 44 |
-
#include <cute/numeric/complex.hpp>
|
| 45 |
-
|
| 46 |
-
#include <cutlass/layout/layout.h>
|
| 47 |
-
|
| 48 |
-
// The computed infinity norm does not include
|
| 49 |
-
// any NaN column absolute-value sums.
|
| 50 |
-
struct matrix_inf_norm_result {
|
| 51 |
-
// Accumulate errors in double, as this is generally
|
| 52 |
-
// the highest precision that the examples use.
|
| 53 |
-
double inf_norm = 0.0;
|
| 54 |
-
bool found_nan = false;
|
| 55 |
-
};
|
| 56 |
-
|
| 57 |
-
// In theory, cute::Tensor<ViewEngine<T*>, T> could be treated as a view type,
|
| 58 |
-
// and thus passed by value (as std::span or std::string_view would be).
|
| 59 |
-
// However, generic cute::Tensor are more like containers
|
| 60 |
-
// and thus are best passed by reference or const reference.
|
| 61 |
-
template <typename EngineType, typename LayoutType>
|
| 62 |
-
matrix_inf_norm_result
|
| 63 |
-
matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
|
| 64 |
-
{
|
| 65 |
-
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
| 66 |
-
using element_type = typename EngineType::value_type;
|
| 67 |
-
|
| 68 |
-
error_type inf_norm = 0.0;
|
| 69 |
-
bool found_nan = false;
|
| 70 |
-
|
| 71 |
-
// Computing the infinity norm requires that we be able
|
| 72 |
-
// to treat the input as a matrix, with rows and columns.
|
| 73 |
-
const int64_t num_rows = cute::size<0>(host_matrix);
|
| 74 |
-
const int64_t num_cols = cute::size<1>(host_matrix);
|
| 75 |
-
|
| 76 |
-
auto abs_fn = [] (element_type A_ij) {
|
| 77 |
-
if constexpr (not std::is_unsigned_v<element_type>) {
|
| 78 |
-
using std::abs;
|
| 79 |
-
return abs(A_ij);
|
| 80 |
-
}
|
| 81 |
-
else {
|
| 82 |
-
return A_ij;
|
| 83 |
-
}
|
| 84 |
-
};
|
| 85 |
-
|
| 86 |
-
for (int64_t i = 0; i < num_rows; ++i) {
|
| 87 |
-
error_type row_abs_sum = 0.0;
|
| 88 |
-
for(int64_t j = 0; j < num_cols; ++j) {
|
| 89 |
-
row_abs_sum += abs_fn(host_matrix(i, j));
|
| 90 |
-
}
|
| 91 |
-
if (std::isnan(row_abs_sum)) {
|
| 92 |
-
found_nan = true;
|
| 93 |
-
}
|
| 94 |
-
else {
|
| 95 |
-
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
|
| 96 |
-
}
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
return {inf_norm, found_nan};
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
// Infinity norm of (X - Y).
|
| 103 |
-
template <typename EngineType, typename LayoutType>
|
| 104 |
-
matrix_inf_norm_result
|
| 105 |
-
matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
|
| 106 |
-
cute::Tensor<EngineType, LayoutType> const& Y)
|
| 107 |
-
{
|
| 108 |
-
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
| 109 |
-
using element_type = typename EngineType::value_type;
|
| 110 |
-
|
| 111 |
-
auto abs_fn = [] (element_type A_ij) {
|
| 112 |
-
if constexpr (not std::is_unsigned_v<element_type>) {
|
| 113 |
-
using std::abs;
|
| 114 |
-
return abs(A_ij);
|
| 115 |
-
}
|
| 116 |
-
else {
|
| 117 |
-
return A_ij;
|
| 118 |
-
}
|
| 119 |
-
};
|
| 120 |
-
|
| 121 |
-
assert(cute::size<0>(X) == cute::size<0>(Y));
|
| 122 |
-
assert(cute::size<1>(X) == cute::size<1>(Y));
|
| 123 |
-
|
| 124 |
-
// Computing the infinity norm requires that we be able
|
| 125 |
-
// to treat the input as a matrix, with rows and columns.
|
| 126 |
-
const int64_t num_rows = cute::size<0>(X);
|
| 127 |
-
const int64_t num_cols = cute::size<1>(X);
|
| 128 |
-
|
| 129 |
-
error_type inf_norm = 0.0;
|
| 130 |
-
bool found_nan = false;
|
| 131 |
-
|
| 132 |
-
for (int64_t i = 0; i < num_rows; ++i) {
|
| 133 |
-
error_type row_abs_sum = 0.0;
|
| 134 |
-
for (int64_t j = 0; j < num_cols; ++j) {
|
| 135 |
-
row_abs_sum += error_type(abs_fn(element_type(X(i,j)) -
|
| 136 |
-
element_type(Y(i,j))));
|
| 137 |
-
}
|
| 138 |
-
if (std::isnan(row_abs_sum)) {
|
| 139 |
-
found_nan = true;
|
| 140 |
-
}
|
| 141 |
-
else {
|
| 142 |
-
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
|
| 143 |
-
}
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
return {inf_norm, found_nan};
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
template <typename EngineType_A, typename LayoutType_A,
|
| 150 |
-
typename EngineType_B, typename LayoutType_B,
|
| 151 |
-
typename EngineType_C, typename LayoutType_C,
|
| 152 |
-
typename EngineType_C_ref, typename LayoutType_C_ref>
|
| 153 |
-
auto
|
| 154 |
-
print_matrix_multiply_mollified_relative_error(
|
| 155 |
-
char const A_value_type_name[],
|
| 156 |
-
cute::Tensor<EngineType_A, LayoutType_A> const& A,
|
| 157 |
-
char const B_value_type_name[],
|
| 158 |
-
cute::Tensor<EngineType_B, LayoutType_B> const& B,
|
| 159 |
-
char const C_value_type_name[],
|
| 160 |
-
cute::Tensor<EngineType_C, LayoutType_C> const& C,
|
| 161 |
-
cute::Tensor<EngineType_C_ref, LayoutType_C_ref> const& C_ref)
|
| 162 |
-
{
|
| 163 |
-
const auto [A_norm, A_has_nan] = matrix_inf_norm(A);
|
| 164 |
-
const auto [B_norm, B_has_nan] = matrix_inf_norm(B);
|
| 165 |
-
const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref);
|
| 166 |
-
const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref);
|
| 167 |
-
|
| 168 |
-
const auto A_norm_times_B_norm = A_norm * B_norm;
|
| 169 |
-
const auto relative_error = A_norm_times_B_norm == 0.0 ?
|
| 170 |
-
diff_norm : (diff_norm / A_norm_times_B_norm);
|
| 171 |
-
|
| 172 |
-
// For expected error bounds, please refer to the LAPACK Users' Guide,
|
| 173 |
-
// in particular https://netlib.org/lapack/lug/node108.html .
|
| 174 |
-
// Printing the infinity norm of C is a way to check
|
| 175 |
-
// that both the function being tested (C)
|
| 176 |
-
// and the reference implementation (C_ref)
|
| 177 |
-
// don't just do nothing (or fill with zeros).
|
| 178 |
-
using std::cout;
|
| 179 |
-
using cute::shape;
|
| 180 |
-
cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n'
|
| 181 |
-
<< "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
|
| 182 |
-
<< "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
|
| 183 |
-
<< std::scientific
|
| 184 |
-
<< "Infinity norm of A: " << A_norm << '\n'
|
| 185 |
-
<< "Infinity norm of B: " << B_norm << '\n'
|
| 186 |
-
<< "Infinity norm of C: " << C_norm << '\n'
|
| 187 |
-
<< "Infinity norm of (C - C_ref): " << diff_norm << '\n';
|
| 188 |
-
|
| 189 |
-
if(A_norm_times_B_norm == 0.0) {
|
| 190 |
-
cout << "Mollified relative error: " << relative_error << '\n';
|
| 191 |
-
} else {
|
| 192 |
-
cout << "Relative error: " << relative_error << '\n';
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) {
|
| 196 |
-
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
|
| 197 |
-
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
|
| 198 |
-
<< "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
|
| 199 |
-
<< "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
|
| 200 |
-
}
|
| 201 |
-
return relative_error;
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
template <typename EngineType, typename LayoutType>
|
| 205 |
-
auto
|
| 206 |
-
print_matrix_multiply_mollified_relative_error(
|
| 207 |
-
const char value_type_name[],
|
| 208 |
-
const cute::Tensor<EngineType, LayoutType>& A,
|
| 209 |
-
const cute::Tensor<EngineType, LayoutType>& B,
|
| 210 |
-
const cute::Tensor<EngineType, LayoutType>& C_computed,
|
| 211 |
-
const cute::Tensor<EngineType, LayoutType>& C_expected)
|
| 212 |
-
{
|
| 213 |
-
return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
|
| 214 |
-
value_type_name, C_computed, C_expected);
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
// Take a CUTLASS HostTensor (or the like) as input,
|
| 218 |
-
// and return a const CuTe Tensor.
|
| 219 |
-
// This is useful for use with the above error printing functions.
|
| 220 |
-
// This implicitly "transposes" if the layout is RowMajor.
|
| 221 |
-
// Note that the HostTensor must be captured by nonconst reference
|
| 222 |
-
// in order for X.host_ref().data() to compile.
|
| 223 |
-
// (CUTLASS is a bit more container-y than CuTe.)
|
| 224 |
-
template<class CutlassHostTensorType>
|
| 225 |
-
auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X)
|
| 226 |
-
{
|
| 227 |
-
// The tensors were created with post-transposed extents.
|
| 228 |
-
const auto extents = X.extent();
|
| 229 |
-
const auto shape = cute::Shape<int, int>{extents[0], extents[1]};
|
| 230 |
-
// Both RowMajor and ColumnMajor only store one stride.
|
| 231 |
-
const int LDX = X.stride(0);
|
| 232 |
-
const auto strides = [&]() {
|
| 233 |
-
using input_layout_type = typename std::decay_t<decltype(X)>::Layout;
|
| 234 |
-
if constexpr (std::is_same_v<input_layout_type, cutlass::layout::ColumnMajor>) {
|
| 235 |
-
return cute::Stride<int, int>{1, LDX};
|
| 236 |
-
}
|
| 237 |
-
else {
|
| 238 |
-
static_assert(std::is_same_v<input_layout_type, cutlass::layout::RowMajor>);
|
| 239 |
-
return cute::Stride<int, int>{LDX, 1};
|
| 240 |
-
}
|
| 241 |
-
}();
|
| 242 |
-
const auto layout = cute::make_layout(shape, strides);
|
| 243 |
-
auto X_data = X.host_ref().data();
|
| 244 |
-
auto X_data_const = const_cast<std::add_const_t< decltype(X_data)> >(X_data);
|
| 245 |
-
return cute::make_tensor(X_data_const, layout);
|
| 246 |
-
};
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE.
|
| 250 |
-
// This makes the return value suitable as the return value of main().
|
| 251 |
-
template <typename T1, typename T2>
|
| 252 |
-
int
|
| 253 |
-
print_relative_error(
|
| 254 |
-
std::size_t n,
|
| 255 |
-
T1 const& data,
|
| 256 |
-
T2 const& reference,
|
| 257 |
-
bool print_verbose = false,
|
| 258 |
-
bool print_error = true,
|
| 259 |
-
double error_margin = 0.00001) {
|
| 260 |
-
using std::abs; using std::sqrt;
|
| 261 |
-
|
| 262 |
-
// Use either double or complex<double> for error computation
|
| 263 |
-
using value_type = cute::remove_cvref_t<decltype(reference[0])>;
|
| 264 |
-
using error_type = std::conditional_t<cute::is_complex<value_type>::value,
|
| 265 |
-
cute::complex<double>,
|
| 266 |
-
double>;
|
| 267 |
-
|
| 268 |
-
if (print_verbose) {
|
| 269 |
-
std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl;
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
double eps = 1e-200;
|
| 273 |
-
|
| 274 |
-
double tot_error_sq = 0;
|
| 275 |
-
double tot_norm_sq = 0;
|
| 276 |
-
double tot_ind_rel_err = 0;
|
| 277 |
-
double max_ind_rel_err = 0;
|
| 278 |
-
double max_diff = 0;
|
| 279 |
-
for (std::size_t i = 0; i < n; ++i) {
|
| 280 |
-
error_type val = data[i];
|
| 281 |
-
error_type ref = reference[i];
|
| 282 |
-
|
| 283 |
-
double aref = abs(ref);
|
| 284 |
-
double diff = abs(ref - val);
|
| 285 |
-
double rel_error = diff / (aref + eps);
|
| 286 |
-
|
| 287 |
-
// Individual relative error
|
| 288 |
-
tot_ind_rel_err += rel_error;
|
| 289 |
-
|
| 290 |
-
// Maximum relative error
|
| 291 |
-
max_ind_rel_err = std::max(max_ind_rel_err, rel_error);
|
| 292 |
-
|
| 293 |
-
// Maximum delta in value error
|
| 294 |
-
max_diff = std::max(max_diff, diff);
|
| 295 |
-
|
| 296 |
-
// Total relative error
|
| 297 |
-
tot_error_sq += diff * diff;
|
| 298 |
-
tot_norm_sq += aref * aref;
|
| 299 |
-
|
| 300 |
-
if (print_verbose) {
|
| 301 |
-
std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl;
|
| 302 |
-
}
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
double ave_rel_err = tot_ind_rel_err / double(n);
|
| 306 |
-
if (print_error) {
|
| 307 |
-
printf("Average relative error: %.3e\n", ave_rel_err);
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
if (print_error) {
|
| 311 |
-
printf("Maximum relative error: %.3e\n", max_ind_rel_err);
|
| 312 |
-
}
|
| 313 |
-
|
| 314 |
-
if (print_error) {
|
| 315 |
-
printf("Maximum difference : %.3e\n", max_diff);
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps));
|
| 319 |
-
if (print_error) {
|
| 320 |
-
printf("Vector relative error: %.3e\n", tot_rel_err);
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
-
printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq));
|
| 324 |
-
|
| 325 |
-
return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE;
|
| 326 |
-
}
|
| 327 |
-
|
| 328 |
-
// Overload for cute::Tensor<>
|
| 329 |
-
template <class Engine, class Layout>
|
| 330 |
-
int
|
| 331 |
-
print_relative_error(
|
| 332 |
-
cute::Tensor<Engine, Layout> data,
|
| 333 |
-
cute::Tensor<Engine, Layout> reference,
|
| 334 |
-
bool print_verbose = false,
|
| 335 |
-
bool print_error = true,
|
| 336 |
-
double error_margin = 0.00001) {
|
| 337 |
-
assert(size(data) == size(reference));
|
| 338 |
-
return print_relative_error(static_cast<std::size_t>(size(data)),
|
| 339 |
-
data, reference,
|
| 340 |
-
print_verbose, print_error, error_margin);
|
| 341 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
#pragma once
|
| 35 |
-
|
| 36 |
-
#include "cutlass/cutlass.h"
|
| 37 |
-
#include "cutlass/array.h"
|
| 38 |
-
|
| 39 |
-
namespace cutlass {
|
| 40 |
-
namespace reference {
|
| 41 |
-
namespace detail {
|
| 42 |
-
|
| 43 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
-
|
| 45 |
-
/// Template function to compute an inner product.
|
| 46 |
-
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
|
| 47 |
-
// host-only type
|
| 48 |
-
template <typename Atype, typename Btype, typename Ctype>
|
| 49 |
-
CUTLASS_HOST_DEVICE
|
| 50 |
-
Ctype inner_product(Atype a, Btype b, Ctype c) {
|
| 51 |
-
return Ctype(a) * Ctype(b) + c;
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
/// Specialization for matrix multiplication with binary operands
|
| 55 |
-
template <>
|
| 56 |
-
CUTLASS_HOST_DEVICE
|
| 57 |
-
int inner_product<Array<bin1_t, 32>, Array<bin1_t, 32>, int>(
|
| 58 |
-
Array<bin1_t, 32> a,
|
| 59 |
-
Array<bin1_t, 32> b,
|
| 60 |
-
int c) {
|
| 61 |
-
|
| 62 |
-
int accum = 0;
|
| 63 |
-
for (int bit = 0; bit < 32; bit++) {
|
| 64 |
-
accum += a[bit] ^ b[bit];
|
| 65 |
-
}
|
| 66 |
-
return accum + c;
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
/*
|
| 70 |
-
/// Specialization for matrix multiplication with signed 4-bit integer operands
|
| 71 |
-
template <>
|
| 72 |
-
CUTLASS_HOST_DEVICE
|
| 73 |
-
int inner_product<Array<int4b_t, 8>, Array<int4b_t, 8>, int>(
|
| 74 |
-
Array<int4b_t, 8> a,
|
| 75 |
-
Array<int4b_t, 8> b,
|
| 76 |
-
int c) {
|
| 77 |
-
|
| 78 |
-
int accum = 0;
|
| 79 |
-
for (int k = 0; k < 8; k++) {
|
| 80 |
-
accum += a[k] * b[k];
|
| 81 |
-
}
|
| 82 |
-
return accum + c;
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
/// Specialization for matrix multiplication with unsigned 4-bit integer operands
|
| 86 |
-
template <>
|
| 87 |
-
CUTLASS_HOST_DEVICE
|
| 88 |
-
int inner_product<Array<uint4b_t, 8>, Array<uint4b_t, 8>, int>(
|
| 89 |
-
Array<uint4b_t, 8> a,
|
| 90 |
-
Array<uint4b_t, 8> b,
|
| 91 |
-
int c) {
|
| 92 |
-
|
| 93 |
-
int accum = 0;
|
| 94 |
-
for (int k = 0; k < 8; k++) {
|
| 95 |
-
accum += a[k] * b[k];
|
| 96 |
-
}
|
| 97 |
-
return accum + c;
|
| 98 |
-
}
|
| 99 |
-
*/
|
| 100 |
-
|
| 101 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 102 |
-
|
| 103 |
-
template <typename SrcType, typename DstType>
|
| 104 |
-
struct Cast {
|
| 105 |
-
// Default behavior: convert to the destination type
|
| 106 |
-
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
| 107 |
-
// host-only type
|
| 108 |
-
CUTLASS_HOST_DEVICE
|
| 109 |
-
static DstType apply(SrcType src) { return static_cast<DstType>(src); };
|
| 110 |
-
};
|
| 111 |
-
|
| 112 |
-
template <>
|
| 113 |
-
struct Cast<float, int8_t> {
|
| 114 |
-
CUTLASS_HOST_DEVICE
|
| 115 |
-
static int8_t apply(float src) {
|
| 116 |
-
// Clamp to the range of signed 8-bit integers.
|
| 117 |
-
return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
|
| 118 |
-
};
|
| 119 |
-
};
|
| 120 |
-
|
| 121 |
-
template <>
|
| 122 |
-
struct Cast<float, uint8_t> {
|
| 123 |
-
CUTLASS_HOST_DEVICE
|
| 124 |
-
static uint8_t apply(float src) {
|
| 125 |
-
// Clamp to the range of signed 8-bit integers.
|
| 126 |
-
return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
|
| 127 |
-
};
|
| 128 |
-
};
|
| 129 |
-
|
| 130 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 131 |
-
|
| 132 |
-
} // namespace detail
|
| 133 |
-
} // namespace reference
|
| 134 |
-
} // namespace cutlass
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
#pragma once
|
| 35 |
-
|
| 36 |
-
#include "cutlass/cutlass.h"
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
|
| 39 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
-
|
| 41 |
-
namespace cutlass {
|
| 42 |
-
namespace reference {
|
| 43 |
-
namespace detail {
|
| 44 |
-
|
| 45 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
-
|
| 47 |
-
template <int Rank, int Index>
|
| 48 |
-
struct LinearToCoordinateHelper {
|
| 49 |
-
|
| 50 |
-
CUTLASS_HOST_DEVICE
|
| 51 |
-
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
|
| 52 |
-
|
| 53 |
-
int64_t prod = 1;
|
| 54 |
-
|
| 55 |
-
CUTLASS_PRAGMA_UNROLL
|
| 56 |
-
for (int i = Rank - Index; i < Rank; ++i) {
|
| 57 |
-
prod *= int64_t(extent[i]);
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
coord[Rank - Index - 1] = int(idx / prod);
|
| 61 |
-
|
| 62 |
-
int64_t residual = idx % prod;
|
| 63 |
-
LinearToCoordinateHelper<Rank, Index - 1>()(coord, residual, extent);
|
| 64 |
-
}
|
| 65 |
-
};
|
| 66 |
-
|
| 67 |
-
template <int Rank>
|
| 68 |
-
struct LinearToCoordinateHelper<Rank, 0> {
|
| 69 |
-
|
| 70 |
-
CUTLASS_HOST_DEVICE
|
| 71 |
-
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &) const {
|
| 72 |
-
coord[Rank - 1] = int(idx);
|
| 73 |
-
}
|
| 74 |
-
};
|
| 75 |
-
|
| 76 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
-
|
| 78 |
-
template <int Rank>
|
| 79 |
-
struct LinearToCoordinate {
|
| 80 |
-
|
| 81 |
-
CUTLASS_HOST_DEVICE
|
| 82 |
-
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
|
| 83 |
-
LinearToCoordinateHelper<Rank, Rank - 1>()(coord, idx, extent);
|
| 84 |
-
}
|
| 85 |
-
};
|
| 86 |
-
|
| 87 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 88 |
-
|
| 89 |
-
} // namespace detail
|
| 90 |
-
} // namespace reference
|
| 91 |
-
} // namespace cutlass
|
| 92 |
-
|
| 93 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h
DELETED
|
@@ -1,1549 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Reference implementation for convolution in device-side code.
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#pragma once
|
| 37 |
-
|
| 38 |
-
#include "cutlass/coord.h"
|
| 39 |
-
#include "cutlass/functional.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/matrix_shape.h"
|
| 42 |
-
#include "cutlass/numeric_conversion.h"
|
| 43 |
-
#include "cutlass/numeric_types.h"
|
| 44 |
-
#include "cutlass/tensor_ref.h"
|
| 45 |
-
#include "cutlass/conv/convolution.h"
|
| 46 |
-
#include "cutlass/conv/conv2d_problem_size.h"
|
| 47 |
-
#include "cutlass/conv/conv3d_problem_size.h"
|
| 48 |
-
|
| 49 |
-
namespace cutlass {
|
| 50 |
-
namespace reference {
|
| 51 |
-
namespace device {
|
| 52 |
-
|
| 53 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
-
|
| 55 |
-
namespace kernel {
|
| 56 |
-
|
| 57 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
-
/// Conv2d device reference kernel
|
| 59 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
-
|
| 61 |
-
// Conv2d Fprop kernel - y = fprop(x, w)
|
| 62 |
-
template <
|
| 63 |
-
typename ElementA,
|
| 64 |
-
typename LayoutA,
|
| 65 |
-
typename ElementB,
|
| 66 |
-
typename LayoutB,
|
| 67 |
-
typename ElementC,
|
| 68 |
-
typename LayoutC,
|
| 69 |
-
typename ElementCompute,
|
| 70 |
-
typename ElementAccumulator = ElementCompute,
|
| 71 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 72 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 73 |
-
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 74 |
-
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 75 |
-
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 76 |
-
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 77 |
-
>
|
| 78 |
-
__global__ void Conv2dFprop(
|
| 79 |
-
conv::Conv2dProblemSize problem_size,
|
| 80 |
-
TensorRef<ElementA, LayoutA> tensor_x,
|
| 81 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 82 |
-
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 83 |
-
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 84 |
-
ElementCompute alpha,
|
| 85 |
-
ElementCompute beta
|
| 86 |
-
) {
|
| 87 |
-
|
| 88 |
-
ConvertOp convert_op;
|
| 89 |
-
InnerProductOp inner_product_op;
|
| 90 |
-
|
| 91 |
-
ElementAccumulator element_A[kThreadM];
|
| 92 |
-
ElementAccumulator element_B[kThreadN];
|
| 93 |
-
ElementAccumulator accum[kThreadM][kThreadN];
|
| 94 |
-
|
| 95 |
-
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 96 |
-
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 97 |
-
|
| 98 |
-
int thread_n[kThreadM];
|
| 99 |
-
int thread_p[kThreadM];
|
| 100 |
-
int thread_q[kThreadM];
|
| 101 |
-
|
| 102 |
-
// Compute N, P, Q coordinates for each row of a thread's tile
|
| 103 |
-
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
|
| 104 |
-
|
| 105 |
-
CUTLASS_PRAGMA_UNROLL
|
| 106 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 107 |
-
|
| 108 |
-
int64_t npq = npq_start + m;
|
| 109 |
-
|
| 110 |
-
thread_n[m] = int(npq / PQ);
|
| 111 |
-
|
| 112 |
-
int64_t residual = npq % PQ;
|
| 113 |
-
thread_p[m] = int(residual / problem_size.Q);
|
| 114 |
-
thread_q[m] = int(residual % problem_size.Q);
|
| 115 |
-
}
|
| 116 |
-
|
| 117 |
-
// Clear accumulators
|
| 118 |
-
CUTLASS_PRAGMA_UNROLL
|
| 119 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 120 |
-
CUTLASS_PRAGMA_UNROLL
|
| 121 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 122 |
-
accum[m][n] = ElementAccumulator();
|
| 123 |
-
}
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
int c_per_group = problem_size.C / problem_size.groups;
|
| 127 |
-
int k_per_group = problem_size.K / problem_size.groups;
|
| 128 |
-
|
| 129 |
-
// Compute convolution
|
| 130 |
-
for (int R = 0; R < problem_size.R; ++R) {
|
| 131 |
-
for (int S = 0; S < problem_size.S; ++S) {
|
| 132 |
-
for (int C = 0; C < problem_size.C; ++C) {
|
| 133 |
-
|
| 134 |
-
// Get group id of currnet channel
|
| 135 |
-
int c_group_idx = C / c_per_group;
|
| 136 |
-
|
| 137 |
-
// Load from activations tensor
|
| 138 |
-
int filter_r = R;
|
| 139 |
-
int filter_s = S;
|
| 140 |
-
|
| 141 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 142 |
-
filter_r = problem_size.R - 1 - R;
|
| 143 |
-
filter_s = problem_size.S - 1 - S;
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
CUTLASS_PRAGMA_UNROLL
|
| 147 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 148 |
-
int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 149 |
-
int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 150 |
-
|
| 151 |
-
if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
|
| 152 |
-
element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C}));
|
| 153 |
-
}
|
| 154 |
-
else {
|
| 155 |
-
element_A[m] = ElementAccumulator();
|
| 156 |
-
}
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
// Load from filters tensor
|
| 160 |
-
CUTLASS_PRAGMA_UNROLL
|
| 161 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 162 |
-
int thread_k = k_start + n;
|
| 163 |
-
int k_group_idx = thread_k / k_per_group;
|
| 164 |
-
|
| 165 |
-
if (thread_k < problem_size.K && k_group_idx == c_group_idx) {
|
| 166 |
-
element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group}));
|
| 167 |
-
}
|
| 168 |
-
else {
|
| 169 |
-
element_B[n] = ElementAccumulator();
|
| 170 |
-
}
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
-
// Accumulate matrix product
|
| 174 |
-
CUTLASS_PRAGMA_UNROLL
|
| 175 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 176 |
-
CUTLASS_PRAGMA_UNROLL
|
| 177 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 178 |
-
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 179 |
-
}
|
| 180 |
-
}
|
| 181 |
-
}
|
| 182 |
-
}
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
// Write out the results
|
| 186 |
-
CUTLASS_PRAGMA_UNROLL
|
| 187 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 188 |
-
if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
|
| 189 |
-
CUTLASS_PRAGMA_UNROLL
|
| 190 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 191 |
-
int thread_k = k_start + n;
|
| 192 |
-
if (thread_k < problem_size.K) {
|
| 193 |
-
|
| 194 |
-
ElementCompute c_ref = ElementCompute();
|
| 195 |
-
if (beta != ElementCompute()) {
|
| 196 |
-
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
| 200 |
-
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 201 |
-
}
|
| 202 |
-
}
|
| 203 |
-
}
|
| 204 |
-
}
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
// Conv3d Fprop kernel - y = fprop(x, w)
|
| 208 |
-
template <
|
| 209 |
-
typename ElementA,
|
| 210 |
-
typename LayoutA,
|
| 211 |
-
typename ElementB,
|
| 212 |
-
typename LayoutB,
|
| 213 |
-
typename ElementC,
|
| 214 |
-
typename LayoutC,
|
| 215 |
-
typename ElementCompute,
|
| 216 |
-
typename ElementAccumulator = ElementCompute,
|
| 217 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 218 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 219 |
-
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 220 |
-
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 221 |
-
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 222 |
-
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 223 |
-
>
|
| 224 |
-
__global__ void Conv3dFprop(
|
| 225 |
-
conv::Conv3dProblemSize problem_size,
|
| 226 |
-
TensorRef<ElementA, LayoutA> tensor_x,
|
| 227 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 228 |
-
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 229 |
-
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 230 |
-
ElementCompute alpha,
|
| 231 |
-
ElementCompute beta
|
| 232 |
-
) {
|
| 233 |
-
|
| 234 |
-
ConvertOp convert_op;
|
| 235 |
-
InnerProductOp inner_product_op;
|
| 236 |
-
|
| 237 |
-
ElementAccumulator element_A[kThreadM];
|
| 238 |
-
ElementAccumulator element_B[kThreadN];
|
| 239 |
-
ElementAccumulator accum[kThreadM][kThreadN];
|
| 240 |
-
|
| 241 |
-
int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 242 |
-
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 243 |
-
|
| 244 |
-
int thread_n[kThreadM];
|
| 245 |
-
int thread_z[kThreadM];
|
| 246 |
-
int thread_p[kThreadM];
|
| 247 |
-
int thread_q[kThreadM];
|
| 248 |
-
|
| 249 |
-
// Compute N, Z, P, Q coordinates for each row of a thread's tile
|
| 250 |
-
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
|
| 251 |
-
int64_t ZPQ = PQ * problem_size.Z;
|
| 252 |
-
|
| 253 |
-
CUTLASS_PRAGMA_UNROLL
|
| 254 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 255 |
-
|
| 256 |
-
int64_t nzpq = nzpq_start + m;
|
| 257 |
-
|
| 258 |
-
thread_n[m] = int(nzpq / ZPQ);
|
| 259 |
-
|
| 260 |
-
int64_t residual = nzpq % ZPQ;
|
| 261 |
-
thread_z[m] = int(residual / PQ);
|
| 262 |
-
|
| 263 |
-
residual = residual % PQ;
|
| 264 |
-
thread_p[m] = int(residual / problem_size.Q);
|
| 265 |
-
thread_q[m] = int(residual % problem_size.Q);
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
// Clear accumulators
|
| 269 |
-
CUTLASS_PRAGMA_UNROLL
|
| 270 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 271 |
-
CUTLASS_PRAGMA_UNROLL
|
| 272 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 273 |
-
accum[m][n] = ElementAccumulator();
|
| 274 |
-
}
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
// Compute convolution
|
| 278 |
-
for (int T = 0; T < problem_size.T; ++T) {
|
| 279 |
-
for (int R = 0; R < problem_size.R; ++R) {
|
| 280 |
-
for (int S = 0; S < problem_size.S; ++S) {
|
| 281 |
-
for (int C = 0; C < problem_size.C; ++C) {
|
| 282 |
-
|
| 283 |
-
// Load from activations tensor
|
| 284 |
-
int filter_t = T;
|
| 285 |
-
int filter_r = R;
|
| 286 |
-
int filter_s = S;
|
| 287 |
-
|
| 288 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 289 |
-
filter_t = problem_size.T - 1 - T;
|
| 290 |
-
filter_r = problem_size.R - 1 - R;
|
| 291 |
-
filter_s = problem_size.S - 1 - S;
|
| 292 |
-
}
|
| 293 |
-
|
| 294 |
-
CUTLASS_PRAGMA_UNROLL
|
| 295 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 296 |
-
int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
| 297 |
-
int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 298 |
-
int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 299 |
-
|
| 300 |
-
if (thread_n[m] < problem_size.N &&
|
| 301 |
-
d >= 0 && d < problem_size.D &&
|
| 302 |
-
h >= 0 && h < problem_size.H &&
|
| 303 |
-
w >= 0 && w < problem_size.W) {
|
| 304 |
-
|
| 305 |
-
element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C}));
|
| 306 |
-
}
|
| 307 |
-
else {
|
| 308 |
-
element_A[m] = ElementAccumulator();
|
| 309 |
-
}
|
| 310 |
-
}
|
| 311 |
-
|
| 312 |
-
// Load from filters tensor
|
| 313 |
-
CUTLASS_PRAGMA_UNROLL
|
| 314 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 315 |
-
int thread_k = k_start + n;
|
| 316 |
-
|
| 317 |
-
if (thread_k < problem_size.K) {
|
| 318 |
-
element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C}));
|
| 319 |
-
}
|
| 320 |
-
else {
|
| 321 |
-
element_B[n] = ElementAccumulator();
|
| 322 |
-
}
|
| 323 |
-
}
|
| 324 |
-
|
| 325 |
-
// Accumulate matrix product
|
| 326 |
-
CUTLASS_PRAGMA_UNROLL
|
| 327 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 328 |
-
CUTLASS_PRAGMA_UNROLL
|
| 329 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 330 |
-
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 331 |
-
}
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
} // for (C)
|
| 335 |
-
} // for (S)
|
| 336 |
-
} // for (R)
|
| 337 |
-
} // for (T)
|
| 338 |
-
|
| 339 |
-
// Write out the results
|
| 340 |
-
CUTLASS_PRAGMA_UNROLL
|
| 341 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 342 |
-
|
| 343 |
-
if (thread_n[m] < problem_size.N &&
|
| 344 |
-
thread_z[m] < problem_size.Z &&
|
| 345 |
-
thread_p[m] < problem_size.P &&
|
| 346 |
-
thread_q[m] < problem_size.Q) {
|
| 347 |
-
|
| 348 |
-
CUTLASS_PRAGMA_UNROLL
|
| 349 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 350 |
-
int thread_k = k_start + n;
|
| 351 |
-
if (thread_k < problem_size.K) {
|
| 352 |
-
|
| 353 |
-
ElementCompute c_ref = ElementCompute();
|
| 354 |
-
if (beta != ElementCompute()) {
|
| 355 |
-
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}));
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
| 359 |
-
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 360 |
-
}
|
| 361 |
-
} // for (n)
|
| 362 |
-
|
| 363 |
-
}
|
| 364 |
-
} // for (m)
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 368 |
-
|
| 369 |
-
// Conv2d dgrad kernel - dx = dgrad(dy, w)
|
| 370 |
-
template <
|
| 371 |
-
typename ElementA,
|
| 372 |
-
typename LayoutA,
|
| 373 |
-
typename ElementB,
|
| 374 |
-
typename LayoutB,
|
| 375 |
-
typename ElementC,
|
| 376 |
-
typename LayoutC,
|
| 377 |
-
typename ElementCompute,
|
| 378 |
-
typename ElementAccumulator = ElementCompute,
|
| 379 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 380 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 381 |
-
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 382 |
-
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 383 |
-
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 384 |
-
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 385 |
-
>
|
| 386 |
-
__global__ void Conv2dDgrad(
|
| 387 |
-
conv::Conv2dProblemSize problem_size,
|
| 388 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 389 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 390 |
-
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 391 |
-
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 392 |
-
ElementCompute alpha,
|
| 393 |
-
ElementCompute beta
|
| 394 |
-
) {
|
| 395 |
-
|
| 396 |
-
ConvertOp convert_op;
|
| 397 |
-
InnerProductOp inner_product_op;
|
| 398 |
-
|
| 399 |
-
ElementAccumulator element_A[kThreadM];
|
| 400 |
-
ElementAccumulator element_B[kThreadN];
|
| 401 |
-
ElementAccumulator accum[kThreadM][kThreadN];
|
| 402 |
-
|
| 403 |
-
int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 404 |
-
int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 405 |
-
|
| 406 |
-
int thread_n[kThreadM];
|
| 407 |
-
int thread_h[kThreadM];
|
| 408 |
-
int thread_w[kThreadM];
|
| 409 |
-
|
| 410 |
-
// Compute N, H, W coordinates for each row of a thread's tile
|
| 411 |
-
int64_t HW = int64_t(problem_size.H) * problem_size.W;
|
| 412 |
-
|
| 413 |
-
CUTLASS_PRAGMA_UNROLL
|
| 414 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 415 |
-
|
| 416 |
-
int64_t nhw = nhw_start + m;
|
| 417 |
-
|
| 418 |
-
thread_n[m] = int(nhw / HW);
|
| 419 |
-
|
| 420 |
-
int64_t residual = nhw % HW;
|
| 421 |
-
thread_h[m] = int(residual / problem_size.W);
|
| 422 |
-
thread_w[m] = int(residual % problem_size.W);
|
| 423 |
-
}
|
| 424 |
-
|
| 425 |
-
// Clear accumulators
|
| 426 |
-
CUTLASS_PRAGMA_UNROLL
|
| 427 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 428 |
-
CUTLASS_PRAGMA_UNROLL
|
| 429 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 430 |
-
accum[m][n] = ElementAccumulator();
|
| 431 |
-
}
|
| 432 |
-
}
|
| 433 |
-
|
| 434 |
-
// Compute convolution
|
| 435 |
-
for (int R = 0; R < problem_size.R; ++R) {
|
| 436 |
-
for (int S = 0; S < problem_size.S; ++S) {
|
| 437 |
-
for (int K = 0; K < problem_size.K; ++K) {
|
| 438 |
-
|
| 439 |
-
// Load from activations tensor
|
| 440 |
-
int filter_r = R;
|
| 441 |
-
int filter_s = S;
|
| 442 |
-
|
| 443 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 444 |
-
filter_r = problem_size.R - 1 - R;
|
| 445 |
-
filter_s = problem_size.S - 1 - S;
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
CUTLASS_PRAGMA_UNROLL
|
| 449 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 450 |
-
|
| 451 |
-
int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 452 |
-
int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 453 |
-
|
| 454 |
-
element_A[m] = ElementAccumulator();
|
| 455 |
-
|
| 456 |
-
if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) {
|
| 457 |
-
|
| 458 |
-
p = p / problem_size.stride_h;
|
| 459 |
-
q = q / problem_size.stride_w;
|
| 460 |
-
|
| 461 |
-
if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) {
|
| 462 |
-
element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K}));
|
| 463 |
-
}
|
| 464 |
-
}
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
// Load from filters tensor
|
| 468 |
-
CUTLASS_PRAGMA_UNROLL
|
| 469 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 470 |
-
int thread_c = c_start + n;
|
| 471 |
-
|
| 472 |
-
if (thread_c < problem_size.C) {
|
| 473 |
-
element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c}));
|
| 474 |
-
}
|
| 475 |
-
else {
|
| 476 |
-
element_B[n] = ElementAccumulator();
|
| 477 |
-
}
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
// Accumulate matrix product
|
| 481 |
-
CUTLASS_PRAGMA_UNROLL
|
| 482 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 483 |
-
CUTLASS_PRAGMA_UNROLL
|
| 484 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 485 |
-
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 486 |
-
}
|
| 487 |
-
}
|
| 488 |
-
}
|
| 489 |
-
}
|
| 490 |
-
}
|
| 491 |
-
|
| 492 |
-
// Write out the results
|
| 493 |
-
CUTLASS_PRAGMA_UNROLL
|
| 494 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 495 |
-
|
| 496 |
-
if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) {
|
| 497 |
-
|
| 498 |
-
CUTLASS_PRAGMA_UNROLL
|
| 499 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 500 |
-
int thread_c = c_start + n;
|
| 501 |
-
if (thread_c < problem_size.C) {
|
| 502 |
-
|
| 503 |
-
ElementCompute c_ref = ElementCompute();
|
| 504 |
-
if (beta != ElementCompute()) {
|
| 505 |
-
c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c}));
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
|
| 509 |
-
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 510 |
-
}
|
| 511 |
-
}
|
| 512 |
-
}
|
| 513 |
-
}
|
| 514 |
-
}
|
| 515 |
-
|
| 516 |
-
// Conv3d dgrad kernel - dx = dgrad(dy, w)
|
| 517 |
-
template <
|
| 518 |
-
typename ElementA,
|
| 519 |
-
typename LayoutA,
|
| 520 |
-
typename ElementB,
|
| 521 |
-
typename LayoutB,
|
| 522 |
-
typename ElementC,
|
| 523 |
-
typename LayoutC,
|
| 524 |
-
typename ElementCompute,
|
| 525 |
-
typename ElementAccumulator = ElementCompute,
|
| 526 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 527 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 528 |
-
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 529 |
-
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 530 |
-
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 531 |
-
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 532 |
-
>
|
| 533 |
-
__global__ void Conv3dDgrad(
|
| 534 |
-
conv::Conv3dProblemSize problem_size,
|
| 535 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 536 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 537 |
-
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 538 |
-
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 539 |
-
ElementCompute alpha,
|
| 540 |
-
ElementCompute beta
|
| 541 |
-
) {
|
| 542 |
-
|
| 543 |
-
ConvertOp convert_op;
|
| 544 |
-
InnerProductOp inner_product_op;
|
| 545 |
-
|
| 546 |
-
ElementAccumulator element_A[kThreadM];
|
| 547 |
-
ElementAccumulator element_B[kThreadN];
|
| 548 |
-
ElementAccumulator accum[kThreadM][kThreadN];
|
| 549 |
-
|
| 550 |
-
int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 551 |
-
int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 552 |
-
|
| 553 |
-
int thread_n[kThreadM];
|
| 554 |
-
int thread_d[kThreadM];
|
| 555 |
-
int thread_h[kThreadM];
|
| 556 |
-
int thread_w[kThreadM];
|
| 557 |
-
|
| 558 |
-
// Compute N, H, W coordinates for each row of a thread's tile
|
| 559 |
-
int64_t HW = int64_t(problem_size.H) * problem_size.W;
|
| 560 |
-
int64_t DHW = HW * problem_size.D;
|
| 561 |
-
|
| 562 |
-
CUTLASS_PRAGMA_UNROLL
|
| 563 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 564 |
-
|
| 565 |
-
int64_t ndhw = ndhw_start + m;
|
| 566 |
-
|
| 567 |
-
thread_n[m] = int(ndhw / DHW);
|
| 568 |
-
|
| 569 |
-
int64_t residual = ndhw % DHW;
|
| 570 |
-
thread_d[m] = int(residual / HW);
|
| 571 |
-
|
| 572 |
-
residual = residual % HW;
|
| 573 |
-
thread_h[m] = int(residual / problem_size.W);
|
| 574 |
-
thread_w[m] = int(residual % problem_size.W);
|
| 575 |
-
}
|
| 576 |
-
|
| 577 |
-
// Clear accumulators
|
| 578 |
-
CUTLASS_PRAGMA_UNROLL
|
| 579 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 580 |
-
CUTLASS_PRAGMA_UNROLL
|
| 581 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 582 |
-
accum[m][n] = ElementAccumulator();
|
| 583 |
-
}
|
| 584 |
-
}
|
| 585 |
-
|
| 586 |
-
// Compute convolution
|
| 587 |
-
for (int T = 0; T < problem_size.T; ++T) {
|
| 588 |
-
for (int R = 0; R < problem_size.R; ++R) {
|
| 589 |
-
for (int S = 0; S < problem_size.S; ++S) {
|
| 590 |
-
for (int K = 0; K < problem_size.K; ++K) {
|
| 591 |
-
|
| 592 |
-
// Load from activations tensor
|
| 593 |
-
int filter_t = T;
|
| 594 |
-
int filter_r = R;
|
| 595 |
-
int filter_s = S;
|
| 596 |
-
|
| 597 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 598 |
-
filter_t = problem_size.T - 1 - T;
|
| 599 |
-
filter_r = problem_size.R - 1 - R;
|
| 600 |
-
filter_s = problem_size.S - 1 - S;
|
| 601 |
-
}
|
| 602 |
-
|
| 603 |
-
CUTLASS_PRAGMA_UNROLL
|
| 604 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 605 |
-
|
| 606 |
-
int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d;
|
| 607 |
-
int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 608 |
-
int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 609 |
-
|
| 610 |
-
element_A[m] = ElementAccumulator();
|
| 611 |
-
|
| 612 |
-
if (z >= 0 && !(z % problem_size.stride_d) &&
|
| 613 |
-
p >= 0 && !(p % problem_size.stride_h) &&
|
| 614 |
-
q >= 0 && !(q % problem_size.stride_w)) {
|
| 615 |
-
|
| 616 |
-
z = z / problem_size.stride_d;
|
| 617 |
-
p = p / problem_size.stride_h;
|
| 618 |
-
q = q / problem_size.stride_w;
|
| 619 |
-
|
| 620 |
-
if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
|
| 621 |
-
element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K}));
|
| 622 |
-
}
|
| 623 |
-
}
|
| 624 |
-
}
|
| 625 |
-
|
| 626 |
-
// Load from filters tensor
|
| 627 |
-
CUTLASS_PRAGMA_UNROLL
|
| 628 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 629 |
-
int thread_c = c_start + n;
|
| 630 |
-
|
| 631 |
-
if (thread_c < problem_size.C) {
|
| 632 |
-
element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c}));
|
| 633 |
-
}
|
| 634 |
-
else {
|
| 635 |
-
element_B[n] = ElementAccumulator();
|
| 636 |
-
}
|
| 637 |
-
}
|
| 638 |
-
|
| 639 |
-
// Accumulate matrix product
|
| 640 |
-
CUTLASS_PRAGMA_UNROLL
|
| 641 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 642 |
-
CUTLASS_PRAGMA_UNROLL
|
| 643 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 644 |
-
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 645 |
-
}
|
| 646 |
-
}
|
| 647 |
-
|
| 648 |
-
} // for (C)
|
| 649 |
-
} // for (S)
|
| 650 |
-
} // for (R)
|
| 651 |
-
} // for (T)
|
| 652 |
-
|
| 653 |
-
// Write out the results
|
| 654 |
-
CUTLASS_PRAGMA_UNROLL
|
| 655 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 656 |
-
|
| 657 |
-
if (thread_n[m] < problem_size.N &&
|
| 658 |
-
thread_d[m] < problem_size.D &&
|
| 659 |
-
thread_h[m] < problem_size.H &&
|
| 660 |
-
thread_w[m] < problem_size.W) {
|
| 661 |
-
|
| 662 |
-
CUTLASS_PRAGMA_UNROLL
|
| 663 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 664 |
-
int thread_c = c_start + n;
|
| 665 |
-
if (thread_c < problem_size.C) {
|
| 666 |
-
|
| 667 |
-
ElementCompute c_ref = ElementCompute();
|
| 668 |
-
if (beta != ElementCompute()) {
|
| 669 |
-
c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}));
|
| 670 |
-
}
|
| 671 |
-
|
| 672 |
-
tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
|
| 673 |
-
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 674 |
-
}
|
| 675 |
-
}
|
| 676 |
-
}
|
| 677 |
-
}
|
| 678 |
-
}
|
| 679 |
-
|
| 680 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 681 |
-
|
| 682 |
-
// Conv2d wgrad kernel - dw = wgrad(dy, x)
|
| 683 |
-
template <
|
| 684 |
-
typename ElementA,
|
| 685 |
-
typename LayoutA,
|
| 686 |
-
typename ElementB,
|
| 687 |
-
typename LayoutB,
|
| 688 |
-
typename ElementC,
|
| 689 |
-
typename LayoutC,
|
| 690 |
-
typename ElementCompute,
|
| 691 |
-
typename ElementAccumulator = ElementCompute,
|
| 692 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 693 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 694 |
-
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 695 |
-
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 696 |
-
int kCtaShapeM = 8, // shape of a threadblock in units of threads
|
| 697 |
-
int kCtaShapeN = 16 // shape of a threadblock in units of threads
|
| 698 |
-
>
|
| 699 |
-
__global__ void Conv2dWgrad(
|
| 700 |
-
conv::Conv2dProblemSize problem_size,
|
| 701 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 702 |
-
TensorRef<ElementB, LayoutB> tensor_x,
|
| 703 |
-
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 704 |
-
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 705 |
-
ElementCompute alpha,
|
| 706 |
-
ElementCompute beta
|
| 707 |
-
) {
|
| 708 |
-
|
| 709 |
-
ConvertOp convert_op;
|
| 710 |
-
InnerProductOp inner_product_op;
|
| 711 |
-
|
| 712 |
-
ElementAccumulator element_A[kThreadM];
|
| 713 |
-
ElementAccumulator element_B[kThreadN];
|
| 714 |
-
ElementAccumulator accum[kThreadM][kThreadN];
|
| 715 |
-
|
| 716 |
-
int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 717 |
-
int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 718 |
-
|
| 719 |
-
int thread_r[kThreadN];
|
| 720 |
-
int thread_s[kThreadN];
|
| 721 |
-
int thread_c[kThreadN];
|
| 722 |
-
|
| 723 |
-
// Compute R, S, C coordinates for each row of a thread's tile
|
| 724 |
-
int64_t SC = int64_t(problem_size.S) * problem_size.C;
|
| 725 |
-
|
| 726 |
-
CUTLASS_PRAGMA_UNROLL
|
| 727 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 728 |
-
|
| 729 |
-
int64_t rsc = rsc_start + n;
|
| 730 |
-
int64_t residual = rsc % SC;
|
| 731 |
-
|
| 732 |
-
thread_r[n] = int(rsc / SC);
|
| 733 |
-
thread_s[n] = int(residual / problem_size.C);
|
| 734 |
-
thread_c[n] = int(residual % problem_size.C);
|
| 735 |
-
}
|
| 736 |
-
|
| 737 |
-
// Clear accumulators
|
| 738 |
-
CUTLASS_PRAGMA_UNROLL
|
| 739 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 740 |
-
CUTLASS_PRAGMA_UNROLL
|
| 741 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 742 |
-
accum[m][n] = ElementAccumulator();
|
| 743 |
-
}
|
| 744 |
-
}
|
| 745 |
-
|
| 746 |
-
// Compute convolution
|
| 747 |
-
for (int N = 0; N < problem_size.N; ++N) {
|
| 748 |
-
for (int P = 0; P < problem_size.P; ++P) {
|
| 749 |
-
for (int Q = 0; Q < problem_size.Q; ++Q) {
|
| 750 |
-
|
| 751 |
-
CUTLASS_PRAGMA_UNROLL
|
| 752 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 753 |
-
int thread_k = k_start + m;
|
| 754 |
-
|
| 755 |
-
element_A[m] = ElementAccumulator();
|
| 756 |
-
|
| 757 |
-
if (thread_k < problem_size.K) {
|
| 758 |
-
element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k}));
|
| 759 |
-
}
|
| 760 |
-
}
|
| 761 |
-
|
| 762 |
-
// Load from filters tensor
|
| 763 |
-
CUTLASS_PRAGMA_UNROLL
|
| 764 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 765 |
-
|
| 766 |
-
// Load from activations tensor
|
| 767 |
-
int filter_r = thread_r[n];
|
| 768 |
-
int filter_s = thread_s[n];
|
| 769 |
-
|
| 770 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 771 |
-
filter_r = problem_size.R - 1 - filter_r;
|
| 772 |
-
filter_s = problem_size.S - 1 - filter_s;
|
| 773 |
-
}
|
| 774 |
-
|
| 775 |
-
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 776 |
-
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 777 |
-
|
| 778 |
-
element_B[n] = ElementAccumulator();
|
| 779 |
-
|
| 780 |
-
if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) {
|
| 781 |
-
element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]}));
|
| 782 |
-
}
|
| 783 |
-
}
|
| 784 |
-
|
| 785 |
-
// Accumulate matrix product
|
| 786 |
-
CUTLASS_PRAGMA_UNROLL
|
| 787 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 788 |
-
CUTLASS_PRAGMA_UNROLL
|
| 789 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 790 |
-
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 791 |
-
}
|
| 792 |
-
}
|
| 793 |
-
}
|
| 794 |
-
}
|
| 795 |
-
}
|
| 796 |
-
|
| 797 |
-
// Write out the results
|
| 798 |
-
CUTLASS_PRAGMA_UNROLL
|
| 799 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 800 |
-
int thread_k = k_start + m;
|
| 801 |
-
|
| 802 |
-
if (thread_k < problem_size.K) {
|
| 803 |
-
|
| 804 |
-
CUTLASS_PRAGMA_UNROLL
|
| 805 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 806 |
-
|
| 807 |
-
if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) {
|
| 808 |
-
|
| 809 |
-
ElementCompute c_ref = ElementCompute();
|
| 810 |
-
|
| 811 |
-
if (beta != ElementCompute()) {
|
| 812 |
-
c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}));
|
| 813 |
-
}
|
| 814 |
-
|
| 815 |
-
tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
|
| 816 |
-
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 817 |
-
}
|
| 818 |
-
}
|
| 819 |
-
}
|
| 820 |
-
}
|
| 821 |
-
}
|
| 822 |
-
|
| 823 |
-
// Conv3d wgrad kernel - dw = wgrad(dy, x)
|
| 824 |
-
template <
|
| 825 |
-
typename ElementA,
|
| 826 |
-
typename LayoutA,
|
| 827 |
-
typename ElementB,
|
| 828 |
-
typename LayoutB,
|
| 829 |
-
typename ElementC,
|
| 830 |
-
typename LayoutC,
|
| 831 |
-
typename ElementCompute,
|
| 832 |
-
typename ElementAccumulator = ElementCompute,
|
| 833 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 834 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 835 |
-
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 836 |
-
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 837 |
-
int kCtaShapeM = 8, // shape of a threadblock in units of threads
|
| 838 |
-
int kCtaShapeN = 16 // shape of a threadblock in units of threads
|
| 839 |
-
>
|
| 840 |
-
__global__ void Conv3dWgrad(
|
| 841 |
-
conv::Conv3dProblemSize problem_size,
|
| 842 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 843 |
-
TensorRef<ElementB, LayoutB> tensor_x,
|
| 844 |
-
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 845 |
-
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 846 |
-
ElementCompute alpha,
|
| 847 |
-
ElementCompute beta
|
| 848 |
-
) {
|
| 849 |
-
|
| 850 |
-
ConvertOp convert_op;
|
| 851 |
-
InnerProductOp inner_product_op;
|
| 852 |
-
|
| 853 |
-
ElementAccumulator element_A[kThreadM];
|
| 854 |
-
ElementAccumulator element_B[kThreadN];
|
| 855 |
-
ElementAccumulator accum[kThreadM][kThreadN];
|
| 856 |
-
|
| 857 |
-
int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 858 |
-
int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 859 |
-
|
| 860 |
-
int thread_t[kThreadN];
|
| 861 |
-
int thread_r[kThreadN];
|
| 862 |
-
int thread_s[kThreadN];
|
| 863 |
-
int thread_c[kThreadN];
|
| 864 |
-
|
| 865 |
-
// Compute R, S, C coordinates for each row of a thread's tile
|
| 866 |
-
int64_t SC = int64_t(problem_size.S) * problem_size.C;
|
| 867 |
-
int64_t RSC = SC * problem_size.R;
|
| 868 |
-
|
| 869 |
-
CUTLASS_PRAGMA_UNROLL
|
| 870 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 871 |
-
|
| 872 |
-
int64_t trsc = trsc_start + n;
|
| 873 |
-
|
| 874 |
-
thread_t[n] = int(trsc / RSC);
|
| 875 |
-
|
| 876 |
-
int64_t residual = trsc % RSC;
|
| 877 |
-
thread_r[n] = int(residual / SC);
|
| 878 |
-
|
| 879 |
-
residual = residual % SC;
|
| 880 |
-
thread_s[n] = int(residual / problem_size.C);
|
| 881 |
-
thread_c[n] = int(residual % problem_size.C);
|
| 882 |
-
}
|
| 883 |
-
|
| 884 |
-
// Clear accumulators
|
| 885 |
-
CUTLASS_PRAGMA_UNROLL
|
| 886 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 887 |
-
CUTLASS_PRAGMA_UNROLL
|
| 888 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 889 |
-
accum[m][n] = ElementAccumulator();
|
| 890 |
-
}
|
| 891 |
-
}
|
| 892 |
-
|
| 893 |
-
// Compute convolution
|
| 894 |
-
for (int N = 0; N < problem_size.N; ++N) {
|
| 895 |
-
for (int Z = 0; Z < problem_size.Z; ++Z) {
|
| 896 |
-
for (int P = 0; P < problem_size.P; ++P) {
|
| 897 |
-
for (int Q = 0; Q < problem_size.Q; ++Q) {
|
| 898 |
-
|
| 899 |
-
CUTLASS_PRAGMA_UNROLL
|
| 900 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 901 |
-
int thread_k = k_start + m;
|
| 902 |
-
|
| 903 |
-
element_A[m] = ElementAccumulator();
|
| 904 |
-
|
| 905 |
-
if (thread_k < problem_size.K) {
|
| 906 |
-
element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k}));
|
| 907 |
-
}
|
| 908 |
-
}
|
| 909 |
-
|
| 910 |
-
// Load from filters tensor
|
| 911 |
-
CUTLASS_PRAGMA_UNROLL
|
| 912 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 913 |
-
|
| 914 |
-
// Load from activations tensor
|
| 915 |
-
int filter_t = thread_t[n];
|
| 916 |
-
int filter_r = thread_r[n];
|
| 917 |
-
int filter_s = thread_s[n];
|
| 918 |
-
|
| 919 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 920 |
-
filter_t = problem_size.T - 1 - filter_t;
|
| 921 |
-
filter_r = problem_size.R - 1 - filter_r;
|
| 922 |
-
filter_s = problem_size.S - 1 - filter_s;
|
| 923 |
-
}
|
| 924 |
-
|
| 925 |
-
int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
| 926 |
-
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 927 |
-
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 928 |
-
|
| 929 |
-
element_B[n] = ElementAccumulator();
|
| 930 |
-
|
| 931 |
-
if (d >= 0 && d < problem_size.D &&
|
| 932 |
-
h >= 0 && h < problem_size.H &&
|
| 933 |
-
w >= 0 && w < problem_size.W &&
|
| 934 |
-
thread_c[n] < problem_size.C) {
|
| 935 |
-
|
| 936 |
-
element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]}));
|
| 937 |
-
}
|
| 938 |
-
}
|
| 939 |
-
|
| 940 |
-
// Accumulate matrix product
|
| 941 |
-
CUTLASS_PRAGMA_UNROLL
|
| 942 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 943 |
-
CUTLASS_PRAGMA_UNROLL
|
| 944 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 945 |
-
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 946 |
-
}
|
| 947 |
-
}
|
| 948 |
-
|
| 949 |
-
} // for (Q)
|
| 950 |
-
} // for (P)
|
| 951 |
-
} // for (Z)
|
| 952 |
-
} // for (N)
|
| 953 |
-
|
| 954 |
-
// Write out the results
|
| 955 |
-
CUTLASS_PRAGMA_UNROLL
|
| 956 |
-
for (int m = 0; m < kThreadM; ++m) {
|
| 957 |
-
int thread_k = k_start + m;
|
| 958 |
-
|
| 959 |
-
if (thread_k < problem_size.K) {
|
| 960 |
-
|
| 961 |
-
CUTLASS_PRAGMA_UNROLL
|
| 962 |
-
for (int n = 0; n < kThreadN; ++n) {
|
| 963 |
-
|
| 964 |
-
if (thread_t[n] < problem_size.T &&
|
| 965 |
-
thread_r[n] < problem_size.R &&
|
| 966 |
-
thread_s[n] < problem_size.S &&
|
| 967 |
-
thread_c[n] < problem_size.C) {
|
| 968 |
-
|
| 969 |
-
ElementCompute c_ref = ElementCompute();
|
| 970 |
-
|
| 971 |
-
if (beta != ElementCompute()) {
|
| 972 |
-
c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}));
|
| 973 |
-
}
|
| 974 |
-
|
| 975 |
-
tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
|
| 976 |
-
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 977 |
-
}
|
| 978 |
-
}
|
| 979 |
-
}
|
| 980 |
-
}
|
| 981 |
-
}
|
| 982 |
-
|
| 983 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 984 |
-
|
| 985 |
-
} // namespace kernel
|
| 986 |
-
|
| 987 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 988 |
-
|
| 989 |
-
/// Conv2d Fprop dispatcher - y = fprop(x, w)
|
| 990 |
-
template <
|
| 991 |
-
typename ElementA,
|
| 992 |
-
typename LayoutA,
|
| 993 |
-
typename ElementB,
|
| 994 |
-
typename LayoutB,
|
| 995 |
-
typename ElementC,
|
| 996 |
-
typename LayoutC,
|
| 997 |
-
typename ElementCompute,
|
| 998 |
-
typename ElementAccumulator = ElementCompute,
|
| 999 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1000 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1001 |
-
>
|
| 1002 |
-
Status Conv2dFprop(
|
| 1003 |
-
conv::Conv2dProblemSize problem_size,
|
| 1004 |
-
TensorRef<ElementA, LayoutA> tensor_x,
|
| 1005 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1006 |
-
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 1007 |
-
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 1008 |
-
ElementCompute alpha,
|
| 1009 |
-
ElementCompute beta,
|
| 1010 |
-
cudaStream_t stream = nullptr) {
|
| 1011 |
-
|
| 1012 |
-
//
|
| 1013 |
-
// Blocking factors improve performance of reference implementation
|
| 1014 |
-
//
|
| 1015 |
-
|
| 1016 |
-
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
|
| 1017 |
-
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1018 |
-
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1019 |
-
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1020 |
-
|
| 1021 |
-
int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
|
| 1022 |
-
int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1023 |
-
|
| 1024 |
-
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1025 |
-
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1026 |
-
|
| 1027 |
-
kernel::Conv2dFprop<
|
| 1028 |
-
ElementA,
|
| 1029 |
-
LayoutA,
|
| 1030 |
-
ElementB,
|
| 1031 |
-
LayoutB,
|
| 1032 |
-
ElementC,
|
| 1033 |
-
LayoutC,
|
| 1034 |
-
ElementCompute,
|
| 1035 |
-
ElementAccumulator,
|
| 1036 |
-
ConvertOp,
|
| 1037 |
-
InnerProductOp,
|
| 1038 |
-
kThreadM,
|
| 1039 |
-
kThreadN,
|
| 1040 |
-
kCtaShapeM,
|
| 1041 |
-
kCtaShapeN
|
| 1042 |
-
><<< grid, block, 0, stream >>>(
|
| 1043 |
-
problem_size,
|
| 1044 |
-
tensor_x,
|
| 1045 |
-
tensor_w,
|
| 1046 |
-
tensor_y_in,
|
| 1047 |
-
tensor_y_out,
|
| 1048 |
-
alpha,
|
| 1049 |
-
beta
|
| 1050 |
-
);
|
| 1051 |
-
|
| 1052 |
-
cudaError_t result = cudaPeekAtLastError();
|
| 1053 |
-
if (result != cudaSuccess) {
|
| 1054 |
-
return Status::kErrorInternal;
|
| 1055 |
-
}
|
| 1056 |
-
|
| 1057 |
-
return Status::kSuccess;
|
| 1058 |
-
}
|
| 1059 |
-
|
| 1060 |
-
/// Conv3d Fprop dispatcher - y = fprop(x, w)
|
| 1061 |
-
template <
|
| 1062 |
-
typename ElementA,
|
| 1063 |
-
typename LayoutA,
|
| 1064 |
-
typename ElementB,
|
| 1065 |
-
typename LayoutB,
|
| 1066 |
-
typename ElementC,
|
| 1067 |
-
typename LayoutC,
|
| 1068 |
-
typename ElementCompute,
|
| 1069 |
-
typename ElementAccumulator = ElementCompute,
|
| 1070 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1071 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1072 |
-
>
|
| 1073 |
-
Status Conv3dFprop(
|
| 1074 |
-
conv::Conv3dProblemSize problem_size,
|
| 1075 |
-
TensorRef<ElementA, LayoutA> tensor_x,
|
| 1076 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1077 |
-
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 1078 |
-
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 1079 |
-
ElementCompute alpha,
|
| 1080 |
-
ElementCompute beta,
|
| 1081 |
-
cudaStream_t stream = nullptr) {
|
| 1082 |
-
|
| 1083 |
-
//
|
| 1084 |
-
// Blocking factors improve performance of reference implementation
|
| 1085 |
-
//
|
| 1086 |
-
|
| 1087 |
-
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
|
| 1088 |
-
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1089 |
-
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1090 |
-
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1091 |
-
|
| 1092 |
-
int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q;
|
| 1093 |
-
int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1094 |
-
|
| 1095 |
-
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1096 |
-
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1097 |
-
|
| 1098 |
-
kernel::Conv3dFprop<
|
| 1099 |
-
ElementA,
|
| 1100 |
-
LayoutA,
|
| 1101 |
-
ElementB,
|
| 1102 |
-
LayoutB,
|
| 1103 |
-
ElementC,
|
| 1104 |
-
LayoutC,
|
| 1105 |
-
ElementCompute,
|
| 1106 |
-
ElementAccumulator,
|
| 1107 |
-
ConvertOp,
|
| 1108 |
-
InnerProductOp,
|
| 1109 |
-
kThreadM,
|
| 1110 |
-
kThreadN,
|
| 1111 |
-
kCtaShapeM,
|
| 1112 |
-
kCtaShapeN
|
| 1113 |
-
><<< grid, block, 0, stream >>>(
|
| 1114 |
-
problem_size,
|
| 1115 |
-
tensor_x,
|
| 1116 |
-
tensor_w,
|
| 1117 |
-
tensor_y_in,
|
| 1118 |
-
tensor_y_out,
|
| 1119 |
-
alpha,
|
| 1120 |
-
beta
|
| 1121 |
-
);
|
| 1122 |
-
|
| 1123 |
-
cudaError_t result = cudaPeekAtLastError();
|
| 1124 |
-
if (result != cudaSuccess) {
|
| 1125 |
-
return Status::kErrorInternal;
|
| 1126 |
-
}
|
| 1127 |
-
|
| 1128 |
-
return Status::kSuccess;
|
| 1129 |
-
}
|
| 1130 |
-
|
| 1131 |
-
/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w)
|
| 1132 |
-
template <
|
| 1133 |
-
typename ElementA,
|
| 1134 |
-
typename LayoutA,
|
| 1135 |
-
typename ElementB,
|
| 1136 |
-
typename LayoutB,
|
| 1137 |
-
typename ElementC,
|
| 1138 |
-
typename LayoutC,
|
| 1139 |
-
typename ElementCompute,
|
| 1140 |
-
typename ElementAccumulator = ElementCompute,
|
| 1141 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1142 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1143 |
-
>
|
| 1144 |
-
Status Conv2dDgrad(
|
| 1145 |
-
conv::Conv2dProblemSize problem_size,
|
| 1146 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1147 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1148 |
-
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 1149 |
-
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 1150 |
-
ElementCompute alpha,
|
| 1151 |
-
ElementCompute beta,
|
| 1152 |
-
cudaStream_t stream = nullptr) {
|
| 1153 |
-
|
| 1154 |
-
//
|
| 1155 |
-
// Blocking factors improve performance of reference implementation
|
| 1156 |
-
//
|
| 1157 |
-
|
| 1158 |
-
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1159 |
-
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1160 |
-
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1161 |
-
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1162 |
-
|
| 1163 |
-
int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W;
|
| 1164 |
-
int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1165 |
-
|
| 1166 |
-
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1167 |
-
dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1168 |
-
|
| 1169 |
-
kernel::Conv2dDgrad<
|
| 1170 |
-
ElementA,
|
| 1171 |
-
LayoutA,
|
| 1172 |
-
ElementB,
|
| 1173 |
-
LayoutB,
|
| 1174 |
-
ElementC,
|
| 1175 |
-
LayoutC,
|
| 1176 |
-
ElementCompute,
|
| 1177 |
-
ElementAccumulator,
|
| 1178 |
-
ConvertOp,
|
| 1179 |
-
InnerProductOp,
|
| 1180 |
-
kThreadM,
|
| 1181 |
-
kThreadN,
|
| 1182 |
-
kCtaShapeM,
|
| 1183 |
-
kCtaShapeN
|
| 1184 |
-
><<< grid, block, 0, stream >>>(
|
| 1185 |
-
problem_size,
|
| 1186 |
-
tensor_dy,
|
| 1187 |
-
tensor_w,
|
| 1188 |
-
tensor_dx_in,
|
| 1189 |
-
tensor_dx_out,
|
| 1190 |
-
alpha,
|
| 1191 |
-
beta
|
| 1192 |
-
);
|
| 1193 |
-
|
| 1194 |
-
cudaError_t result = cudaPeekAtLastError();
|
| 1195 |
-
if (result != cudaSuccess) {
|
| 1196 |
-
return Status::kErrorInternal;
|
| 1197 |
-
}
|
| 1198 |
-
|
| 1199 |
-
return Status::kSuccess;
|
| 1200 |
-
}
|
| 1201 |
-
|
| 1202 |
-
/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w)
|
| 1203 |
-
template <
|
| 1204 |
-
typename ElementA,
|
| 1205 |
-
typename LayoutA,
|
| 1206 |
-
typename ElementB,
|
| 1207 |
-
typename LayoutB,
|
| 1208 |
-
typename ElementC,
|
| 1209 |
-
typename LayoutC,
|
| 1210 |
-
typename ElementCompute,
|
| 1211 |
-
typename ElementAccumulator = ElementCompute,
|
| 1212 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1213 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1214 |
-
>
|
| 1215 |
-
Status Conv3dDgrad(
|
| 1216 |
-
conv::Conv3dProblemSize problem_size,
|
| 1217 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1218 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1219 |
-
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 1220 |
-
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 1221 |
-
ElementCompute alpha,
|
| 1222 |
-
ElementCompute beta,
|
| 1223 |
-
cudaStream_t stream = nullptr) {
|
| 1224 |
-
|
| 1225 |
-
//
|
| 1226 |
-
// Blocking factors improve performance of reference implementation
|
| 1227 |
-
//
|
| 1228 |
-
|
| 1229 |
-
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1230 |
-
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1231 |
-
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1232 |
-
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1233 |
-
|
| 1234 |
-
int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W;
|
| 1235 |
-
int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1236 |
-
|
| 1237 |
-
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1238 |
-
dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1239 |
-
|
| 1240 |
-
kernel::Conv3dDgrad<
|
| 1241 |
-
ElementA,
|
| 1242 |
-
LayoutA,
|
| 1243 |
-
ElementB,
|
| 1244 |
-
LayoutB,
|
| 1245 |
-
ElementC,
|
| 1246 |
-
LayoutC,
|
| 1247 |
-
ElementCompute,
|
| 1248 |
-
ElementAccumulator,
|
| 1249 |
-
ConvertOp,
|
| 1250 |
-
InnerProductOp,
|
| 1251 |
-
kThreadM,
|
| 1252 |
-
kThreadN,
|
| 1253 |
-
kCtaShapeM,
|
| 1254 |
-
kCtaShapeN
|
| 1255 |
-
><<< grid, block, 0, stream >>>(
|
| 1256 |
-
problem_size,
|
| 1257 |
-
tensor_dy,
|
| 1258 |
-
tensor_w,
|
| 1259 |
-
tensor_dx_in,
|
| 1260 |
-
tensor_dx_out,
|
| 1261 |
-
alpha,
|
| 1262 |
-
beta
|
| 1263 |
-
);
|
| 1264 |
-
|
| 1265 |
-
cudaError_t result = cudaPeekAtLastError();
|
| 1266 |
-
if (result != cudaSuccess) {
|
| 1267 |
-
return Status::kErrorInternal;
|
| 1268 |
-
}
|
| 1269 |
-
|
| 1270 |
-
return Status::kSuccess;
|
| 1271 |
-
}
|
| 1272 |
-
|
| 1273 |
-
/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x)
|
| 1274 |
-
template <
|
| 1275 |
-
typename ElementA,
|
| 1276 |
-
typename LayoutA,
|
| 1277 |
-
typename ElementB,
|
| 1278 |
-
typename LayoutB,
|
| 1279 |
-
typename ElementC,
|
| 1280 |
-
typename LayoutC,
|
| 1281 |
-
typename ElementCompute,
|
| 1282 |
-
typename ElementAccumulator = ElementCompute,
|
| 1283 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1284 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1285 |
-
>
|
| 1286 |
-
Status Conv2dWgrad(
|
| 1287 |
-
conv::Conv2dProblemSize problem_size,
|
| 1288 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1289 |
-
TensorRef<ElementB, LayoutB> tensor_x,
|
| 1290 |
-
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 1291 |
-
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 1292 |
-
ElementCompute alpha,
|
| 1293 |
-
ElementCompute beta,
|
| 1294 |
-
cudaStream_t stream = nullptr) {
|
| 1295 |
-
|
| 1296 |
-
//
|
| 1297 |
-
// Blocking factors improve performance of reference implementation
|
| 1298 |
-
//
|
| 1299 |
-
|
| 1300 |
-
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1301 |
-
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1302 |
-
int const kCtaShapeM = 8; // shape of a threadblock in units of threads
|
| 1303 |
-
int const kCtaShapeN = 16; // shape of a threadblock in units of threads
|
| 1304 |
-
|
| 1305 |
-
int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C;
|
| 1306 |
-
int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
|
| 1307 |
-
|
| 1308 |
-
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1309 |
-
dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
|
| 1310 |
-
|
| 1311 |
-
kernel::Conv2dWgrad<
|
| 1312 |
-
ElementA,
|
| 1313 |
-
LayoutA,
|
| 1314 |
-
ElementB,
|
| 1315 |
-
LayoutB,
|
| 1316 |
-
ElementC,
|
| 1317 |
-
LayoutC,
|
| 1318 |
-
ElementCompute,
|
| 1319 |
-
ElementAccumulator,
|
| 1320 |
-
ConvertOp,
|
| 1321 |
-
InnerProductOp,
|
| 1322 |
-
kThreadM,
|
| 1323 |
-
kThreadN,
|
| 1324 |
-
kCtaShapeM,
|
| 1325 |
-
kCtaShapeN
|
| 1326 |
-
><<< grid, block, 0, stream >>>(
|
| 1327 |
-
problem_size,
|
| 1328 |
-
tensor_dy,
|
| 1329 |
-
tensor_x,
|
| 1330 |
-
tensor_dw_in,
|
| 1331 |
-
tensor_dw_out,
|
| 1332 |
-
alpha,
|
| 1333 |
-
beta
|
| 1334 |
-
);
|
| 1335 |
-
|
| 1336 |
-
cudaError_t result = cudaPeekAtLastError();
|
| 1337 |
-
if (result != cudaSuccess) {
|
| 1338 |
-
return Status::kErrorInternal;
|
| 1339 |
-
}
|
| 1340 |
-
|
| 1341 |
-
return Status::kSuccess;
|
| 1342 |
-
}
|
| 1343 |
-
|
| 1344 |
-
/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x)
|
| 1345 |
-
template <
|
| 1346 |
-
typename ElementA,
|
| 1347 |
-
typename LayoutA,
|
| 1348 |
-
typename ElementB,
|
| 1349 |
-
typename LayoutB,
|
| 1350 |
-
typename ElementC,
|
| 1351 |
-
typename LayoutC,
|
| 1352 |
-
typename ElementCompute,
|
| 1353 |
-
typename ElementAccumulator = ElementCompute,
|
| 1354 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1355 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1356 |
-
>
|
| 1357 |
-
Status Conv3dWgrad(
|
| 1358 |
-
conv::Conv3dProblemSize problem_size,
|
| 1359 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1360 |
-
TensorRef<ElementB, LayoutB> tensor_x,
|
| 1361 |
-
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 1362 |
-
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 1363 |
-
ElementCompute alpha,
|
| 1364 |
-
ElementCompute beta,
|
| 1365 |
-
cudaStream_t stream = nullptr) {
|
| 1366 |
-
|
| 1367 |
-
//
|
| 1368 |
-
// Blocking factors improve performance of reference implementation
|
| 1369 |
-
//
|
| 1370 |
-
|
| 1371 |
-
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1372 |
-
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1373 |
-
int const kCtaShapeM = 8; // shape of a threadblock in units of threads
|
| 1374 |
-
int const kCtaShapeN = 16; // shape of a threadblock in units of threads
|
| 1375 |
-
|
| 1376 |
-
int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C;
|
| 1377 |
-
int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
|
| 1378 |
-
|
| 1379 |
-
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1380 |
-
dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
|
| 1381 |
-
|
| 1382 |
-
kernel::Conv3dWgrad<
|
| 1383 |
-
ElementA,
|
| 1384 |
-
LayoutA,
|
| 1385 |
-
ElementB,
|
| 1386 |
-
LayoutB,
|
| 1387 |
-
ElementC,
|
| 1388 |
-
LayoutC,
|
| 1389 |
-
ElementCompute,
|
| 1390 |
-
ElementAccumulator,
|
| 1391 |
-
ConvertOp,
|
| 1392 |
-
InnerProductOp,
|
| 1393 |
-
kThreadM,
|
| 1394 |
-
kThreadN,
|
| 1395 |
-
kCtaShapeM,
|
| 1396 |
-
kCtaShapeN
|
| 1397 |
-
><<< grid, block, 0, stream >>>(
|
| 1398 |
-
problem_size,
|
| 1399 |
-
tensor_dy,
|
| 1400 |
-
tensor_x,
|
| 1401 |
-
tensor_dw_in,
|
| 1402 |
-
tensor_dw_out,
|
| 1403 |
-
alpha,
|
| 1404 |
-
beta
|
| 1405 |
-
);
|
| 1406 |
-
|
| 1407 |
-
cudaError_t result = cudaPeekAtLastError();
|
| 1408 |
-
if (result != cudaSuccess) {
|
| 1409 |
-
return Status::kErrorInternal;
|
| 1410 |
-
}
|
| 1411 |
-
|
| 1412 |
-
return Status::kSuccess;
|
| 1413 |
-
}
|
| 1414 |
-
|
| 1415 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1416 |
-
|
| 1417 |
-
/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
|
| 1418 |
-
template <
|
| 1419 |
-
typename ElementA,
|
| 1420 |
-
typename LayoutA,
|
| 1421 |
-
typename ElementB,
|
| 1422 |
-
typename LayoutB,
|
| 1423 |
-
typename ElementC,
|
| 1424 |
-
typename LayoutC,
|
| 1425 |
-
typename ElementCompute,
|
| 1426 |
-
typename ElementAccumulator = ElementCompute,
|
| 1427 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1428 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1429 |
-
>
|
| 1430 |
-
Status Conv2d(
|
| 1431 |
-
conv::Operator convolutional_operator,
|
| 1432 |
-
conv::Conv2dProblemSize problem_size,
|
| 1433 |
-
TensorRef<ElementA, LayoutA> tensor_A,
|
| 1434 |
-
TensorRef<ElementB, LayoutB> tensor_B,
|
| 1435 |
-
TensorRef<ElementC, LayoutC> tensor_C,
|
| 1436 |
-
TensorRef<ElementC, LayoutC> tensor_D,
|
| 1437 |
-
ElementCompute alpha,
|
| 1438 |
-
ElementCompute beta,
|
| 1439 |
-
cudaStream_t stream = nullptr) {
|
| 1440 |
-
|
| 1441 |
-
switch (convolutional_operator) {
|
| 1442 |
-
case conv::Operator::kFprop:
|
| 1443 |
-
return Conv2dFprop<
|
| 1444 |
-
ElementA, LayoutA,
|
| 1445 |
-
ElementB, LayoutB,
|
| 1446 |
-
ElementC, LayoutC,
|
| 1447 |
-
ElementCompute,
|
| 1448 |
-
ElementAccumulator,
|
| 1449 |
-
ConvertOp, InnerProductOp
|
| 1450 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1451 |
-
break;
|
| 1452 |
-
|
| 1453 |
-
case conv::Operator::kDgrad:
|
| 1454 |
-
return Conv2dDgrad<
|
| 1455 |
-
ElementA, LayoutA,
|
| 1456 |
-
ElementB, LayoutB,
|
| 1457 |
-
ElementC, LayoutC,
|
| 1458 |
-
ElementCompute,
|
| 1459 |
-
ElementAccumulator,
|
| 1460 |
-
ConvertOp, InnerProductOp
|
| 1461 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1462 |
-
break;
|
| 1463 |
-
|
| 1464 |
-
case conv::Operator::kWgrad:
|
| 1465 |
-
return Conv2dWgrad<
|
| 1466 |
-
ElementA, LayoutA,
|
| 1467 |
-
ElementB, LayoutB,
|
| 1468 |
-
ElementC, LayoutC,
|
| 1469 |
-
ElementCompute,
|
| 1470 |
-
ElementAccumulator,
|
| 1471 |
-
ConvertOp, InnerProductOp
|
| 1472 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1473 |
-
break;
|
| 1474 |
-
|
| 1475 |
-
default: break;
|
| 1476 |
-
}
|
| 1477 |
-
|
| 1478 |
-
return Status::kErrorNotSupported;
|
| 1479 |
-
}
|
| 1480 |
-
|
| 1481 |
-
/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad.
|
| 1482 |
-
template <
|
| 1483 |
-
typename ElementA,
|
| 1484 |
-
typename LayoutA,
|
| 1485 |
-
typename ElementB,
|
| 1486 |
-
typename LayoutB,
|
| 1487 |
-
typename ElementC,
|
| 1488 |
-
typename LayoutC,
|
| 1489 |
-
typename ElementCompute,
|
| 1490 |
-
typename ElementAccumulator = ElementCompute,
|
| 1491 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1492 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1493 |
-
>
|
| 1494 |
-
Status Conv3d(
|
| 1495 |
-
conv::Operator convolutional_operator,
|
| 1496 |
-
conv::Conv3dProblemSize problem_size,
|
| 1497 |
-
TensorRef<ElementA, LayoutA> tensor_A,
|
| 1498 |
-
TensorRef<ElementB, LayoutB> tensor_B,
|
| 1499 |
-
TensorRef<ElementC, LayoutC> tensor_C,
|
| 1500 |
-
TensorRef<ElementC, LayoutC> tensor_D,
|
| 1501 |
-
ElementCompute alpha,
|
| 1502 |
-
ElementCompute beta,
|
| 1503 |
-
cudaStream_t stream = nullptr) {
|
| 1504 |
-
|
| 1505 |
-
switch (convolutional_operator) {
|
| 1506 |
-
case conv::Operator::kFprop:
|
| 1507 |
-
return Conv3dFprop<
|
| 1508 |
-
ElementA, LayoutA,
|
| 1509 |
-
ElementB, LayoutB,
|
| 1510 |
-
ElementC, LayoutC,
|
| 1511 |
-
ElementCompute,
|
| 1512 |
-
ElementAccumulator,
|
| 1513 |
-
ConvertOp, InnerProductOp
|
| 1514 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1515 |
-
|
| 1516 |
-
case conv::Operator::kDgrad:
|
| 1517 |
-
return Conv3dDgrad<
|
| 1518 |
-
ElementA, LayoutA,
|
| 1519 |
-
ElementB, LayoutB,
|
| 1520 |
-
ElementC, LayoutC,
|
| 1521 |
-
ElementCompute,
|
| 1522 |
-
ElementAccumulator,
|
| 1523 |
-
ConvertOp, InnerProductOp
|
| 1524 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1525 |
-
|
| 1526 |
-
case conv::Operator::kWgrad:
|
| 1527 |
-
return Conv3dWgrad<
|
| 1528 |
-
ElementA, LayoutA,
|
| 1529 |
-
ElementB, LayoutB,
|
| 1530 |
-
ElementC, LayoutC,
|
| 1531 |
-
ElementCompute,
|
| 1532 |
-
ElementAccumulator,
|
| 1533 |
-
ConvertOp, InnerProductOp
|
| 1534 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1535 |
-
|
| 1536 |
-
default: break;
|
| 1537 |
-
}
|
| 1538 |
-
|
| 1539 |
-
return Status::kErrorNotSupported;
|
| 1540 |
-
}
|
| 1541 |
-
|
| 1542 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1543 |
-
|
| 1544 |
-
} // namespace device
|
| 1545 |
-
} // namespace reference
|
| 1546 |
-
} // namespace cutlass
|
| 1547 |
-
|
| 1548 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h
DELETED
|
@@ -1,385 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for GEMM in device-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
|
| 39 |
-
#include "cutlass/numeric_types.h"
|
| 40 |
-
#include "cutlass/functional.h"
|
| 41 |
-
#include "cutlass/numeric_conversion.h"
|
| 42 |
-
|
| 43 |
-
#include "cutlass/tensor_view.h"
|
| 44 |
-
#include "cutlass/gemm/gemm.h"
|
| 45 |
-
|
| 46 |
-
#include "cutlass/util/reference/device/kernel/gemm.h"
|
| 47 |
-
|
| 48 |
-
namespace cutlass {
|
| 49 |
-
namespace reference {
|
| 50 |
-
namespace device {
|
| 51 |
-
|
| 52 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
-
|
| 54 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
-
/// objects.
|
| 56 |
-
///
|
| 57 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 58 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 59 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 60 |
-
/// arguments explicitly.
|
| 61 |
-
template <
|
| 62 |
-
typename ElementA,
|
| 63 |
-
typename LayoutA,
|
| 64 |
-
typename ElementB,
|
| 65 |
-
typename LayoutB,
|
| 66 |
-
typename ElementC,
|
| 67 |
-
typename LayoutC,
|
| 68 |
-
typename ScalarType,
|
| 69 |
-
typename AccumulatorType,
|
| 70 |
-
typename InnerProductOp = multiply_add<AccumulatorType>,
|
| 71 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 72 |
-
>
|
| 73 |
-
void compute_gemm(
|
| 74 |
-
gemm::GemmCoord problem_size,
|
| 75 |
-
ScalarType alpha,
|
| 76 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 77 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 78 |
-
ScalarType beta,
|
| 79 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 80 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 81 |
-
AccumulatorType initial_accum) {
|
| 82 |
-
|
| 83 |
-
static_assert(
|
| 84 |
-
LayoutA::kRank == 2 &&
|
| 85 |
-
LayoutB::kRank == 2 &&
|
| 86 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 87 |
-
|
| 88 |
-
// Blocking structure potentially improves performance of reference implementation
|
| 89 |
-
// with a minor increase in complexity.
|
| 90 |
-
//
|
| 91 |
-
// Note, this reference implementation is NOT expected to approach peak performance.
|
| 92 |
-
using OutputTile = MatrixShape<4, 4>;
|
| 93 |
-
|
| 94 |
-
dim3 block(16, 8);
|
| 95 |
-
|
| 96 |
-
dim3 grid(
|
| 97 |
-
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
|
| 98 |
-
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
|
| 99 |
-
);
|
| 100 |
-
|
| 101 |
-
// Launch a GEMM kernel
|
| 102 |
-
kernel::Gemm<
|
| 103 |
-
TensorRef<ElementA, LayoutA>,
|
| 104 |
-
TensorRef<ElementB, LayoutB>,
|
| 105 |
-
TensorRef<ElementC, LayoutC>,
|
| 106 |
-
ScalarType,
|
| 107 |
-
AccumulatorType,
|
| 108 |
-
OutputTile,
|
| 109 |
-
InnerProductOp,
|
| 110 |
-
ConvertOp
|
| 111 |
-
><<< grid, block >>>(
|
| 112 |
-
problem_size,
|
| 113 |
-
alpha,
|
| 114 |
-
tensor_a,
|
| 115 |
-
tensor_b,
|
| 116 |
-
beta,
|
| 117 |
-
tensor_c,
|
| 118 |
-
tensor_d,
|
| 119 |
-
initial_accum
|
| 120 |
-
);
|
| 121 |
-
}
|
| 122 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 123 |
-
|
| 124 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 125 |
-
/// objects.
|
| 126 |
-
///
|
| 127 |
-
/// This assumes the accumulator type is the same type as the scalars.
|
| 128 |
-
template <
|
| 129 |
-
typename ElementA,
|
| 130 |
-
typename LayoutA,
|
| 131 |
-
typename ElementB,
|
| 132 |
-
typename LayoutB,
|
| 133 |
-
typename ElementC,
|
| 134 |
-
typename LayoutC,
|
| 135 |
-
typename ScalarType,
|
| 136 |
-
typename AccumulatorType,
|
| 137 |
-
typename InnerProductOp = multiply_add<AccumulatorType>,
|
| 138 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 139 |
-
>
|
| 140 |
-
void compute_gemm(
|
| 141 |
-
gemm::GemmCoord problem_size,
|
| 142 |
-
ScalarType alpha,
|
| 143 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 144 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 145 |
-
ScalarType beta,
|
| 146 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 147 |
-
AccumulatorType initial_accum) {
|
| 148 |
-
|
| 149 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 150 |
-
ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
|
| 151 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
| 152 |
-
initial_accum);
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
template <
|
| 156 |
-
typename ElementA,
|
| 157 |
-
typename LayoutA,
|
| 158 |
-
typename ElementB,
|
| 159 |
-
typename LayoutB,
|
| 160 |
-
typename ElementC,
|
| 161 |
-
typename LayoutC,
|
| 162 |
-
typename ScalarType,
|
| 163 |
-
typename AccumulatorType,
|
| 164 |
-
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 165 |
-
>
|
| 166 |
-
struct Gemm;
|
| 167 |
-
|
| 168 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 169 |
-
|
| 170 |
-
/// Partial specialization for multiply-add
|
| 171 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 172 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 173 |
-
typename ScalarType, typename AccumulatorType>
|
| 174 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 175 |
-
ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
|
| 176 |
-
|
| 177 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 178 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 179 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 180 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 181 |
-
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 182 |
-
|
| 183 |
-
static_assert(
|
| 184 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 185 |
-
"Tensors must be of rank 2");
|
| 186 |
-
|
| 187 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 188 |
-
ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
|
| 189 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 190 |
-
}
|
| 191 |
-
|
| 192 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 193 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 194 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 195 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 196 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 197 |
-
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 198 |
-
static_assert(
|
| 199 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 200 |
-
"Tensors must be of rank 2");
|
| 201 |
-
|
| 202 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 203 |
-
ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
|
| 204 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 205 |
-
}
|
| 206 |
-
};
|
| 207 |
-
|
| 208 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 209 |
-
|
| 210 |
-
/// Partial specialization for multiply-add-saturate
|
| 211 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 212 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 213 |
-
typename ScalarType, typename AccumulatorType>
|
| 214 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 215 |
-
AccumulatorType, arch::OpMultiplyAddSaturate> {
|
| 216 |
-
|
| 217 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 218 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 219 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 220 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 221 |
-
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 222 |
-
static_assert(
|
| 223 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 224 |
-
"Tensors must be of rank 2");
|
| 225 |
-
|
| 226 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 227 |
-
ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
|
| 228 |
-
NumericConverterClamp<ElementC, ScalarType>>(
|
| 229 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 233 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 234 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 235 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 236 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 237 |
-
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 238 |
-
static_assert(
|
| 239 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 240 |
-
"Tensors must be of rank 2");
|
| 241 |
-
|
| 242 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 243 |
-
ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
|
| 244 |
-
NumericConverterClamp<ElementC, ScalarType>>(
|
| 245 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 246 |
-
}
|
| 247 |
-
};
|
| 248 |
-
|
| 249 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 250 |
-
|
| 251 |
-
/// Partial specialization for XOR-popc
|
| 252 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 253 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 254 |
-
typename ScalarType, typename AccumulatorType>
|
| 255 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 256 |
-
AccumulatorType, arch::OpXorPopc> {
|
| 257 |
-
|
| 258 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 259 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 260 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 261 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 262 |
-
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 263 |
-
static_assert(
|
| 264 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 265 |
-
"Tensors must be of rank 2");
|
| 266 |
-
|
| 267 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 268 |
-
ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
|
| 269 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 273 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 274 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 275 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 276 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 277 |
-
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 278 |
-
static_assert(
|
| 279 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 280 |
-
"Tensors must be of rank 2");
|
| 281 |
-
|
| 282 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 283 |
-
ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
|
| 284 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 285 |
-
}
|
| 286 |
-
};
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 290 |
-
//
|
| 291 |
-
// Batched GEMM
|
| 292 |
-
//
|
| 293 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 294 |
-
|
| 295 |
-
/// Computes a batch of GEMMs over a set of matrices of common dimension.
|
| 296 |
-
//
|
| 297 |
-
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 298 |
-
//
|
| 299 |
-
template <
|
| 300 |
-
typename TensorRefCollectionA,
|
| 301 |
-
typename TensorRefCollectionB,
|
| 302 |
-
typename TensorRefCollectionC,
|
| 303 |
-
typename ScalarType,
|
| 304 |
-
typename AccumulatorType,
|
| 305 |
-
typename InnerProductOp,
|
| 306 |
-
typename ConvertOp
|
| 307 |
-
>
|
| 308 |
-
void BatchedGemm(
|
| 309 |
-
gemm::GemmCoord problem_size,
|
| 310 |
-
int batch_count,
|
| 311 |
-
ScalarType alpha,
|
| 312 |
-
TensorRefCollectionA const& tensor_a,
|
| 313 |
-
TensorRefCollectionB const& tensor_b,
|
| 314 |
-
ScalarType beta,
|
| 315 |
-
TensorRefCollectionC &tensor_c,
|
| 316 |
-
AccumulatorType initial_accum) {
|
| 317 |
-
|
| 318 |
-
static_assert(
|
| 319 |
-
TensorRefCollectionA::kRank == 2 &&
|
| 320 |
-
TensorRefCollectionB::kRank == 2 &&
|
| 321 |
-
TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2");
|
| 322 |
-
|
| 323 |
-
// Blocking structure potentially improves performance of reference implementation
|
| 324 |
-
// with a minor increase in complexity.
|
| 325 |
-
//
|
| 326 |
-
// Note, this reference implementation is NOT expected to approach peak performance.
|
| 327 |
-
using OutputTile = MatrixShape<4, 4>;
|
| 328 |
-
|
| 329 |
-
dim3 block(16, 8);
|
| 330 |
-
dim3 grid(
|
| 331 |
-
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
|
| 332 |
-
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
|
| 333 |
-
batch_count
|
| 334 |
-
);
|
| 335 |
-
|
| 336 |
-
// Launch a GEMM kernel
|
| 337 |
-
kernel::BatchedGemm<
|
| 338 |
-
TensorRefCollectionA,
|
| 339 |
-
TensorRefCollectionB,
|
| 340 |
-
TensorRefCollectionC,
|
| 341 |
-
ScalarType,
|
| 342 |
-
AccumulatorType,
|
| 343 |
-
OutputTile,
|
| 344 |
-
InnerProductOp,
|
| 345 |
-
ConvertOp
|
| 346 |
-
><<< grid, block >>>(
|
| 347 |
-
problem_size,
|
| 348 |
-
alpha,
|
| 349 |
-
tensor_a,
|
| 350 |
-
tensor_b,
|
| 351 |
-
beta,
|
| 352 |
-
tensor_c,
|
| 353 |
-
initial_accum
|
| 354 |
-
);
|
| 355 |
-
}
|
| 356 |
-
|
| 357 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 358 |
-
/// objects.
|
| 359 |
-
//
|
| 360 |
-
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 361 |
-
//
|
| 362 |
-
template <
|
| 363 |
-
typename TensorRefCollectionA,
|
| 364 |
-
typename TensorRefCollectionB,
|
| 365 |
-
typename TensorRefCollectionC,
|
| 366 |
-
typename ScalarType,
|
| 367 |
-
typename AccumulatorType
|
| 368 |
-
>
|
| 369 |
-
void BatchedGemm(
|
| 370 |
-
gemm::GemmCoord problem_size,
|
| 371 |
-
int batch_count,
|
| 372 |
-
ScalarType alpha,
|
| 373 |
-
TensorRefCollectionA const& tensor_a,
|
| 374 |
-
TensorRefCollectionB const& tensor_b,
|
| 375 |
-
ScalarType beta,
|
| 376 |
-
TensorRefCollectionC &tensor_c) {
|
| 377 |
-
|
| 378 |
-
BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 382 |
-
|
| 383 |
-
} // namespace device
|
| 384 |
-
} // namespace reference
|
| 385 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h
DELETED
|
@@ -1,350 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for complex-valued GEMM in device-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/complex.h"
|
| 39 |
-
#include "cutlass/numeric_types.h"
|
| 40 |
-
#include "cutlass/functional.h"
|
| 41 |
-
#include "cutlass/numeric_conversion.h"
|
| 42 |
-
|
| 43 |
-
#include "cutlass/tensor_view.h"
|
| 44 |
-
#include "cutlass/gemm/gemm.h"
|
| 45 |
-
|
| 46 |
-
namespace cutlass {
|
| 47 |
-
namespace reference {
|
| 48 |
-
namespace device {
|
| 49 |
-
|
| 50 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
-
|
| 52 |
-
namespace kernel {
|
| 53 |
-
|
| 54 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
-
/// objects.
|
| 56 |
-
///
|
| 57 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 58 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 59 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 60 |
-
/// arguments explicitly.
|
| 61 |
-
template <
|
| 62 |
-
typename ElementA,
|
| 63 |
-
typename LayoutA,
|
| 64 |
-
typename ElementB,
|
| 65 |
-
typename LayoutB,
|
| 66 |
-
typename ElementC,
|
| 67 |
-
typename LayoutC,
|
| 68 |
-
typename ScalarType,
|
| 69 |
-
typename ComputeType,
|
| 70 |
-
typename ElementD = ElementC,
|
| 71 |
-
typename ConvertOp = NumericConverter<ElementD, ScalarType>,
|
| 72 |
-
typename InnerProductOp = multiply_add<ComputeType>,
|
| 73 |
-
int kMblock = 4,
|
| 74 |
-
int kNblock = 4
|
| 75 |
-
>
|
| 76 |
-
__global__ void GemmComplex(
|
| 77 |
-
gemm::GemmCoord problem_size,
|
| 78 |
-
ScalarType alpha,
|
| 79 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 80 |
-
ComplexTransform transform_a,
|
| 81 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 82 |
-
ComplexTransform transform_b,
|
| 83 |
-
ScalarType beta,
|
| 84 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 85 |
-
TensorRef<ElementD, LayoutC> tensor_d,
|
| 86 |
-
ComputeType initial_accum,
|
| 87 |
-
int batch_count = 1,
|
| 88 |
-
int64_t batch_stride_A = 0,
|
| 89 |
-
int64_t batch_stride_B = 0,
|
| 90 |
-
int64_t batch_stride_C = 0,
|
| 91 |
-
int64_t batch_stride_D = 0) {
|
| 92 |
-
|
| 93 |
-
static_assert(
|
| 94 |
-
LayoutA::kRank == 2 &&
|
| 95 |
-
LayoutB::kRank == 2 &&
|
| 96 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 97 |
-
|
| 98 |
-
int const M = problem_size.m();
|
| 99 |
-
int const N = problem_size.n();
|
| 100 |
-
int const K = problem_size.k();
|
| 101 |
-
|
| 102 |
-
ConvertOp convert_op;
|
| 103 |
-
InnerProductOp inner_product_op;
|
| 104 |
-
|
| 105 |
-
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
| 106 |
-
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
| 107 |
-
int batch_idx = blockIdx.z;
|
| 108 |
-
|
| 109 |
-
tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
|
| 110 |
-
tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
|
| 111 |
-
tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
|
| 112 |
-
tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
|
| 113 |
-
|
| 114 |
-
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
| 115 |
-
|
| 116 |
-
// Compute matrix product using blocks
|
| 117 |
-
ComputeType accum[kMblock][kNblock];
|
| 118 |
-
|
| 119 |
-
CUTLASS_PRAGMA_UNROLL
|
| 120 |
-
for (int j = 0; j < kNblock; j++) {
|
| 121 |
-
CUTLASS_PRAGMA_UNROLL
|
| 122 |
-
for (int i = 0; i < kMblock; i++) {
|
| 123 |
-
accum[i][j] = initial_accum;
|
| 124 |
-
}
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
for (int k_block = 0; k_block < K; ++k_block) {
|
| 128 |
-
CUTLASS_PRAGMA_UNROLL
|
| 129 |
-
for (int j = 0; j < kNblock; j++) {
|
| 130 |
-
CUTLASS_PRAGMA_UNROLL
|
| 131 |
-
for (int i = 0; i < kMblock; i++) {
|
| 132 |
-
int row = row_block + i;
|
| 133 |
-
int col = col_block + j;
|
| 134 |
-
|
| 135 |
-
if (row < M && col < N) {
|
| 136 |
-
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 137 |
-
ElementB b = tensor_b.at(MatrixCoord(k_block, col));
|
| 138 |
-
|
| 139 |
-
ComputeType a_ik = ComputeType(a);
|
| 140 |
-
ComputeType b_kj = ComputeType(b);
|
| 141 |
-
|
| 142 |
-
if (transform_a == ComplexTransform::kConjugate) {
|
| 143 |
-
a_ik = conj(a_ik);
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
if (transform_b == ComplexTransform::kConjugate) {
|
| 147 |
-
b_kj = conj(b_kj);
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 151 |
-
}
|
| 152 |
-
}
|
| 153 |
-
}
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
CUTLASS_PRAGMA_UNROLL
|
| 157 |
-
for (int j = 0; j < kNblock; j++) {
|
| 158 |
-
CUTLASS_PRAGMA_UNROLL
|
| 159 |
-
for (int i = 0; i < kMblock; i++) {
|
| 160 |
-
int row = row_block + i;
|
| 161 |
-
int col = col_block + j;
|
| 162 |
-
|
| 163 |
-
MatrixCoord coord = MatrixCoord(row, col);
|
| 164 |
-
|
| 165 |
-
if (row < M && col < N) {
|
| 166 |
-
|
| 167 |
-
tensor_d.at(coord) = convert_op(
|
| 168 |
-
alpha * ScalarType(accum[i][j]) +
|
| 169 |
-
beta * ScalarType(tensor_c.at(coord)));
|
| 170 |
-
}
|
| 171 |
-
}
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
|
| 175 |
-
tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
|
| 176 |
-
tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
|
| 177 |
-
tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
|
| 178 |
-
|
| 179 |
-
} // for (batch_idx)
|
| 180 |
-
}
|
| 181 |
-
|
| 182 |
-
} // namespace kernel
|
| 183 |
-
|
| 184 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 185 |
-
|
| 186 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 187 |
-
/// objects.
|
| 188 |
-
///
|
| 189 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 190 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 191 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 192 |
-
/// arguments explicitly.
|
| 193 |
-
template <
|
| 194 |
-
typename ElementA,
|
| 195 |
-
typename LayoutA,
|
| 196 |
-
typename ElementB,
|
| 197 |
-
typename LayoutB,
|
| 198 |
-
typename ElementC,
|
| 199 |
-
typename LayoutC,
|
| 200 |
-
typename ScalarType,
|
| 201 |
-
typename ComputeType,
|
| 202 |
-
typename ElementD = ElementC,
|
| 203 |
-
typename ConvertOp = NumericConverter<ElementD, ScalarType>,
|
| 204 |
-
typename InnerProductOp = multiply_add<ComputeType>
|
| 205 |
-
>
|
| 206 |
-
void GemmComplex(
|
| 207 |
-
gemm::GemmCoord problem_size,
|
| 208 |
-
ScalarType alpha,
|
| 209 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 210 |
-
ComplexTransform transform_a,
|
| 211 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 212 |
-
ComplexTransform transform_b,
|
| 213 |
-
ScalarType beta,
|
| 214 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 215 |
-
TensorRef<ElementD, LayoutC> tensor_d,
|
| 216 |
-
ComputeType initial_accum,
|
| 217 |
-
int batch_count = 1,
|
| 218 |
-
int64_t batch_stride_A = 0,
|
| 219 |
-
int64_t batch_stride_B = 0,
|
| 220 |
-
int64_t batch_stride_C = 0,
|
| 221 |
-
int64_t batch_stride_D = 0) {
|
| 222 |
-
|
| 223 |
-
static_assert(
|
| 224 |
-
LayoutA::kRank == 2 &&
|
| 225 |
-
LayoutB::kRank == 2 &&
|
| 226 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 227 |
-
|
| 228 |
-
int const kMblock = 4;
|
| 229 |
-
int const kNblock = 4;
|
| 230 |
-
|
| 231 |
-
dim3 block(16, 8);
|
| 232 |
-
dim3 grid(
|
| 233 |
-
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
| 234 |
-
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
| 235 |
-
batch_count % std::numeric_limits<uint16_t>::max()
|
| 236 |
-
);
|
| 237 |
-
|
| 238 |
-
if (grid.y <= std::numeric_limits<uint16_t>::max()) {
|
| 239 |
-
kernel::GemmComplex<
|
| 240 |
-
ElementA,
|
| 241 |
-
LayoutA,
|
| 242 |
-
ElementB,
|
| 243 |
-
LayoutB,
|
| 244 |
-
ElementC,
|
| 245 |
-
LayoutC,
|
| 246 |
-
ScalarType,
|
| 247 |
-
ComputeType,
|
| 248 |
-
ElementD,
|
| 249 |
-
ConvertOp,
|
| 250 |
-
InnerProductOp,
|
| 251 |
-
kMblock,
|
| 252 |
-
kNblock
|
| 253 |
-
><<< grid, block >>>(
|
| 254 |
-
problem_size,
|
| 255 |
-
alpha,
|
| 256 |
-
tensor_a,
|
| 257 |
-
transform_a,
|
| 258 |
-
tensor_b,
|
| 259 |
-
transform_b,
|
| 260 |
-
beta,
|
| 261 |
-
tensor_c,
|
| 262 |
-
tensor_d,
|
| 263 |
-
initial_accum,
|
| 264 |
-
batch_count,
|
| 265 |
-
batch_stride_A,
|
| 266 |
-
batch_stride_B,
|
| 267 |
-
batch_stride_C,
|
| 268 |
-
batch_stride_D
|
| 269 |
-
);
|
| 270 |
-
} else {
|
| 271 |
-
// Using bigger thread tile size
|
| 272 |
-
int const kBigMblock = 4;
|
| 273 |
-
int const kBigNblock = 16;
|
| 274 |
-
|
| 275 |
-
dim3 Bigblock(16, 8);
|
| 276 |
-
dim3 Biggrid(
|
| 277 |
-
(problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock),
|
| 278 |
-
(problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock),
|
| 279 |
-
batch_count % std::numeric_limits<uint16_t>::max()
|
| 280 |
-
);
|
| 281 |
-
|
| 282 |
-
kernel::GemmComplex<
|
| 283 |
-
ElementA,
|
| 284 |
-
LayoutA,
|
| 285 |
-
ElementB,
|
| 286 |
-
LayoutB,
|
| 287 |
-
ElementC,
|
| 288 |
-
LayoutC,
|
| 289 |
-
ScalarType,
|
| 290 |
-
ComputeType,
|
| 291 |
-
ElementD,
|
| 292 |
-
ConvertOp,
|
| 293 |
-
InnerProductOp,
|
| 294 |
-
kBigMblock,
|
| 295 |
-
kBigNblock
|
| 296 |
-
><<< Biggrid, Bigblock >>>(
|
| 297 |
-
problem_size,
|
| 298 |
-
alpha,
|
| 299 |
-
tensor_a,
|
| 300 |
-
transform_a,
|
| 301 |
-
tensor_b,
|
| 302 |
-
transform_b,
|
| 303 |
-
beta,
|
| 304 |
-
tensor_c,
|
| 305 |
-
tensor_d,
|
| 306 |
-
initial_accum,
|
| 307 |
-
batch_count,
|
| 308 |
-
batch_stride_A,
|
| 309 |
-
batch_stride_B,
|
| 310 |
-
batch_stride_C,
|
| 311 |
-
batch_stride_D
|
| 312 |
-
);
|
| 313 |
-
}
|
| 314 |
-
}
|
| 315 |
-
|
| 316 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 317 |
-
|
| 318 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 319 |
-
/// objects.
|
| 320 |
-
///
|
| 321 |
-
/// This assumes the accumulator type is the same type as the scalars.
|
| 322 |
-
template <
|
| 323 |
-
typename ElementA,
|
| 324 |
-
typename LayoutA,
|
| 325 |
-
typename ElementB,
|
| 326 |
-
typename LayoutB,
|
| 327 |
-
typename ElementC,
|
| 328 |
-
typename LayoutC,
|
| 329 |
-
typename ScalarType,
|
| 330 |
-
typename ElementD = ElementC
|
| 331 |
-
>
|
| 332 |
-
void GemmComplex(
|
| 333 |
-
gemm::GemmCoord problem_size,
|
| 334 |
-
ScalarType alpha,
|
| 335 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 336 |
-
ComplexTransform transform_a,
|
| 337 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 338 |
-
ComplexTransform transform_b,
|
| 339 |
-
ScalarType beta,
|
| 340 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 341 |
-
TensorRef<ElementD, LayoutC> tensor_d) {
|
| 342 |
-
|
| 343 |
-
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 347 |
-
|
| 348 |
-
} // namespace device
|
| 349 |
-
} // namespace reference
|
| 350 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h
DELETED
|
@@ -1,311 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for complex-valued GEMM in device code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/complex.h"
|
| 39 |
-
#include "cutlass/matrix_coord.h"
|
| 40 |
-
#include "cutlass/numeric_types.h"
|
| 41 |
-
#include "cutlass/functional.h"
|
| 42 |
-
#include "cutlass/numeric_conversion.h"
|
| 43 |
-
#include "cutlass/tensor_ref_planar_complex.h"
|
| 44 |
-
|
| 45 |
-
#include "cutlass/tensor_view.h"
|
| 46 |
-
#include "cutlass/gemm/gemm.h"
|
| 47 |
-
|
| 48 |
-
namespace cutlass {
|
| 49 |
-
namespace reference {
|
| 50 |
-
namespace device {
|
| 51 |
-
|
| 52 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
-
|
| 54 |
-
namespace kernel {
|
| 55 |
-
|
| 56 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
-
|
| 58 |
-
static int const kGemmPlanarComplexBlockSize = 4;
|
| 59 |
-
|
| 60 |
-
template <
|
| 61 |
-
typename ElementA,
|
| 62 |
-
typename LayoutA,
|
| 63 |
-
typename ElementB,
|
| 64 |
-
typename LayoutB,
|
| 65 |
-
typename ElementC,
|
| 66 |
-
typename LayoutC,
|
| 67 |
-
typename ScalarType,
|
| 68 |
-
typename ComputeType,
|
| 69 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 70 |
-
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
| 71 |
-
>
|
| 72 |
-
__global__ void GemmPlanarComplex(
|
| 73 |
-
gemm::GemmCoord problem_size,
|
| 74 |
-
complex<ScalarType> alpha,
|
| 75 |
-
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 76 |
-
ComplexTransform transform_a,
|
| 77 |
-
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 78 |
-
ComplexTransform transform_b,
|
| 79 |
-
complex<ScalarType> beta,
|
| 80 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 81 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
| 82 |
-
complex<ComputeType> initial_accum) {
|
| 83 |
-
|
| 84 |
-
int const kMblock = kGemmPlanarComplexBlockSize;
|
| 85 |
-
int const kNblock = kGemmPlanarComplexBlockSize;
|
| 86 |
-
|
| 87 |
-
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
|
| 88 |
-
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
|
| 89 |
-
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
|
| 90 |
-
|
| 91 |
-
// Note: batch is ignored.
|
| 92 |
-
int const M = problem_size.m();
|
| 93 |
-
int const N = problem_size.n();
|
| 94 |
-
int const K = problem_size.k();
|
| 95 |
-
|
| 96 |
-
ConvertOp convert_op;
|
| 97 |
-
InnerProductOp inner_product_op;
|
| 98 |
-
|
| 99 |
-
complex<ComputeType> accum[kMblock][kNblock];
|
| 100 |
-
|
| 101 |
-
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
| 102 |
-
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
| 103 |
-
|
| 104 |
-
CUTLASS_PRAGMA_UNROLL
|
| 105 |
-
for (int j = 0; j < kNblock; j++) {
|
| 106 |
-
CUTLASS_PRAGMA_UNROLL
|
| 107 |
-
for (int i = 0; i < kMblock; i++) {
|
| 108 |
-
accum[i][j] = initial_accum;
|
| 109 |
-
}
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 113 |
-
for (int k_block = 0; k_block < K; ++k_block) {
|
| 114 |
-
|
| 115 |
-
CUTLASS_PRAGMA_UNROLL
|
| 116 |
-
for (int j = 0; j < kNblock; j++) {
|
| 117 |
-
|
| 118 |
-
CUTLASS_PRAGMA_UNROLL
|
| 119 |
-
for (int i = 0; i < kMblock; i++) {
|
| 120 |
-
|
| 121 |
-
int row = row_block + i;
|
| 122 |
-
int col = col_block + j;
|
| 123 |
-
|
| 124 |
-
if (row < M && col < N) {
|
| 125 |
-
|
| 126 |
-
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
|
| 127 |
-
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
|
| 128 |
-
|
| 129 |
-
complex<ComputeType> a = complex<ComputeType>{
|
| 130 |
-
ComputeType(a_ik.real()),
|
| 131 |
-
ComputeType(a_ik.imag())
|
| 132 |
-
};
|
| 133 |
-
|
| 134 |
-
complex<ComputeType> b = complex<ComputeType>{
|
| 135 |
-
ComputeType(b_kj.real()),
|
| 136 |
-
ComputeType(b_kj.imag())
|
| 137 |
-
};
|
| 138 |
-
|
| 139 |
-
if (transform_a == ComplexTransform::kConjugate) {
|
| 140 |
-
a = conj(a);
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
if (transform_b == ComplexTransform::kConjugate) {
|
| 144 |
-
b = conj(b);
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
accum[i][j] = inner_product_op(a, b, accum[i][j]);
|
| 148 |
-
}
|
| 149 |
-
}
|
| 150 |
-
}
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
CUTLASS_PRAGMA_UNROLL
|
| 154 |
-
for (int j = 0; j < kNblock; j++) {
|
| 155 |
-
CUTLASS_PRAGMA_UNROLL
|
| 156 |
-
for (int i = 0; i < kMblock; i++) {
|
| 157 |
-
|
| 158 |
-
int row = row_block + i;
|
| 159 |
-
int col = col_block + j;
|
| 160 |
-
|
| 161 |
-
MatrixCoord coord = MatrixCoord(row, col);
|
| 162 |
-
|
| 163 |
-
if (row < M && col < N) {
|
| 164 |
-
|
| 165 |
-
complex<ScalarType> acc{
|
| 166 |
-
ScalarType(accum[i][j].real()),
|
| 167 |
-
ScalarType(accum[i][j].imag())
|
| 168 |
-
};
|
| 169 |
-
|
| 170 |
-
ComplexC c_ij = ComplexC();
|
| 171 |
-
|
| 172 |
-
if (beta.real() != ScalarType() || beta.imag() != ScalarType()) {
|
| 173 |
-
c_ij = tensor_c.at(coord);
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
complex<ScalarType> src{
|
| 177 |
-
ScalarType(c_ij.real()),
|
| 178 |
-
ScalarType(c_ij.imag())
|
| 179 |
-
};
|
| 180 |
-
|
| 181 |
-
complex<ScalarType> result = alpha * acc + beta * src;
|
| 182 |
-
|
| 183 |
-
ComplexC d_ij;
|
| 184 |
-
|
| 185 |
-
d_ij.real() = convert_op(result.real());
|
| 186 |
-
d_ij.imag() = convert_op(result.imag());
|
| 187 |
-
|
| 188 |
-
tensor_d.at(coord) = d_ij;
|
| 189 |
-
}
|
| 190 |
-
}
|
| 191 |
-
}
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 195 |
-
|
| 196 |
-
} // namespace kernel
|
| 197 |
-
|
| 198 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 199 |
-
|
| 200 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 201 |
-
/// objects.
|
| 202 |
-
///
|
| 203 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 204 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 205 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 206 |
-
/// arguments explicitly.
|
| 207 |
-
template <
|
| 208 |
-
typename ElementA,
|
| 209 |
-
typename LayoutA,
|
| 210 |
-
typename ElementB,
|
| 211 |
-
typename LayoutB,
|
| 212 |
-
typename ElementC,
|
| 213 |
-
typename LayoutC,
|
| 214 |
-
typename ScalarType,
|
| 215 |
-
typename ComputeType,
|
| 216 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 217 |
-
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
| 218 |
-
>
|
| 219 |
-
void GemmPlanarComplex(
|
| 220 |
-
gemm::GemmCoord problem_size,
|
| 221 |
-
complex<ScalarType> alpha,
|
| 222 |
-
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 223 |
-
ComplexTransform transform_a,
|
| 224 |
-
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 225 |
-
ComplexTransform transform_b,
|
| 226 |
-
complex<ScalarType> beta,
|
| 227 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 228 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
| 229 |
-
complex<ComputeType> initial_accum) {
|
| 230 |
-
|
| 231 |
-
static_assert(
|
| 232 |
-
LayoutA::kRank == 2 &&
|
| 233 |
-
LayoutB::kRank == 2 &&
|
| 234 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 235 |
-
|
| 236 |
-
int const kMblock = kernel::kGemmPlanarComplexBlockSize;
|
| 237 |
-
int const kNblock = kernel::kGemmPlanarComplexBlockSize;
|
| 238 |
-
|
| 239 |
-
dim3 block(16, 8);
|
| 240 |
-
|
| 241 |
-
dim3 grid(
|
| 242 |
-
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
| 243 |
-
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
| 244 |
-
1);
|
| 245 |
-
|
| 246 |
-
kernel::GemmPlanarComplex<
|
| 247 |
-
ElementA, LayoutA,
|
| 248 |
-
ElementB, LayoutB,
|
| 249 |
-
ElementC, LayoutC,
|
| 250 |
-
ScalarType,
|
| 251 |
-
ComputeType,
|
| 252 |
-
ConvertOp,
|
| 253 |
-
InnerProductOp
|
| 254 |
-
><<< grid, block >>>(
|
| 255 |
-
problem_size,
|
| 256 |
-
alpha,
|
| 257 |
-
tensor_a,
|
| 258 |
-
transform_a,
|
| 259 |
-
tensor_b,
|
| 260 |
-
transform_b,
|
| 261 |
-
beta,
|
| 262 |
-
tensor_c,
|
| 263 |
-
tensor_d,
|
| 264 |
-
initial_accum
|
| 265 |
-
);
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 269 |
-
|
| 270 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 271 |
-
/// objects.
|
| 272 |
-
///
|
| 273 |
-
/// This assumes the accumulator type is the same type as the scalars.
|
| 274 |
-
template <
|
| 275 |
-
typename ElementA,
|
| 276 |
-
typename LayoutA,
|
| 277 |
-
typename ElementB,
|
| 278 |
-
typename LayoutB,
|
| 279 |
-
typename ElementC,
|
| 280 |
-
typename LayoutC,
|
| 281 |
-
typename ScalarType
|
| 282 |
-
>
|
| 283 |
-
void GemmPlanarComplex(
|
| 284 |
-
gemm::GemmCoord problem_size,
|
| 285 |
-
complex<ScalarType> alpha,
|
| 286 |
-
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 287 |
-
ComplexTransform transform_a,
|
| 288 |
-
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 289 |
-
ComplexTransform transform_b,
|
| 290 |
-
complex<ScalarType> beta,
|
| 291 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 292 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
|
| 293 |
-
|
| 294 |
-
GemmPlanarComplex(
|
| 295 |
-
problem_size,
|
| 296 |
-
alpha,
|
| 297 |
-
tensor_a, transform_a,
|
| 298 |
-
tensor_b, transform_b,
|
| 299 |
-
beta,
|
| 300 |
-
tensor_c,
|
| 301 |
-
tensor_d,
|
| 302 |
-
complex<ScalarType>());
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 306 |
-
|
| 307 |
-
} // namespace device
|
| 308 |
-
} // namespace reference
|
| 309 |
-
} // namespace cutlass
|
| 310 |
-
|
| 311 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp
DELETED
|
@@ -1,146 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief GETT device reference code
|
| 33 |
-
*/
|
| 34 |
-
#pragma once
|
| 35 |
-
|
| 36 |
-
#include <cute/tensor.hpp>
|
| 37 |
-
|
| 38 |
-
namespace cutlass::reference::device {
|
| 39 |
-
|
| 40 |
-
template <
|
| 41 |
-
class ATensor,
|
| 42 |
-
class BTensor,
|
| 43 |
-
class CTensor,
|
| 44 |
-
class DTensor,
|
| 45 |
-
class ElementAccumulator,
|
| 46 |
-
class ElementEpilogue>
|
| 47 |
-
__global__ static
|
| 48 |
-
void
|
| 49 |
-
gett_kernel(
|
| 50 |
-
DTensor D,
|
| 51 |
-
ATensor const A,
|
| 52 |
-
BTensor const B,
|
| 53 |
-
CTensor const C,
|
| 54 |
-
ElementEpilogue alpha, ElementEpilogue beta,
|
| 55 |
-
ElementAccumulator acc_init)
|
| 56 |
-
{
|
| 57 |
-
using namespace cute;
|
| 58 |
-
|
| 59 |
-
static_assert(DTensor::rank == 3, "(M,N,L)");
|
| 60 |
-
static_assert(ATensor::rank == 3, "(M,K,L)");
|
| 61 |
-
static_assert(BTensor::rank == 3, "(N,K,L)");
|
| 62 |
-
static_assert(CTensor::rank == 3, "(M,N,L)");
|
| 63 |
-
|
| 64 |
-
assert(size<0>(A) == size<0>(D)); // M
|
| 65 |
-
assert(size<0>(C) == size<0>(D)); // M
|
| 66 |
-
assert(size<0>(B) == size<1>(D)); // N
|
| 67 |
-
assert(size<1>(C) == size<1>(D)); // N
|
| 68 |
-
assert(size<1>(A) == size<1>(B)); // K
|
| 69 |
-
assert(size<2>(A) == size<2>(D)); // L
|
| 70 |
-
assert(size<2>(B) == size<2>(D)); // L
|
| 71 |
-
assert(size<2>(C) == size<2>(D)); // L
|
| 72 |
-
|
| 73 |
-
NumericConverter<ElementAccumulator, typename ATensor::value_type> a_converter;
|
| 74 |
-
NumericConverter<ElementAccumulator, typename BTensor::value_type> b_converter;
|
| 75 |
-
NumericConverter<ElementEpilogue, ElementAccumulator> acc_converter;
|
| 76 |
-
NumericConverter<ElementEpilogue, typename CTensor::value_type> source_converter;
|
| 77 |
-
NumericConverter<typename DTensor::value_type, ElementEpilogue> output_converter;
|
| 78 |
-
|
| 79 |
-
// Thread id to each element of D
|
| 80 |
-
for (int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
| 81 |
-
tid < size(D);
|
| 82 |
-
tid += blockDim.x * gridDim.x) {
|
| 83 |
-
// (m,n,l) coordinate
|
| 84 |
-
auto mnl_coord = idx2crd(tid, product_each(shape(D)));
|
| 85 |
-
auto m = get<0>(mnl_coord);
|
| 86 |
-
auto n = get<1>(mnl_coord);
|
| 87 |
-
auto l = get<2>(mnl_coord);
|
| 88 |
-
|
| 89 |
-
auto A_ml = A(m,_,l);
|
| 90 |
-
auto B_nl = B(n,_,l);
|
| 91 |
-
|
| 92 |
-
ElementAccumulator accum = ElementAccumulator(0);
|
| 93 |
-
for (int k = 0; k < size<1>(A); ++k) {
|
| 94 |
-
ElementAccumulator a = a_converter(A_ml(k));
|
| 95 |
-
ElementAccumulator b = b_converter(B_nl(k));
|
| 96 |
-
accum += a * b;
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l)));
|
| 100 |
-
D(m,n,l) = output_converter(scaled_output);
|
| 101 |
-
}
|
| 102 |
-
}
|
| 103 |
-
|
| 104 |
-
// Most general version
|
| 105 |
-
template <
|
| 106 |
-
class ProblemShapeMNKL,
|
| 107 |
-
class ElementA,
|
| 108 |
-
class StrideA,
|
| 109 |
-
class ElementB,
|
| 110 |
-
class StrideB,
|
| 111 |
-
class ElementAccumulator,
|
| 112 |
-
class ElementC,
|
| 113 |
-
class StrideC,
|
| 114 |
-
class ElementD,
|
| 115 |
-
class StrideD,
|
| 116 |
-
class ElementEpilogue>
|
| 117 |
-
void
|
| 118 |
-
gett(
|
| 119 |
-
ProblemShapeMNKL problem_shape_mnkl,
|
| 120 |
-
ElementA const* ptr_A, StrideA stride_a_mkl,
|
| 121 |
-
ElementB const* ptr_B, StrideB stride_b_nkl,
|
| 122 |
-
ElementAccumulator _,
|
| 123 |
-
ElementC const* ptr_C, StrideC stride_c_mnl,
|
| 124 |
-
ElementD * ptr_D, StrideD stride_d_mnl,
|
| 125 |
-
ElementEpilogue alpha, ElementEpilogue beta,
|
| 126 |
-
cudaStream_t stream = 0) {
|
| 127 |
-
using namespace cute;
|
| 128 |
-
|
| 129 |
-
static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
|
| 130 |
-
auto M = get<0>(problem_shape_mnkl);
|
| 131 |
-
auto N = get<1>(problem_shape_mnkl);
|
| 132 |
-
auto K = get<2>(problem_shape_mnkl);
|
| 133 |
-
auto L = get<3>(problem_shape_mnkl);
|
| 134 |
-
|
| 135 |
-
// Represent the full tensors
|
| 136 |
-
auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L)
|
| 137 |
-
auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L)
|
| 138 |
-
auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L)
|
| 139 |
-
auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L)
|
| 140 |
-
|
| 141 |
-
dim3 dimBlock(256);
|
| 142 |
-
dim3 dimGrid(240);
|
| 143 |
-
gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0));
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
} // namespace cutlass::reference::device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/tensor_view.h"
|
| 39 |
-
#include "cutlass/gemm/gemm.h"
|
| 40 |
-
|
| 41 |
-
#include "cutlass/util/reference/device/thread/gemm.h"
|
| 42 |
-
|
| 43 |
-
namespace cutlass {
|
| 44 |
-
namespace reference {
|
| 45 |
-
namespace device {
|
| 46 |
-
namespace kernel {
|
| 47 |
-
|
| 48 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
-
|
| 50 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 51 |
-
/// objects.
|
| 52 |
-
template <
|
| 53 |
-
typename TensorRefA,
|
| 54 |
-
typename TensorRefB,
|
| 55 |
-
typename TensorRefC,
|
| 56 |
-
typename ScalarType,
|
| 57 |
-
typename AccumulatorType,
|
| 58 |
-
typename OutputTile,
|
| 59 |
-
typename InnerProductOp,
|
| 60 |
-
typename ConvertOp
|
| 61 |
-
>
|
| 62 |
-
__global__ void Gemm(
|
| 63 |
-
gemm::GemmCoord problem_size,
|
| 64 |
-
ScalarType alpha,
|
| 65 |
-
TensorRefA tensor_a,
|
| 66 |
-
TensorRefB tensor_b,
|
| 67 |
-
ScalarType beta,
|
| 68 |
-
TensorRefC tensor_c,
|
| 69 |
-
TensorRefC tensor_d,
|
| 70 |
-
AccumulatorType initial_accum) {
|
| 71 |
-
|
| 72 |
-
// Map each thread to a unique tile of the output matrix
|
| 73 |
-
MatrixCoord output_coord(
|
| 74 |
-
MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
|
| 75 |
-
MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
|
| 76 |
-
);
|
| 77 |
-
|
| 78 |
-
// Compute the general matrix product
|
| 79 |
-
thread::Gemm<
|
| 80 |
-
TensorRefA,
|
| 81 |
-
TensorRefB,
|
| 82 |
-
TensorRefC,
|
| 83 |
-
ScalarType,
|
| 84 |
-
AccumulatorType,
|
| 85 |
-
OutputTile,
|
| 86 |
-
InnerProductOp,
|
| 87 |
-
ConvertOp
|
| 88 |
-
> gemm(initial_accum);
|
| 89 |
-
|
| 90 |
-
gemm.multiply_add(
|
| 91 |
-
problem_size,
|
| 92 |
-
tensor_a,
|
| 93 |
-
tensor_b,
|
| 94 |
-
output_coord);
|
| 95 |
-
|
| 96 |
-
gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord);
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 100 |
-
|
| 101 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 102 |
-
/// objects.
|
| 103 |
-
template <
|
| 104 |
-
typename TensorRefCollectionA,
|
| 105 |
-
typename TensorRefCollectionB,
|
| 106 |
-
typename TensorRefCollectionC,
|
| 107 |
-
typename ScalarType,
|
| 108 |
-
typename AccumulatorType,
|
| 109 |
-
typename OutputTile,
|
| 110 |
-
typename InnerProductOp,
|
| 111 |
-
typename ConvertOp
|
| 112 |
-
>
|
| 113 |
-
__global__ void BatchedGemm(
|
| 114 |
-
gemm::GemmCoord problem_size,
|
| 115 |
-
ScalarType alpha,
|
| 116 |
-
TensorRefCollectionA tensor_collection_a,
|
| 117 |
-
TensorRefCollectionB tensor_collection_b,
|
| 118 |
-
ScalarType beta,
|
| 119 |
-
TensorRefCollectionC tensor_collection_c,
|
| 120 |
-
AccumulatorType initial_accum) {
|
| 121 |
-
|
| 122 |
-
// Obtain batch ID
|
| 123 |
-
int batch_id = blockIdx.z;
|
| 124 |
-
|
| 125 |
-
// Dereference based on batch_id
|
| 126 |
-
typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
|
| 127 |
-
typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
|
| 128 |
-
typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
|
| 129 |
-
|
| 130 |
-
// Map each thread to a unique tile of the output matrix
|
| 131 |
-
MatrixCoord output_coord(
|
| 132 |
-
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn,
|
| 133 |
-
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow
|
| 134 |
-
);
|
| 135 |
-
|
| 136 |
-
// Compute the general matrix product
|
| 137 |
-
thread::Gemm<
|
| 138 |
-
typename TensorRefCollectionA::TensorRef,
|
| 139 |
-
typename TensorRefCollectionB::TensorRef,
|
| 140 |
-
typename TensorRefCollectionC::TensorRef,
|
| 141 |
-
ScalarType,
|
| 142 |
-
AccumulatorType,
|
| 143 |
-
OutputTile,
|
| 144 |
-
InnerProductOp,
|
| 145 |
-
ConvertOp
|
| 146 |
-
> gemm(initial_accum);
|
| 147 |
-
|
| 148 |
-
gemm.multiply_add(
|
| 149 |
-
problem_size,
|
| 150 |
-
tensor_a,
|
| 151 |
-
tensor_b,
|
| 152 |
-
output_coord);
|
| 153 |
-
|
| 154 |
-
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 158 |
-
|
| 159 |
-
} // namespace kernel
|
| 160 |
-
} // namespace device
|
| 161 |
-
} // namespace reference
|
| 162 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <curand_kernel.h>
|
| 35 |
-
|
| 36 |
-
#include "cutlass/cutlass.h"
|
| 37 |
-
|
| 38 |
-
namespace cutlass {
|
| 39 |
-
namespace reference {
|
| 40 |
-
namespace device {
|
| 41 |
-
namespace kernel {
|
| 42 |
-
|
| 43 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
-
|
| 45 |
-
/// Kernel to initialize tensor to uniform random distribution
|
| 46 |
-
template <typename T>
|
| 47 |
-
__global__ void TensorInitializeUniform(
|
| 48 |
-
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 49 |
-
__shared__ curandState_t rng_state[1024];
|
| 50 |
-
|
| 51 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 52 |
-
|
| 53 |
-
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 54 |
-
|
| 55 |
-
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 56 |
-
int s_idx = blockIdx.y * blockDim.x;
|
| 57 |
-
|
| 58 |
-
tensor += s_idx * ldm + c_idx;
|
| 59 |
-
|
| 60 |
-
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 61 |
-
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 62 |
-
double range = dist.uniform.max - dist.uniform.min;
|
| 63 |
-
|
| 64 |
-
double rnd = curand_uniform(&rng_state[threadIdx.x]);
|
| 65 |
-
|
| 66 |
-
rnd = dist.uniform.min + range * rnd;
|
| 67 |
-
|
| 68 |
-
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 69 |
-
// testing
|
| 70 |
-
if (dist.int_scale >= 0) {
|
| 71 |
-
rnd = double(int(rnd * double(1 << dist.int_scale)));
|
| 72 |
-
*tensor = T(rnd / double(1 << dist.int_scale));
|
| 73 |
-
} else {
|
| 74 |
-
*tensor = T(rnd);
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
tensor += ldm;
|
| 78 |
-
}
|
| 79 |
-
}
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 83 |
-
|
| 84 |
-
/// Kernel to initialize tensor to uniform distribution
|
| 85 |
-
template <typename T>
|
| 86 |
-
__global__ void TensorInitializeGaussian(
|
| 87 |
-
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 88 |
-
__shared__ curandState_t rng_state[1024];
|
| 89 |
-
|
| 90 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 91 |
-
|
| 92 |
-
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 93 |
-
|
| 94 |
-
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 95 |
-
int s_idx = blockIdx.y * blockDim.x;
|
| 96 |
-
|
| 97 |
-
tensor += s_idx * ldm + c_idx;
|
| 98 |
-
|
| 99 |
-
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 100 |
-
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 101 |
-
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 102 |
-
// testing
|
| 103 |
-
|
| 104 |
-
double rnd = curand_normal(&rng_state[threadIdx.x]);
|
| 105 |
-
|
| 106 |
-
rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd;
|
| 107 |
-
|
| 108 |
-
if (dist.int_scale >= 0) {
|
| 109 |
-
rnd = double(int(rnd * double(1 << dist.int_scale)));
|
| 110 |
-
*tensor = T(rnd / double(1 << dist.int_scale));
|
| 111 |
-
} else {
|
| 112 |
-
*tensor = T(rnd);
|
| 113 |
-
}
|
| 114 |
-
}
|
| 115 |
-
}
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
/// Kernel to initialize tensor to an identity matrix
|
| 119 |
-
template <typename T>
|
| 120 |
-
__global__ void TensorInitializeLinear(
|
| 121 |
-
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 122 |
-
__shared__ curandState_t rng_state[1024];
|
| 123 |
-
|
| 124 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 125 |
-
|
| 126 |
-
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 127 |
-
|
| 128 |
-
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 129 |
-
int s_idx = blockIdx.y * blockDim.x;
|
| 130 |
-
|
| 131 |
-
tensor += s_idx * ldm + c_idx;
|
| 132 |
-
|
| 133 |
-
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 134 |
-
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 135 |
-
*tensor =
|
| 136 |
-
dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx;
|
| 137 |
-
}
|
| 138 |
-
}
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
/// Kernel to initialize tensor to an identity matrix
|
| 142 |
-
template <typename T>
|
| 143 |
-
__global__ void TensorInitializeIdentity(
|
| 144 |
-
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 145 |
-
__shared__ curandState_t rng_state[1024];
|
| 146 |
-
|
| 147 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 148 |
-
|
| 149 |
-
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 150 |
-
|
| 151 |
-
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 152 |
-
int s_idx = blockIdx.y * blockDim.x;
|
| 153 |
-
|
| 154 |
-
tensor += s_idx * ldm + c_idx;
|
| 155 |
-
|
| 156 |
-
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 157 |
-
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 158 |
-
*tensor = (c_idx == s_idx ? T(1) : T(0));
|
| 159 |
-
}
|
| 160 |
-
}
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 164 |
-
|
| 165 |
-
} // namespace kernel
|
| 166 |
-
} // namespace device
|
| 167 |
-
} // namespace reference
|
| 168 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h
DELETED
|
@@ -1,159 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/coord.h"
|
| 36 |
-
#include "cutlass/subbyte_reference.h"
|
| 37 |
-
#include "cutlass/fast_math.h"
|
| 38 |
-
|
| 39 |
-
namespace cutlass {
|
| 40 |
-
namespace reference {
|
| 41 |
-
namespace device {
|
| 42 |
-
namespace kernel {
|
| 43 |
-
|
| 44 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
-
|
| 46 |
-
/// Defines several helpers
|
| 47 |
-
namespace detail {
|
| 48 |
-
|
| 49 |
-
/// Helper to perform for-each operation
|
| 50 |
-
template <typename Func, int Rank, int RankRemaining>
|
| 51 |
-
struct TensorForEachHelper {
|
| 52 |
-
|
| 53 |
-
/// Constructor for general rank
|
| 54 |
-
__inline__ __device__
|
| 55 |
-
TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
|
| 56 |
-
|
| 57 |
-
int64_t product = 1;
|
| 58 |
-
|
| 59 |
-
CUTLASS_PRAGMA_UNROLL
|
| 60 |
-
for (int i = Rank - RankRemaining; i < Rank; ++i) {
|
| 61 |
-
product *= size[i];
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
coord[Rank - 1 - RankRemaining] = index / product;
|
| 65 |
-
int64_t remaining = index % product;
|
| 66 |
-
|
| 67 |
-
TensorForEachHelper<Func, Rank, RankRemaining-1>(func, size, coord, remaining);
|
| 68 |
-
}
|
| 69 |
-
};
|
| 70 |
-
|
| 71 |
-
/// Helper to perform for-each operation
|
| 72 |
-
template <typename Func, int Rank>
|
| 73 |
-
struct TensorForEachHelper<Func, Rank, 0> {
|
| 74 |
-
|
| 75 |
-
/// Constructor for fastest changing rank
|
| 76 |
-
__inline__ __device__
|
| 77 |
-
TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
|
| 78 |
-
|
| 79 |
-
coord[Rank - 1] = index;
|
| 80 |
-
|
| 81 |
-
if (coord < size) {
|
| 82 |
-
func(coord);
|
| 83 |
-
}
|
| 84 |
-
}
|
| 85 |
-
};
|
| 86 |
-
|
| 87 |
-
} // namespace detail
|
| 88 |
-
|
| 89 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 90 |
-
|
| 91 |
-
/// Kernel calls a functor for each element in a tensor's index space
|
| 92 |
-
template <typename Func, int Rank, typename Params>
|
| 93 |
-
__global__ void TensorForEach(Coord<Rank> size, Params params = Params()) {
|
| 94 |
-
|
| 95 |
-
Func func(params);
|
| 96 |
-
|
| 97 |
-
int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
| 98 |
-
int64_t max_index = 1;
|
| 99 |
-
|
| 100 |
-
CUTLASS_PRAGMA_UNROLL
|
| 101 |
-
for (int i = 0; i < Rank; ++i) {
|
| 102 |
-
max_index *= size[i];
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 106 |
-
while (index < max_index) {
|
| 107 |
-
Coord<Rank> coord;
|
| 108 |
-
|
| 109 |
-
detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord, index);
|
| 110 |
-
index += blockDim.x * gridDim.x;
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 115 |
-
|
| 116 |
-
/// Kernel calls a functor for each element along a tensor's diagonal
|
| 117 |
-
template <typename Func, int Rank, typename Params>
|
| 118 |
-
__global__ void TensorDiagonalForEach(Coord<Rank> size, Params params, int start, int end) {
|
| 119 |
-
|
| 120 |
-
Func func(params);
|
| 121 |
-
|
| 122 |
-
int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start;
|
| 123 |
-
|
| 124 |
-
if (index < end) {
|
| 125 |
-
Coord<Rank> coord;
|
| 126 |
-
|
| 127 |
-
CUTLASS_PRAGMA_UNROLL
|
| 128 |
-
for (int i = 0; i < Rank; ++i) {
|
| 129 |
-
coord[i] = index;
|
| 130 |
-
}
|
| 131 |
-
|
| 132 |
-
func(coord);
|
| 133 |
-
}
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 137 |
-
|
| 138 |
-
template <typename Element, typename Func>
|
| 139 |
-
__global__ void BlockForEach(
|
| 140 |
-
Element *ptr,
|
| 141 |
-
size_t capacity,
|
| 142 |
-
typename Func::Params params) {
|
| 143 |
-
|
| 144 |
-
Func func(params);
|
| 145 |
-
|
| 146 |
-
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
| 147 |
-
|
| 148 |
-
for (; index < capacity; index += blockDim.x * gridDim.x) {
|
| 149 |
-
ReferenceFactory<Element>::get(ptr, index) = func();
|
| 150 |
-
}
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 154 |
-
|
| 155 |
-
} // namespace kernel
|
| 156 |
-
} // namespace device
|
| 157 |
-
} // namespace reference
|
| 158 |
-
} // namespace cutlass
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h
DELETED
|
@@ -1,355 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for complex-valued GEMM in device-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/blas3.h"
|
| 38 |
-
#include "cutlass/complex.h"
|
| 39 |
-
#include "cutlass/numeric_conversion.h"
|
| 40 |
-
#include "cutlass/tensor_view.h"
|
| 41 |
-
#include "cutlass/gemm/gemm.h"
|
| 42 |
-
|
| 43 |
-
namespace cutlass {
|
| 44 |
-
namespace reference {
|
| 45 |
-
namespace device {
|
| 46 |
-
|
| 47 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
-
|
| 49 |
-
namespace kernel {
|
| 50 |
-
|
| 51 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 52 |
-
/// objects.
|
| 53 |
-
///
|
| 54 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 55 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 56 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 57 |
-
/// arguments explicitly.
|
| 58 |
-
template <
|
| 59 |
-
typename ElementA,
|
| 60 |
-
typename LayoutA,
|
| 61 |
-
typename ElementB,
|
| 62 |
-
typename LayoutB,
|
| 63 |
-
typename ElementC,
|
| 64 |
-
typename LayoutC,
|
| 65 |
-
typename ScalarType,
|
| 66 |
-
typename ComputeType,
|
| 67 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 68 |
-
typename InnerProductOp = multiply_add<ComputeType>,
|
| 69 |
-
int kMblock = 4,
|
| 70 |
-
int kNblock = 4
|
| 71 |
-
>
|
| 72 |
-
__global__ void Rank2KComplex(
|
| 73 |
-
gemm::GemmCoord problem_size,
|
| 74 |
-
ScalarType alpha,
|
| 75 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 76 |
-
ComplexTransform transform_a,
|
| 77 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 78 |
-
ComplexTransform transform_b,
|
| 79 |
-
ScalarType beta,
|
| 80 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 81 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 82 |
-
ComputeType initial_accum,
|
| 83 |
-
FillMode fill_mode_c,
|
| 84 |
-
BlasMode blas_mode,
|
| 85 |
-
int batch_count = 1,
|
| 86 |
-
int64_t batch_stride_A = 0,
|
| 87 |
-
int64_t batch_stride_B = 0,
|
| 88 |
-
int64_t batch_stride_C = 0,
|
| 89 |
-
int64_t batch_stride_D = 0) {
|
| 90 |
-
|
| 91 |
-
static_assert(
|
| 92 |
-
LayoutA::kRank == 2 &&
|
| 93 |
-
LayoutB::kRank == 2 &&
|
| 94 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 95 |
-
|
| 96 |
-
int const M = problem_size.m();
|
| 97 |
-
int const N = problem_size.n();
|
| 98 |
-
int const K = problem_size.k();
|
| 99 |
-
|
| 100 |
-
assert(M=N);
|
| 101 |
-
|
| 102 |
-
ConvertOp convert_op;
|
| 103 |
-
InnerProductOp inner_product_op;
|
| 104 |
-
|
| 105 |
-
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
| 106 |
-
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
| 107 |
-
int batch_idx = blockIdx.z;
|
| 108 |
-
|
| 109 |
-
tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
|
| 110 |
-
tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
|
| 111 |
-
tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
|
| 112 |
-
tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
|
| 113 |
-
|
| 114 |
-
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
| 115 |
-
|
| 116 |
-
// Compute matrix product using blocks
|
| 117 |
-
ComputeType accum[kMblock][kNblock];
|
| 118 |
-
|
| 119 |
-
CUTLASS_PRAGMA_UNROLL
|
| 120 |
-
for (int j = 0; j < kNblock; j++) {
|
| 121 |
-
CUTLASS_PRAGMA_UNROLL
|
| 122 |
-
for (int i = 0; i < kMblock; i++) {
|
| 123 |
-
accum[i][j] = initial_accum;
|
| 124 |
-
}
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
for (int k_block = 0; k_block < K; ++k_block) {
|
| 128 |
-
CUTLASS_PRAGMA_UNROLL
|
| 129 |
-
for (int j = 0; j < kNblock; j++) {
|
| 130 |
-
CUTLASS_PRAGMA_UNROLL
|
| 131 |
-
for (int i = 0; i < kMblock; i++) {
|
| 132 |
-
int row = row_block + i;
|
| 133 |
-
int col = col_block + j;
|
| 134 |
-
|
| 135 |
-
if (row < M && col < N &&
|
| 136 |
-
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
| 137 |
-
(fill_mode_c == FillMode::kUpper && row <= col) )
|
| 138 |
-
) {
|
| 139 |
-
|
| 140 |
-
// A x B^T (Symmetric) or A x B^H (Hermitian)
|
| 141 |
-
// complex conjugation on operandB (b_t) is function of blas3 computation
|
| 142 |
-
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 143 |
-
ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
|
| 144 |
-
conj(tensor_b.at(MatrixCoord(col, k_block))) :
|
| 145 |
-
tensor_b.at(MatrixCoord(col, k_block));
|
| 146 |
-
|
| 147 |
-
ComputeType a_ik = ComputeType(a);
|
| 148 |
-
ComputeType b_jk = ComputeType(b_t);
|
| 149 |
-
|
| 150 |
-
// complex conjugation is a function of operand layouts
|
| 151 |
-
if (transform_a == ComplexTransform::kConjugate) {
|
| 152 |
-
a_ik = conj(a_ik);
|
| 153 |
-
}
|
| 154 |
-
// complex conjugation is a function of operand layouts
|
| 155 |
-
if (transform_b == ComplexTransform::kConjugate) {
|
| 156 |
-
b_jk = conj(b_jk);
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
| 160 |
-
|
| 161 |
-
// B x A^T (Symmetric) or B x A^H (Hermitian)
|
| 162 |
-
// complex conjugation on operandB (a_t) is function of blas3 computation
|
| 163 |
-
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
| 164 |
-
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
| 165 |
-
conj(tensor_a.at(MatrixCoord(col, k_block))):
|
| 166 |
-
tensor_a.at(MatrixCoord(col, k_block));
|
| 167 |
-
|
| 168 |
-
ComputeType b_ik = ComputeType(b);
|
| 169 |
-
ComputeType a_jk = ComputeType(a_t);
|
| 170 |
-
|
| 171 |
-
// complex conjugation here is a function of operand layouts
|
| 172 |
-
if (transform_b == ComplexTransform::kConjugate) {
|
| 173 |
-
b_ik = conj(b_ik);
|
| 174 |
-
}
|
| 175 |
-
// complex conjugation here is a function of operand layouts
|
| 176 |
-
if (transform_a == ComplexTransform::kConjugate) {
|
| 177 |
-
a_jk = conj(a_jk);
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 181 |
-
}
|
| 182 |
-
}
|
| 183 |
-
}
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
CUTLASS_PRAGMA_UNROLL
|
| 187 |
-
for (int j = 0; j < kNblock; j++) {
|
| 188 |
-
CUTLASS_PRAGMA_UNROLL
|
| 189 |
-
for (int i = 0; i < kMblock; i++) {
|
| 190 |
-
int row = row_block + i;
|
| 191 |
-
int col = col_block + j;
|
| 192 |
-
|
| 193 |
-
MatrixCoord coord = MatrixCoord(row, col);
|
| 194 |
-
|
| 195 |
-
if (row < M && col < N &&
|
| 196 |
-
((fill_mode_c == FillMode::kLower && row >= col) ||
|
| 197 |
-
(fill_mode_c == FillMode::kUpper && row <= col))
|
| 198 |
-
) {
|
| 199 |
-
|
| 200 |
-
ScalarType c = tensor_c.at(coord);
|
| 201 |
-
// The imaginary parts of the diagonal elements of
|
| 202 |
-
// a complex data type are assumed and set to zero
|
| 203 |
-
if (blas_mode == BlasMode::kHermitian) {
|
| 204 |
-
c = (row == col) ? real(c) : c;
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
tensor_d.at(coord) = convert_op(
|
| 208 |
-
alpha * ScalarType(accum[i][j]) +
|
| 209 |
-
beta * c);
|
| 210 |
-
}
|
| 211 |
-
}
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
|
| 215 |
-
tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
|
| 216 |
-
tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
|
| 217 |
-
tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
|
| 218 |
-
|
| 219 |
-
} // for (batch_idx)
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
} // namespace kernel
|
| 223 |
-
|
| 224 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 225 |
-
|
| 226 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 227 |
-
/// objects.
|
| 228 |
-
///
|
| 229 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 230 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 231 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 232 |
-
/// arguments explicitly.
|
| 233 |
-
template <
|
| 234 |
-
typename ElementA,
|
| 235 |
-
typename LayoutA,
|
| 236 |
-
typename ElementB,
|
| 237 |
-
typename LayoutB,
|
| 238 |
-
typename ElementC,
|
| 239 |
-
typename LayoutC,
|
| 240 |
-
typename ScalarType,
|
| 241 |
-
typename ComputeType,
|
| 242 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 243 |
-
typename InnerProductOp = multiply_add<ComputeType>
|
| 244 |
-
>
|
| 245 |
-
void Rank2KComplex(
|
| 246 |
-
gemm::GemmCoord problem_size,
|
| 247 |
-
ScalarType alpha,
|
| 248 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 249 |
-
ComplexTransform transform_a,
|
| 250 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 251 |
-
ComplexTransform transform_b,
|
| 252 |
-
ScalarType beta,
|
| 253 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 254 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 255 |
-
ComputeType initial_accum,
|
| 256 |
-
FillMode fill_mode_c,
|
| 257 |
-
BlasMode blas_mode,
|
| 258 |
-
int batch_count = 1,
|
| 259 |
-
int64_t batch_stride_A = 0,
|
| 260 |
-
int64_t batch_stride_B = 0,
|
| 261 |
-
int64_t batch_stride_C = 0,
|
| 262 |
-
int64_t batch_stride_D = 0) {
|
| 263 |
-
|
| 264 |
-
static_assert(
|
| 265 |
-
LayoutA::kRank == 2 &&
|
| 266 |
-
LayoutB::kRank == 2 &&
|
| 267 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 268 |
-
|
| 269 |
-
int const kMblock = 4;
|
| 270 |
-
int const kNblock = 4;
|
| 271 |
-
|
| 272 |
-
dim3 block(16, 8);
|
| 273 |
-
dim3 grid(
|
| 274 |
-
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
| 275 |
-
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
| 276 |
-
batch_count % std::numeric_limits<uint16_t>::max()
|
| 277 |
-
);
|
| 278 |
-
|
| 279 |
-
kernel::Rank2KComplex<
|
| 280 |
-
ElementA,
|
| 281 |
-
LayoutA,
|
| 282 |
-
ElementB,
|
| 283 |
-
LayoutB,
|
| 284 |
-
ElementC,
|
| 285 |
-
LayoutC,
|
| 286 |
-
ScalarType,
|
| 287 |
-
ComputeType,
|
| 288 |
-
ConvertOp,
|
| 289 |
-
InnerProductOp,
|
| 290 |
-
kMblock,
|
| 291 |
-
kNblock
|
| 292 |
-
><<< grid, block >>>(
|
| 293 |
-
problem_size,
|
| 294 |
-
alpha,
|
| 295 |
-
tensor_a,
|
| 296 |
-
transform_a,
|
| 297 |
-
tensor_b,
|
| 298 |
-
transform_b,
|
| 299 |
-
beta,
|
| 300 |
-
tensor_c,
|
| 301 |
-
tensor_d,
|
| 302 |
-
initial_accum,
|
| 303 |
-
fill_mode_c,
|
| 304 |
-
blas_mode,
|
| 305 |
-
batch_count,
|
| 306 |
-
batch_stride_A,
|
| 307 |
-
batch_stride_B,
|
| 308 |
-
batch_stride_C,
|
| 309 |
-
batch_stride_D
|
| 310 |
-
);
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 314 |
-
|
| 315 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 316 |
-
/// objects.
|
| 317 |
-
///
|
| 318 |
-
/// This assumes the accumulator type is the same type as the scalars.
|
| 319 |
-
template <
|
| 320 |
-
typename ElementA,
|
| 321 |
-
typename LayoutA,
|
| 322 |
-
typename ElementB,
|
| 323 |
-
typename LayoutB,
|
| 324 |
-
typename ElementC,
|
| 325 |
-
typename LayoutC,
|
| 326 |
-
typename ScalarType
|
| 327 |
-
>
|
| 328 |
-
void Rank2KComplex(
|
| 329 |
-
gemm::GemmCoord problem_size,
|
| 330 |
-
ScalarType alpha,
|
| 331 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 332 |
-
ComplexTransform transform_a,
|
| 333 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 334 |
-
ComplexTransform transform_b,
|
| 335 |
-
ScalarType beta,
|
| 336 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 337 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 338 |
-
FillMode fill_mode_c,
|
| 339 |
-
BlasMode blas_mode) {
|
| 340 |
-
|
| 341 |
-
Rank2KComplex(
|
| 342 |
-
problem_size, alpha,
|
| 343 |
-
tensor_a, transform_a,
|
| 344 |
-
tensor_b, transform_b,
|
| 345 |
-
beta, tensor_c, tensor_d,
|
| 346 |
-
ScalarType(0),
|
| 347 |
-
fill_mode_c,
|
| 348 |
-
blas_mode);
|
| 349 |
-
}
|
| 350 |
-
|
| 351 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 352 |
-
|
| 353 |
-
} // namespace device
|
| 354 |
-
} // namespace reference
|
| 355 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h
DELETED
|
@@ -1,250 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/* \file
|
| 32 |
-
\brief Defines host-side elementwise operations on TensorView.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
// Standard Library includes
|
| 37 |
-
#include <utility>
|
| 38 |
-
|
| 39 |
-
// Cutlass includes
|
| 40 |
-
#include "cutlass/cutlass.h"
|
| 41 |
-
#include "cutlass/relatively_equal.h"
|
| 42 |
-
|
| 43 |
-
#include "cutlass/util/distribution.h"
|
| 44 |
-
|
| 45 |
-
#include "tensor_foreach.h"
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
namespace reference {
|
| 49 |
-
namespace device {
|
| 50 |
-
|
| 51 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
-
|
| 53 |
-
namespace kernel {
|
| 54 |
-
|
| 55 |
-
template <typename Element>
|
| 56 |
-
__global__ void BlockCompareEqual(
|
| 57 |
-
int *equal,
|
| 58 |
-
Element const *ptr_A,
|
| 59 |
-
Element const *ptr_B,
|
| 60 |
-
size_t capacity) {
|
| 61 |
-
|
| 62 |
-
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
|
| 63 |
-
|
| 64 |
-
for (; idx < capacity; idx += gridDim.x * blockDim.x) {
|
| 65 |
-
|
| 66 |
-
Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
|
| 67 |
-
Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
|
| 68 |
-
|
| 69 |
-
if (a != b) {
|
| 70 |
-
*equal = 0;
|
| 71 |
-
|
| 72 |
-
return;
|
| 73 |
-
}
|
| 74 |
-
}
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
template <typename Element>
|
| 78 |
-
__global__ void BlockCompareRelativelyEqual(
|
| 79 |
-
int *equal,
|
| 80 |
-
Element const *ptr_A,
|
| 81 |
-
Element const *ptr_B,
|
| 82 |
-
size_t capacity,
|
| 83 |
-
Element epsilon,
|
| 84 |
-
Element nonzero_floor) {
|
| 85 |
-
|
| 86 |
-
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
|
| 87 |
-
|
| 88 |
-
for (; idx < capacity; idx += gridDim.x * blockDim.x) {
|
| 89 |
-
|
| 90 |
-
Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
|
| 91 |
-
Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
|
| 92 |
-
|
| 93 |
-
if (!relatively_equal(a, b, epsilon, nonzero_floor)) {
|
| 94 |
-
*equal = 0;
|
| 95 |
-
return;
|
| 96 |
-
}
|
| 97 |
-
}
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
} // namespace kernel
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 104 |
-
|
| 105 |
-
/// Performs a bit-level equality check between two blocks
|
| 106 |
-
template <typename Element>
|
| 107 |
-
bool BlockCompareEqual(
|
| 108 |
-
Element const *ptr_A,
|
| 109 |
-
Element const *ptr_B,
|
| 110 |
-
size_t capacity,
|
| 111 |
-
int grid_size = 0,
|
| 112 |
-
int block_size = 0,
|
| 113 |
-
cudaStream_t stream = nullptr) {
|
| 114 |
-
|
| 115 |
-
int equal_flag = 1;
|
| 116 |
-
int *device_equal_flag = nullptr;
|
| 117 |
-
|
| 118 |
-
if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
|
| 119 |
-
throw std::runtime_error("Failed to allocate device flag.");
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
if (cudaMemcpy(
|
| 123 |
-
device_equal_flag,
|
| 124 |
-
&equal_flag,
|
| 125 |
-
sizeof(int),
|
| 126 |
-
cudaMemcpyHostToDevice) != cudaSuccess) {
|
| 127 |
-
|
| 128 |
-
throw std::runtime_error("Failed to copy equality flag to device.");
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
if (!grid_size || !block_size) {
|
| 132 |
-
|
| 133 |
-
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 134 |
-
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 135 |
-
&grid_size,
|
| 136 |
-
&block_size,
|
| 137 |
-
reinterpret_cast<void const *>(kernel::BlockCompareEqual<Element>));
|
| 138 |
-
|
| 139 |
-
if (result != cudaSuccess) {
|
| 140 |
-
throw std::runtime_error("Failed to query occupancy.");
|
| 141 |
-
}
|
| 142 |
-
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 143 |
-
// single thread and reduces the impact of initialization overhead.
|
| 144 |
-
block_size = (block_size < 128 ? block_size : 128);
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
dim3 grid(grid_size, 1, 1);
|
| 148 |
-
dim3 block(block_size, 1, 1);
|
| 149 |
-
|
| 150 |
-
kernel::BlockCompareEqual<Element><<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity);
|
| 151 |
-
|
| 152 |
-
cudaStreamSynchronize(stream);
|
| 153 |
-
|
| 154 |
-
if (cudaMemcpy(
|
| 155 |
-
&equal_flag,
|
| 156 |
-
device_equal_flag,
|
| 157 |
-
sizeof(int),
|
| 158 |
-
cudaMemcpyDeviceToHost) != cudaSuccess) {
|
| 159 |
-
|
| 160 |
-
cudaFree(device_equal_flag);
|
| 161 |
-
|
| 162 |
-
throw std::runtime_error("Failed to copy equality flag from device.");
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
cudaFree(device_equal_flag);
|
| 166 |
-
|
| 167 |
-
return equal_flag;
|
| 168 |
-
}
|
| 169 |
-
|
| 170 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 171 |
-
|
| 172 |
-
/// Performs a bit-level equality check between two blocks
|
| 173 |
-
template <typename Element>
|
| 174 |
-
bool BlockCompareRelativelyEqual(
|
| 175 |
-
Element const *ptr_A,
|
| 176 |
-
Element const *ptr_B,
|
| 177 |
-
size_t capacity,
|
| 178 |
-
Element epsilon,
|
| 179 |
-
Element nonzero_floor,
|
| 180 |
-
int grid_size = 0,
|
| 181 |
-
int block_size = 0,
|
| 182 |
-
cudaStream_t stream = nullptr) {
|
| 183 |
-
|
| 184 |
-
int equal_flag = 1;
|
| 185 |
-
int *device_equal_flag = nullptr;
|
| 186 |
-
|
| 187 |
-
if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
|
| 188 |
-
throw std::runtime_error("Failed to allocate device flag.");
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
if (cudaMemcpy(
|
| 192 |
-
device_equal_flag,
|
| 193 |
-
&equal_flag,
|
| 194 |
-
sizeof(int),
|
| 195 |
-
cudaMemcpyHostToDevice) != cudaSuccess) {
|
| 196 |
-
|
| 197 |
-
throw std::runtime_error("Failed to copy equality flag to device.");
|
| 198 |
-
}
|
| 199 |
-
|
| 200 |
-
if (!grid_size || !block_size) {
|
| 201 |
-
|
| 202 |
-
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 203 |
-
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 204 |
-
&grid_size,
|
| 205 |
-
&block_size,
|
| 206 |
-
reinterpret_cast<void const *>(kernel::BlockCompareRelativelyEqual<Element>));
|
| 207 |
-
|
| 208 |
-
if (result != cudaSuccess) {
|
| 209 |
-
throw std::runtime_error("Failed to query occupancy.");
|
| 210 |
-
}
|
| 211 |
-
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 212 |
-
// single thread and reduces the impact of initialization overhead.
|
| 213 |
-
block_size = (block_size < 128 ? block_size : 128);
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
dim3 grid(grid_size, 1, 1);
|
| 217 |
-
dim3 block(block_size, 1, 1);
|
| 218 |
-
|
| 219 |
-
kernel::BlockCompareRelativelyEqual<Element><<< grid, block, 0, stream >>>(
|
| 220 |
-
device_equal_flag,
|
| 221 |
-
ptr_A,
|
| 222 |
-
ptr_B,
|
| 223 |
-
capacity,
|
| 224 |
-
epsilon,
|
| 225 |
-
nonzero_floor
|
| 226 |
-
);
|
| 227 |
-
|
| 228 |
-
cudaStreamSynchronize(stream);
|
| 229 |
-
|
| 230 |
-
if (cudaMemcpy(
|
| 231 |
-
&equal_flag,
|
| 232 |
-
device_equal_flag,
|
| 233 |
-
sizeof(int),
|
| 234 |
-
cudaMemcpyDeviceToHost) != cudaSuccess) {
|
| 235 |
-
|
| 236 |
-
cudaFree(device_equal_flag);
|
| 237 |
-
|
| 238 |
-
throw std::runtime_error("Failed to copy equality flag from device.");
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
cudaFree(device_equal_flag);
|
| 242 |
-
|
| 243 |
-
return equal_flag;
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 247 |
-
|
| 248 |
-
} // device
|
| 249 |
-
} // reference
|
| 250 |
-
} // cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h
DELETED
|
@@ -1,2075 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/* \file
|
| 32 |
-
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
|
| 33 |
-
in this header are not specialized for any particular data layout and are therefore not
|
| 34 |
-
intended to offer the best possible performance. Rather, they are intended to be generic
|
| 35 |
-
reference implementations to support the CUTLASS unit tests.
|
| 36 |
-
*/
|
| 37 |
-
|
| 38 |
-
#pragma once
|
| 39 |
-
|
| 40 |
-
#if !defined(__CUDACC_RTC__)
|
| 41 |
-
|
| 42 |
-
// Standard Library includes
|
| 43 |
-
#include <utility>
|
| 44 |
-
#include <cstdlib>
|
| 45 |
-
#include <cmath>
|
| 46 |
-
#include <type_traits>
|
| 47 |
-
#include <cstdint>
|
| 48 |
-
|
| 49 |
-
#endif
|
| 50 |
-
|
| 51 |
-
// CUDA includes
|
| 52 |
-
#include <curand_kernel.h>
|
| 53 |
-
|
| 54 |
-
// Cutlass includes
|
| 55 |
-
#include "cutlass/cutlass.h"
|
| 56 |
-
#include "cutlass/array.h"
|
| 57 |
-
#include "cutlass/complex.h"
|
| 58 |
-
#include "cutlass/tensor_view.h"
|
| 59 |
-
#include "cutlass/blas3.h"
|
| 60 |
-
#include "cutlass/numeric_types.h"
|
| 61 |
-
|
| 62 |
-
#include "cutlass/layout/vector.h"
|
| 63 |
-
|
| 64 |
-
#include "cutlass/util/reference/device/tensor_foreach.h"
|
| 65 |
-
#include "cutlass/util/distribution.h"
|
| 66 |
-
|
| 67 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 68 |
-
|
| 69 |
-
namespace cutlass {
|
| 70 |
-
namespace reference {
|
| 71 |
-
namespace device {
|
| 72 |
-
|
| 73 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 74 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 75 |
-
|
| 76 |
-
namespace detail {
|
| 77 |
-
|
| 78 |
-
template <typename FloatType>
|
| 79 |
-
CUTLASS_DEVICE
|
| 80 |
-
FloatType random_normal_float(curandState_t *state) {
|
| 81 |
-
return curand_normal(state);
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
template <>
|
| 85 |
-
CUTLASS_DEVICE
|
| 86 |
-
double random_normal_float<double>(curandState_t *state) {
|
| 87 |
-
return curand_normal_double(state);
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
template <typename FloatType>
|
| 91 |
-
CUTLASS_DEVICE
|
| 92 |
-
FloatType random_uniform_float(curandState_t *state) {
|
| 93 |
-
return curand_uniform(state);
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
template <>
|
| 97 |
-
CUTLASS_DEVICE
|
| 98 |
-
double random_uniform_float<double>(curandState_t *state) {
|
| 99 |
-
return curand_uniform_double(state);
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
template <typename Element>
|
| 103 |
-
struct RandomGaussianFunc {
|
| 104 |
-
|
| 105 |
-
using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type;
|
| 106 |
-
using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type;
|
| 107 |
-
|
| 108 |
-
/// Parameters structure
|
| 109 |
-
struct Params {
|
| 110 |
-
|
| 111 |
-
//
|
| 112 |
-
// Data members
|
| 113 |
-
//
|
| 114 |
-
|
| 115 |
-
uint64_t seed;
|
| 116 |
-
FloatType mean;
|
| 117 |
-
FloatType stddev;
|
| 118 |
-
int int_scale;
|
| 119 |
-
FloatType float_scale_up;
|
| 120 |
-
FloatType float_scale_down;
|
| 121 |
-
int exclude_zero; ///< If non-negative, excludes zeros
|
| 122 |
-
|
| 123 |
-
//
|
| 124 |
-
// Methods
|
| 125 |
-
//
|
| 126 |
-
|
| 127 |
-
/// Construction of Gaussian RNG functor.
|
| 128 |
-
Params(
|
| 129 |
-
uint64_t seed_ = 0,
|
| 130 |
-
Element mean_ = 0,
|
| 131 |
-
Element stddev_ = 1,
|
| 132 |
-
int int_scale_ = -1,
|
| 133 |
-
int exclude_zero_ = -1
|
| 134 |
-
):
|
| 135 |
-
seed(seed_),
|
| 136 |
-
mean(static_cast<FloatType>(mean_)),
|
| 137 |
-
stddev(static_cast<FloatType>(stddev_)),
|
| 138 |
-
int_scale(int_scale_),
|
| 139 |
-
exclude_zero(exclude_zero_) {
|
| 140 |
-
|
| 141 |
-
float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
|
| 142 |
-
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 143 |
-
}
|
| 144 |
-
};
|
| 145 |
-
|
| 146 |
-
//
|
| 147 |
-
// Data members
|
| 148 |
-
//
|
| 149 |
-
|
| 150 |
-
/// Parameters object
|
| 151 |
-
Params params;
|
| 152 |
-
|
| 153 |
-
/// RNG state object
|
| 154 |
-
curandState_t rng_state;
|
| 155 |
-
|
| 156 |
-
//
|
| 157 |
-
// Methods
|
| 158 |
-
//
|
| 159 |
-
|
| 160 |
-
/// Device-side initialization of RNG
|
| 161 |
-
CUTLASS_DEVICE
|
| 162 |
-
RandomGaussianFunc(Params const ¶ms): params(params) {
|
| 163 |
-
|
| 164 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 165 |
-
|
| 166 |
-
curand_init(params.seed, gtid, 0, &rng_state);
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
/// Compute random value and update RNG state
|
| 170 |
-
CUTLASS_DEVICE
|
| 171 |
-
Element operator()() {
|
| 172 |
-
|
| 173 |
-
FloatType rnd = random_normal_float<FloatType>(&rng_state);
|
| 174 |
-
rnd = params.mean + params.stddev * rnd;
|
| 175 |
-
|
| 176 |
-
Element result;
|
| 177 |
-
if (params.int_scale >= 0) {
|
| 178 |
-
rnd = FloatType(std::llround(rnd * params.float_scale_up));
|
| 179 |
-
result = Element(rnd * params.float_scale_down);
|
| 180 |
-
}
|
| 181 |
-
else {
|
| 182 |
-
result = Element(rnd);
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
if (params.exclude_zero >=0 && result == Element(0.0)) {
|
| 186 |
-
if (rnd > FloatType(0)) {
|
| 187 |
-
rnd += FloatType(1);
|
| 188 |
-
} else {
|
| 189 |
-
rnd -= FloatType(1);
|
| 190 |
-
}
|
| 191 |
-
result = Element(rnd);
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
return result;
|
| 195 |
-
}
|
| 196 |
-
};
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
template <typename Real>
|
| 200 |
-
struct RandomGaussianFunc<complex<Real>> {
|
| 201 |
-
|
| 202 |
-
using Element = complex<Real>;
|
| 203 |
-
using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type;
|
| 204 |
-
using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type;
|
| 205 |
-
|
| 206 |
-
/// Parameters structure
|
| 207 |
-
struct Params {
|
| 208 |
-
|
| 209 |
-
//
|
| 210 |
-
// Data members
|
| 211 |
-
//
|
| 212 |
-
|
| 213 |
-
uint64_t seed;
|
| 214 |
-
FloatType mean;
|
| 215 |
-
FloatType stddev;
|
| 216 |
-
int int_scale;
|
| 217 |
-
FloatType float_scale_up;
|
| 218 |
-
FloatType float_scale_down;
|
| 219 |
-
int exclude_zero; ///< If non-negative, excludes zeros
|
| 220 |
-
|
| 221 |
-
//
|
| 222 |
-
// Methods
|
| 223 |
-
//
|
| 224 |
-
|
| 225 |
-
/// Construction of Gaussian RNG functor.
|
| 226 |
-
Params(
|
| 227 |
-
uint64_t seed_ = 0,
|
| 228 |
-
Real mean_ = 0,
|
| 229 |
-
Real stddev_ = 1,
|
| 230 |
-
int int_scale_ = -1,
|
| 231 |
-
int exclude_zero_ = -1
|
| 232 |
-
):
|
| 233 |
-
seed(seed_),
|
| 234 |
-
mean(static_cast<FloatType>(mean_)),
|
| 235 |
-
stddev(static_cast<FloatType>(stddev_)),
|
| 236 |
-
int_scale(int_scale_),
|
| 237 |
-
exclude_zero(exclude_zero_) {
|
| 238 |
-
|
| 239 |
-
float_scale_up = FloatType(IntType(1) << int_scale);
|
| 240 |
-
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 241 |
-
}
|
| 242 |
-
};
|
| 243 |
-
|
| 244 |
-
//
|
| 245 |
-
// Data members
|
| 246 |
-
//
|
| 247 |
-
|
| 248 |
-
/// Parameters object
|
| 249 |
-
Params params;
|
| 250 |
-
|
| 251 |
-
/// RNG state object
|
| 252 |
-
curandState_t rng_state;
|
| 253 |
-
|
| 254 |
-
//
|
| 255 |
-
// Methods
|
| 256 |
-
//
|
| 257 |
-
|
| 258 |
-
/// Device-side initialization of RNG
|
| 259 |
-
CUTLASS_DEVICE
|
| 260 |
-
RandomGaussianFunc(Params const ¶ms): params(params) {
|
| 261 |
-
|
| 262 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 263 |
-
|
| 264 |
-
curand_init(params.seed, gtid, 0, &rng_state);
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
/// Compute random value and update RNG state
|
| 268 |
-
CUTLASS_DEVICE
|
| 269 |
-
Element operator()() {
|
| 270 |
-
|
| 271 |
-
FloatType rnd_r = random_normal_float<FloatType>(&rng_state);
|
| 272 |
-
FloatType rnd_i = random_normal_float<FloatType>(&rng_state);
|
| 273 |
-
rnd_r = params.mean + params.stddev * rnd_r;
|
| 274 |
-
rnd_i = params.mean + params.stddev * rnd_i;
|
| 275 |
-
|
| 276 |
-
Element result;
|
| 277 |
-
if (params.int_scale >= 0) {
|
| 278 |
-
rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
|
| 279 |
-
rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
|
| 280 |
-
|
| 281 |
-
result = {
|
| 282 |
-
Real(rnd_r * params.float_scale_down),
|
| 283 |
-
Real(rnd_i * params.float_scale_down)
|
| 284 |
-
};
|
| 285 |
-
}
|
| 286 |
-
else {
|
| 287 |
-
result = Element(Real(rnd_r), Real(rnd_i));
|
| 288 |
-
}
|
| 289 |
-
|
| 290 |
-
if (params.exclude_zero >= 0 &&
|
| 291 |
-
result.real() == Real(0.0) &&
|
| 292 |
-
result.imag() == Real(0.0)) {
|
| 293 |
-
|
| 294 |
-
if (rnd_r > FloatType(0)) {
|
| 295 |
-
rnd_r += FloatType(1);
|
| 296 |
-
} else {
|
| 297 |
-
rnd_r -= FloatType(1);
|
| 298 |
-
}
|
| 299 |
-
result = Element(Real(rnd_r), Real(rnd_i));
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
return result;
|
| 303 |
-
}
|
| 304 |
-
};
|
| 305 |
-
|
| 306 |
-
/// Computes a random Gaussian distribution
|
| 307 |
-
template <
|
| 308 |
-
typename Element, ///< Element type
|
| 309 |
-
typename Layout> ///< Layout function
|
| 310 |
-
struct TensorFillRandomGaussianFunc {
|
| 311 |
-
|
| 312 |
-
/// View type
|
| 313 |
-
using TensorView = TensorView<Element, Layout>;
|
| 314 |
-
|
| 315 |
-
/// Scalar type
|
| 316 |
-
typedef typename TensorView::Element T;
|
| 317 |
-
|
| 318 |
-
/// Coordinate in tensor's index space
|
| 319 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 320 |
-
|
| 321 |
-
using RandomFunc = RandomGaussianFunc<Element>;
|
| 322 |
-
|
| 323 |
-
/// Parameters structure
|
| 324 |
-
struct Params {
|
| 325 |
-
|
| 326 |
-
//
|
| 327 |
-
// Data members
|
| 328 |
-
//
|
| 329 |
-
|
| 330 |
-
TensorView view;
|
| 331 |
-
typename RandomFunc::Params random;
|
| 332 |
-
|
| 333 |
-
//
|
| 334 |
-
// Methods
|
| 335 |
-
//
|
| 336 |
-
|
| 337 |
-
/// Construction of Gaussian RNG functor.
|
| 338 |
-
Params(
|
| 339 |
-
TensorView view_ = TensorView(),
|
| 340 |
-
typename RandomFunc::Params random_ = typename RandomFunc::Params()
|
| 341 |
-
):
|
| 342 |
-
view(view_), random(random_) {
|
| 343 |
-
|
| 344 |
-
}
|
| 345 |
-
};
|
| 346 |
-
|
| 347 |
-
//
|
| 348 |
-
// Data members
|
| 349 |
-
//
|
| 350 |
-
|
| 351 |
-
Params params;
|
| 352 |
-
RandomFunc random;
|
| 353 |
-
|
| 354 |
-
//
|
| 355 |
-
// Methods
|
| 356 |
-
//
|
| 357 |
-
|
| 358 |
-
/// Device-side initialization of RNG
|
| 359 |
-
CUTLASS_DEVICE
|
| 360 |
-
TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) {
|
| 361 |
-
|
| 362 |
-
}
|
| 363 |
-
|
| 364 |
-
/// Compute random value and update RNG state
|
| 365 |
-
CUTLASS_DEVICE
|
| 366 |
-
void operator()(TensorCoord const &coord) {
|
| 367 |
-
|
| 368 |
-
params.view.at(coord) = random();
|
| 369 |
-
}
|
| 370 |
-
};
|
| 371 |
-
|
| 372 |
-
} // namespace detail
|
| 373 |
-
|
| 374 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 375 |
-
|
| 376 |
-
/// Fills a tensor with random values with a Gaussian distribution.
|
| 377 |
-
template <
|
| 378 |
-
typename Element, ///< Element type
|
| 379 |
-
typename Layout> ///< Layout function
|
| 380 |
-
void TensorFillRandomGaussian(
|
| 381 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 382 |
-
uint64_t seed, ///< seed for RNG
|
| 383 |
-
typename RealType<Element>::Type mean = Element(0), ///< Gaussian distribution's mean
|
| 384 |
-
typename RealType<Element>::Type stddev = Element(1), ///< Gaussian distribution's standard deviation
|
| 385 |
-
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 386 |
-
/// are not truncated to zero. Permits reducing precision of
|
| 387 |
-
/// data.
|
| 388 |
-
int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
|
| 389 |
-
cudaStream_t stream = nullptr) {
|
| 390 |
-
|
| 391 |
-
using RandomFunc = detail::RandomGaussianFunc<Element>;
|
| 392 |
-
using Func = detail::TensorFillRandomGaussianFunc<Element, Layout>;
|
| 393 |
-
using Params = typename Func::Params;
|
| 394 |
-
|
| 395 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 396 |
-
view.extent(),
|
| 397 |
-
Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)),
|
| 398 |
-
/*grid_size*/0, /*block_size*/0,
|
| 399 |
-
stream
|
| 400 |
-
);
|
| 401 |
-
}
|
| 402 |
-
|
| 403 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 404 |
-
|
| 405 |
-
/// Fills a tensor with random values with a Gaussian distribution.
|
| 406 |
-
template <typename Element> ///< Element type
|
| 407 |
-
void BlockFillRandomGaussian(
|
| 408 |
-
Element *ptr,
|
| 409 |
-
size_t capacity,
|
| 410 |
-
uint64_t seed, ///< seed for RNG
|
| 411 |
-
typename RealType<Element>::Type mean, ///< Gaussian distribution's mean
|
| 412 |
-
typename RealType<Element>::Type stddev, ///< Gaussian distribution's standard deviation
|
| 413 |
-
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 414 |
-
/// are not truncated to zero. Permits reducing precision of
|
| 415 |
-
/// data.
|
| 416 |
-
cudaStream_t stream = nullptr) {
|
| 417 |
-
|
| 418 |
-
using RandomFunc = detail::RandomGaussianFunc<Element>;
|
| 419 |
-
|
| 420 |
-
typename RandomFunc::Params params(seed, mean, stddev, bits);
|
| 421 |
-
|
| 422 |
-
BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
|
| 423 |
-
}
|
| 424 |
-
|
| 425 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 426 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 427 |
-
|
| 428 |
-
namespace detail {
|
| 429 |
-
|
| 430 |
-
/// Computes a random uniform distribution
|
| 431 |
-
template <typename Element> ///< Element type
|
| 432 |
-
struct RandomUniformFunc {
|
| 433 |
-
|
| 434 |
-
using FloatType = typename std::conditional<
|
| 435 |
-
(sizeof(Element) > 4),
|
| 436 |
-
double,
|
| 437 |
-
float>::type;
|
| 438 |
-
|
| 439 |
-
using IntType = typename std::conditional<
|
| 440 |
-
(sizeof(Element) > 4),
|
| 441 |
-
int64_t,
|
| 442 |
-
int>::type;
|
| 443 |
-
|
| 444 |
-
/// Parameters structure
|
| 445 |
-
struct Params {
|
| 446 |
-
|
| 447 |
-
//
|
| 448 |
-
// Data members
|
| 449 |
-
//
|
| 450 |
-
|
| 451 |
-
uint64_t seed;
|
| 452 |
-
FloatType range;
|
| 453 |
-
FloatType max;
|
| 454 |
-
int int_scale;
|
| 455 |
-
double pnan;
|
| 456 |
-
FloatType float_scale_up;
|
| 457 |
-
FloatType float_scale_down;
|
| 458 |
-
int exclude_zero; ///< If non-negative, excludes zeros
|
| 459 |
-
|
| 460 |
-
/// Default ctor
|
| 461 |
-
CUTLASS_HOST_DEVICE
|
| 462 |
-
Params() { }
|
| 463 |
-
|
| 464 |
-
//
|
| 465 |
-
// Methods
|
| 466 |
-
//
|
| 467 |
-
|
| 468 |
-
/// Construction of Gaussian RNG functor.
|
| 469 |
-
Params(
|
| 470 |
-
uint64_t seed_ = 0,
|
| 471 |
-
Element max_ = 1,
|
| 472 |
-
Element min = 0,
|
| 473 |
-
int int_scale_ = -1,
|
| 474 |
-
double pnan_ = 0,
|
| 475 |
-
int exclude_zero_ = -1
|
| 476 |
-
):
|
| 477 |
-
seed(seed_),
|
| 478 |
-
range(static_cast<FloatType>(max_) - static_cast<FloatType>(min)),
|
| 479 |
-
max(static_cast<FloatType>(max_)),
|
| 480 |
-
int_scale(int_scale_),
|
| 481 |
-
pnan(pnan_),
|
| 482 |
-
exclude_zero(exclude_zero_) {
|
| 483 |
-
|
| 484 |
-
float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
|
| 485 |
-
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 486 |
-
|
| 487 |
-
// Handle cases where min = 0 or max = 0 for excluding zeros
|
| 488 |
-
if (exclude_zero >= 0) {
|
| 489 |
-
range = (min == Element(0)) ? range - FloatType(1): range;
|
| 490 |
-
max = (max_ == Element(0)) ? max - FloatType(1): max;
|
| 491 |
-
}
|
| 492 |
-
}
|
| 493 |
-
};
|
| 494 |
-
|
| 495 |
-
//
|
| 496 |
-
// Data members
|
| 497 |
-
//
|
| 498 |
-
|
| 499 |
-
/// Parameters object
|
| 500 |
-
Params params;
|
| 501 |
-
|
| 502 |
-
/// RNG state object
|
| 503 |
-
curandState_t rng_state;
|
| 504 |
-
|
| 505 |
-
//
|
| 506 |
-
// Methods
|
| 507 |
-
//
|
| 508 |
-
|
| 509 |
-
/// Device-side initialization of RNG
|
| 510 |
-
CUTLASS_DEVICE
|
| 511 |
-
RandomUniformFunc(Params const ¶ms): params(params) {
|
| 512 |
-
|
| 513 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 514 |
-
|
| 515 |
-
curand_init(params.seed, gtid, 0, &rng_state);
|
| 516 |
-
}
|
| 517 |
-
|
| 518 |
-
/// Compute random value and update RNG state
|
| 519 |
-
CUTLASS_DEVICE
|
| 520 |
-
Element operator()() {
|
| 521 |
-
|
| 522 |
-
// Draw random float in [0.0, 1.0] to determine if element should be NaN.
|
| 523 |
-
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 524 |
-
if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
|
| 525 |
-
return Element(NAN);
|
| 526 |
-
}
|
| 527 |
-
}
|
| 528 |
-
|
| 529 |
-
FloatType rnd = random_uniform_float<FloatType>(&rng_state);
|
| 530 |
-
rnd = params.max - params.range * rnd;
|
| 531 |
-
|
| 532 |
-
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 533 |
-
// testing
|
| 534 |
-
Element result;
|
| 535 |
-
|
| 536 |
-
if (params.int_scale >= 0) {
|
| 537 |
-
rnd = FloatType(std::llround(rnd * params.float_scale_up));
|
| 538 |
-
result = Element(rnd * params.float_scale_down);
|
| 539 |
-
}
|
| 540 |
-
else {
|
| 541 |
-
result = Element(rnd);
|
| 542 |
-
}
|
| 543 |
-
|
| 544 |
-
if (params.exclude_zero >=0 && result == Element(0.0)) {
|
| 545 |
-
if (rnd > FloatType(0)) {
|
| 546 |
-
rnd = std::min(params.max, rnd + FloatType(1));
|
| 547 |
-
} else {
|
| 548 |
-
rnd = std::max((params.max - params.range), rnd - FloatType(1));
|
| 549 |
-
}
|
| 550 |
-
result = Element(rnd);
|
| 551 |
-
}
|
| 552 |
-
|
| 553 |
-
return result;
|
| 554 |
-
}
|
| 555 |
-
};
|
| 556 |
-
|
| 557 |
-
/// Computes a random Gaussian distribution
|
| 558 |
-
template <typename Real>
|
| 559 |
-
struct RandomUniformFunc<complex<Real>> {
|
| 560 |
-
|
| 561 |
-
using Element = complex<Real>;
|
| 562 |
-
|
| 563 |
-
using FloatType = typename std::conditional<
|
| 564 |
-
(sizeof(Real) > 4),
|
| 565 |
-
double,
|
| 566 |
-
float>::type;
|
| 567 |
-
|
| 568 |
-
using IntType = typename std::conditional<
|
| 569 |
-
(sizeof(Real) > 4),
|
| 570 |
-
int64_t,
|
| 571 |
-
int>::type;
|
| 572 |
-
|
| 573 |
-
/// Parameters structure
|
| 574 |
-
struct Params {
|
| 575 |
-
|
| 576 |
-
//
|
| 577 |
-
// Data members
|
| 578 |
-
//
|
| 579 |
-
|
| 580 |
-
uint64_t seed;
|
| 581 |
-
FloatType range;
|
| 582 |
-
FloatType min;
|
| 583 |
-
int int_scale;
|
| 584 |
-
double pnan;
|
| 585 |
-
FloatType float_scale_up;
|
| 586 |
-
FloatType float_scale_down;
|
| 587 |
-
int exclude_zero; ///< If non-negative, excludes zeros
|
| 588 |
-
|
| 589 |
-
/// Default ctor
|
| 590 |
-
CUTLASS_HOST_DEVICE
|
| 591 |
-
Params() { }
|
| 592 |
-
|
| 593 |
-
//
|
| 594 |
-
// Methods
|
| 595 |
-
//
|
| 596 |
-
|
| 597 |
-
/// Construction of Gaussian RNG functor.
|
| 598 |
-
Params(
|
| 599 |
-
uint64_t seed_ = 0,
|
| 600 |
-
FloatType max = 1,
|
| 601 |
-
FloatType min_ = 0,
|
| 602 |
-
int int_scale_ = -1,
|
| 603 |
-
double pnan_ = 0,
|
| 604 |
-
int exclude_zero_ = -1
|
| 605 |
-
):
|
| 606 |
-
seed(seed_),
|
| 607 |
-
range(static_cast<FloatType>(max - min_)),
|
| 608 |
-
min(static_cast<FloatType>(min_)),
|
| 609 |
-
int_scale(int_scale_),
|
| 610 |
-
pnan(pnan_),
|
| 611 |
-
exclude_zero(exclude_zero_) {
|
| 612 |
-
|
| 613 |
-
float_scale_up = FloatType(IntType(1) << int_scale);
|
| 614 |
-
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 615 |
-
|
| 616 |
-
// Handle cases where min = 0 or max = 0 for excluding zeros
|
| 617 |
-
if (exclude_zero >= 0) {
|
| 618 |
-
min = (min == FloatType(0)) ? min + FloatType(1): min;
|
| 619 |
-
range = (max == FloatType(0)) ? range - FloatType(1): range;
|
| 620 |
-
}
|
| 621 |
-
}
|
| 622 |
-
};
|
| 623 |
-
|
| 624 |
-
//
|
| 625 |
-
// Data members
|
| 626 |
-
//
|
| 627 |
-
|
| 628 |
-
/// Parameters object
|
| 629 |
-
Params params;
|
| 630 |
-
|
| 631 |
-
/// RNG state object
|
| 632 |
-
curandState_t rng_state;
|
| 633 |
-
|
| 634 |
-
//
|
| 635 |
-
// Methods
|
| 636 |
-
//
|
| 637 |
-
|
| 638 |
-
/// Device-side initialization of RNG
|
| 639 |
-
CUTLASS_DEVICE
|
| 640 |
-
RandomUniformFunc(Params const ¶ms): params(params) {
|
| 641 |
-
|
| 642 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 643 |
-
|
| 644 |
-
curand_init(params.seed, gtid, 0, &rng_state);
|
| 645 |
-
}
|
| 646 |
-
|
| 647 |
-
/// Compute random value and update RNG state
|
| 648 |
-
CUTLASS_DEVICE
|
| 649 |
-
Element operator()() {
|
| 650 |
-
|
| 651 |
-
// Draw random float in [0.0, 1.0] to determine if element should be NaN.
|
| 652 |
-
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 653 |
-
if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
|
| 654 |
-
return Element(Real(NAN), Real(NAN));
|
| 655 |
-
}
|
| 656 |
-
}
|
| 657 |
-
|
| 658 |
-
FloatType rnd_r = random_uniform_float<FloatType>(&rng_state);
|
| 659 |
-
FloatType rnd_i = random_uniform_float<FloatType>(&rng_state);
|
| 660 |
-
|
| 661 |
-
rnd_r = params.min + params.range * rnd_r;
|
| 662 |
-
rnd_i = params.min + params.range * rnd_i;
|
| 663 |
-
|
| 664 |
-
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 665 |
-
// testing
|
| 666 |
-
Element result;
|
| 667 |
-
|
| 668 |
-
if (params.int_scale >= 0) {
|
| 669 |
-
rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
|
| 670 |
-
rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
|
| 671 |
-
|
| 672 |
-
result = {
|
| 673 |
-
Real(rnd_r * params.float_scale_down),
|
| 674 |
-
Real(rnd_i * params.float_scale_down)
|
| 675 |
-
};
|
| 676 |
-
}
|
| 677 |
-
else {
|
| 678 |
-
result = Element(Real(rnd_r), Real(rnd_i));
|
| 679 |
-
}
|
| 680 |
-
|
| 681 |
-
if (params.exclude_zero >= 0 &&
|
| 682 |
-
result.real() == Real(0.0) &&
|
| 683 |
-
result.imag() == Real(0.0)) {
|
| 684 |
-
|
| 685 |
-
if (rnd_r > FloatType(0)) {
|
| 686 |
-
rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1));
|
| 687 |
-
} else {
|
| 688 |
-
rnd_r = std::max((params.min), rnd_r - FloatType(1));
|
| 689 |
-
}
|
| 690 |
-
result = Element(Real(rnd_r), Real(rnd_i));
|
| 691 |
-
}
|
| 692 |
-
|
| 693 |
-
return result;
|
| 694 |
-
}
|
| 695 |
-
};
|
| 696 |
-
|
| 697 |
-
/// Computes a random uniform distribution
|
| 698 |
-
template <
|
| 699 |
-
typename Element, ///< Element type
|
| 700 |
-
typename Layout> ///< Layout function
|
| 701 |
-
struct TensorFillRandomUniformFunc {
|
| 702 |
-
|
| 703 |
-
/// View type
|
| 704 |
-
using TensorView = TensorView<Element, Layout>;
|
| 705 |
-
|
| 706 |
-
/// Scalar type
|
| 707 |
-
typedef typename TensorView::Element T;
|
| 708 |
-
|
| 709 |
-
/// Coordinate in tensor's index space
|
| 710 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 711 |
-
|
| 712 |
-
using RandomFunc = RandomUniformFunc<Element>;
|
| 713 |
-
|
| 714 |
-
/// Parameters structure
|
| 715 |
-
struct Params {
|
| 716 |
-
|
| 717 |
-
//
|
| 718 |
-
// Data members
|
| 719 |
-
//
|
| 720 |
-
|
| 721 |
-
TensorView view;
|
| 722 |
-
typename RandomFunc::Params random;
|
| 723 |
-
|
| 724 |
-
/// Default ctor
|
| 725 |
-
CUTLASS_HOST_DEVICE
|
| 726 |
-
Params() { }
|
| 727 |
-
|
| 728 |
-
//
|
| 729 |
-
// Methods
|
| 730 |
-
//
|
| 731 |
-
|
| 732 |
-
/// Construction of Gaussian RNG functor.
|
| 733 |
-
Params(
|
| 734 |
-
TensorView view_ = TensorView(),
|
| 735 |
-
typename RandomFunc::Params random_ = RandomFunc::Params()
|
| 736 |
-
):
|
| 737 |
-
view(view_), random(random_) {
|
| 738 |
-
|
| 739 |
-
}
|
| 740 |
-
};
|
| 741 |
-
|
| 742 |
-
//
|
| 743 |
-
// Data members
|
| 744 |
-
//
|
| 745 |
-
|
| 746 |
-
Params params;
|
| 747 |
-
RandomFunc random;
|
| 748 |
-
|
| 749 |
-
//
|
| 750 |
-
// Methods
|
| 751 |
-
//
|
| 752 |
-
|
| 753 |
-
/// Device-side initialization of RNG
|
| 754 |
-
CUTLASS_DEVICE
|
| 755 |
-
TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) {
|
| 756 |
-
}
|
| 757 |
-
|
| 758 |
-
/// Compute random value and update RNG state
|
| 759 |
-
CUTLASS_DEVICE
|
| 760 |
-
void operator()(TensorCoord const &coord) {
|
| 761 |
-
|
| 762 |
-
params.view.at(coord) = random();
|
| 763 |
-
}
|
| 764 |
-
};
|
| 765 |
-
|
| 766 |
-
} // namespace detail
|
| 767 |
-
|
| 768 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 769 |
-
|
| 770 |
-
/// Fills a tensor with random values with a uniform random distribution.
|
| 771 |
-
template <
|
| 772 |
-
typename Element, ///< Element type
|
| 773 |
-
typename Layout> ///< Layout function
|
| 774 |
-
void TensorFillRandomUniform(
|
| 775 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 776 |
-
uint64_t seed, ///< seed for RNG
|
| 777 |
-
typename RealType<Element>::Type max = Element(1), ///< upper bound of distribution
|
| 778 |
-
typename RealType<Element>::Type min = Element(0), ///< lower bound for distribution
|
| 779 |
-
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 780 |
-
/// are not truncated to zero. Permits reducing precision of
|
| 781 |
-
/// data.
|
| 782 |
-
double pnan = 0, ///< Percentage of NaN elements.
|
| 783 |
-
int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
|
| 784 |
-
cudaStream_t stream = nullptr) {
|
| 785 |
-
|
| 786 |
-
using RandomFunc = detail::RandomUniformFunc<Element>;
|
| 787 |
-
using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
|
| 788 |
-
using Params = typename Func::Params;
|
| 789 |
-
|
| 790 |
-
typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero);
|
| 791 |
-
|
| 792 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 793 |
-
view.extent(),
|
| 794 |
-
Params(view, random),
|
| 795 |
-
/*grid_size*/0, /*block_size*/0,
|
| 796 |
-
stream
|
| 797 |
-
);
|
| 798 |
-
}
|
| 799 |
-
|
| 800 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 801 |
-
|
| 802 |
-
/// Fills a tensor with random values with a uniform random distribution.
|
| 803 |
-
template <typename Element>
|
| 804 |
-
void BlockFillRandomUniform(
|
| 805 |
-
Element *ptr,
|
| 806 |
-
size_t capacity,
|
| 807 |
-
uint64_t seed, ///< seed for RNG
|
| 808 |
-
typename RealType<Element>::Type max, ///< upper bound of distribution
|
| 809 |
-
typename RealType<Element>::Type min, ///< lower bound for distribution
|
| 810 |
-
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 811 |
-
/// are not truncated to zero. Permits reducing precision of
|
| 812 |
-
/// data.
|
| 813 |
-
double pnan = 0, ///< Percentage of NaN elements.
|
| 814 |
-
cudaStream_t stream = nullptr) {
|
| 815 |
-
|
| 816 |
-
using RandomFunc = detail::RandomUniformFunc<Element>;
|
| 817 |
-
|
| 818 |
-
typename RandomFunc::Params params(seed, max, min, bits, pnan);
|
| 819 |
-
|
| 820 |
-
BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
|
| 821 |
-
}
|
| 822 |
-
|
| 823 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 824 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 825 |
-
|
| 826 |
-
namespace detail {
|
| 827 |
-
|
| 828 |
-
/// Computes a random sparse meta
|
| 829 |
-
template <typename Element> ///< Element type
|
| 830 |
-
struct RandomSparseMetaFunc {
|
| 831 |
-
|
| 832 |
-
using FloatType = float;
|
| 833 |
-
|
| 834 |
-
using IntType = int32_t;
|
| 835 |
-
|
| 836 |
-
/// Parameters structure
|
| 837 |
-
struct Params {
|
| 838 |
-
|
| 839 |
-
//
|
| 840 |
-
// Data members
|
| 841 |
-
//
|
| 842 |
-
|
| 843 |
-
uint64_t seed;
|
| 844 |
-
FloatType range;
|
| 845 |
-
int MetaSizeInBits;
|
| 846 |
-
|
| 847 |
-
/// Default ctor
|
| 848 |
-
CUTLASS_HOST_DEVICE
|
| 849 |
-
Params() { }
|
| 850 |
-
|
| 851 |
-
//
|
| 852 |
-
// Methods
|
| 853 |
-
//
|
| 854 |
-
|
| 855 |
-
/// Construction of Gaussian RNG functor.
|
| 856 |
-
Params(
|
| 857 |
-
uint64_t seed_ = 0,
|
| 858 |
-
int MetaSizeInBits_ = 2
|
| 859 |
-
):
|
| 860 |
-
seed(seed_),
|
| 861 |
-
MetaSizeInBits(MetaSizeInBits_) {
|
| 862 |
-
if (MetaSizeInBits_ == 2) {
|
| 863 |
-
range = 6;
|
| 864 |
-
}
|
| 865 |
-
else if (MetaSizeInBits_ == 4) {
|
| 866 |
-
range = 2;
|
| 867 |
-
}
|
| 868 |
-
else {
|
| 869 |
-
throw std::invalid_argument("Invalid MetaSizeInBits");
|
| 870 |
-
}
|
| 871 |
-
}
|
| 872 |
-
};
|
| 873 |
-
|
| 874 |
-
//
|
| 875 |
-
// Data members
|
| 876 |
-
//
|
| 877 |
-
|
| 878 |
-
/// Parameters object
|
| 879 |
-
Params params;
|
| 880 |
-
|
| 881 |
-
/// RNG state object
|
| 882 |
-
curandState_t rng_state;
|
| 883 |
-
|
| 884 |
-
//
|
| 885 |
-
// Methods
|
| 886 |
-
//
|
| 887 |
-
|
| 888 |
-
/// Device-side initialization of RNG
|
| 889 |
-
CUTLASS_DEVICE
|
| 890 |
-
RandomSparseMetaFunc(Params const ¶ms): params(params) {
|
| 891 |
-
|
| 892 |
-
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 893 |
-
|
| 894 |
-
curand_init(params.seed, gtid, 0, &rng_state);
|
| 895 |
-
}
|
| 896 |
-
|
| 897 |
-
/// Compute random value and update RNG state
|
| 898 |
-
CUTLASS_DEVICE
|
| 899 |
-
Element operator()() {
|
| 900 |
-
Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe};
|
| 901 |
-
Element TwoToOneMeta[2] = {0x4, 0xe};
|
| 902 |
-
|
| 903 |
-
Element *MetaArray =
|
| 904 |
-
(params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta;
|
| 905 |
-
|
| 906 |
-
Element result = 0x0;
|
| 907 |
-
|
| 908 |
-
CUTLASS_PRAGMA_UNROLL
|
| 909 |
-
for (int i = 0; i < cutlass::sizeof_bits<Element>::value / 4; ++i) {
|
| 910 |
-
FloatType rnd = random_uniform_float<FloatType>(&rng_state);
|
| 911 |
-
rnd = params.range * rnd;
|
| 912 |
-
Element meta = MetaArray[(int)rnd];
|
| 913 |
-
|
| 914 |
-
result = (Element)(result | ((Element)(meta << (i * 4))));
|
| 915 |
-
}
|
| 916 |
-
|
| 917 |
-
return result;
|
| 918 |
-
}
|
| 919 |
-
};
|
| 920 |
-
|
| 921 |
-
/// Computes a random Gaussian distribution
|
| 922 |
-
template <
|
| 923 |
-
typename Element, ///< Element type
|
| 924 |
-
typename Layout> ///< Layout function
|
| 925 |
-
struct TensorFillRandomSparseMetaFunc {
|
| 926 |
-
|
| 927 |
-
/// View type
|
| 928 |
-
using TensorView = TensorView<Element, Layout>;
|
| 929 |
-
|
| 930 |
-
/// Scalar type
|
| 931 |
-
typedef typename TensorView::Element T;
|
| 932 |
-
|
| 933 |
-
/// Coordinate in tensor's index space
|
| 934 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 935 |
-
|
| 936 |
-
using RandomFunc = RandomSparseMetaFunc<Element>;
|
| 937 |
-
|
| 938 |
-
/// Parameters structure
|
| 939 |
-
struct Params {
|
| 940 |
-
|
| 941 |
-
//
|
| 942 |
-
// Data members
|
| 943 |
-
//
|
| 944 |
-
|
| 945 |
-
TensorView view;
|
| 946 |
-
typename RandomFunc::Params random;
|
| 947 |
-
|
| 948 |
-
/// Default ctor
|
| 949 |
-
CUTLASS_HOST_DEVICE
|
| 950 |
-
Params() { }
|
| 951 |
-
|
| 952 |
-
//
|
| 953 |
-
// Methods
|
| 954 |
-
//
|
| 955 |
-
|
| 956 |
-
/// Construction of Gaussian RNG functor.
|
| 957 |
-
Params(
|
| 958 |
-
TensorView view_ = TensorView(),
|
| 959 |
-
typename RandomFunc::Params random_ = RandomFunc::Params()
|
| 960 |
-
):
|
| 961 |
-
view(view_), random(random_) {
|
| 962 |
-
|
| 963 |
-
}
|
| 964 |
-
};
|
| 965 |
-
|
| 966 |
-
//
|
| 967 |
-
// Data members
|
| 968 |
-
//
|
| 969 |
-
|
| 970 |
-
Params params;
|
| 971 |
-
RandomFunc random;
|
| 972 |
-
|
| 973 |
-
//
|
| 974 |
-
// Methods
|
| 975 |
-
//
|
| 976 |
-
|
| 977 |
-
/// Device-side initialization of RNG
|
| 978 |
-
CUTLASS_DEVICE
|
| 979 |
-
TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) {
|
| 980 |
-
}
|
| 981 |
-
|
| 982 |
-
/// Compute random value and update RNG state
|
| 983 |
-
CUTLASS_DEVICE
|
| 984 |
-
void operator()(TensorCoord const &coord) {
|
| 985 |
-
|
| 986 |
-
params.view.at(coord) = random();
|
| 987 |
-
}
|
| 988 |
-
};
|
| 989 |
-
|
| 990 |
-
} // namespace detail
|
| 991 |
-
|
| 992 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 993 |
-
|
| 994 |
-
/// Fills a tensor with random values with a uniform random distribution.
|
| 995 |
-
template <
|
| 996 |
-
typename Element, ///< Element type
|
| 997 |
-
typename Layout> ///< Layout function
|
| 998 |
-
void TensorFillRandomSparseMeta(
|
| 999 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1000 |
-
uint64_t seed, ///< seed for RNG
|
| 1001 |
-
int MetaSizeInBits = 2, ///< meta data size
|
| 1002 |
-
cudaStream_t stream = nullptr) {
|
| 1003 |
-
|
| 1004 |
-
using RandomFunc = detail::RandomSparseMetaFunc<Element>;
|
| 1005 |
-
using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
|
| 1006 |
-
using Params = typename Func::Params;
|
| 1007 |
-
|
| 1008 |
-
typename RandomFunc::Params random(seed, MetaSizeInBits);
|
| 1009 |
-
|
| 1010 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1011 |
-
view.extent(),
|
| 1012 |
-
Params(view, random),
|
| 1013 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1014 |
-
stream
|
| 1015 |
-
);
|
| 1016 |
-
}
|
| 1017 |
-
|
| 1018 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1019 |
-
|
| 1020 |
-
/// Fills a tensor with random values with a uniform random distribution.
|
| 1021 |
-
template <typename Element>
|
| 1022 |
-
void BlockFillRandomSparseMeta(
|
| 1023 |
-
Element *ptr,
|
| 1024 |
-
size_t capacity,
|
| 1025 |
-
uint64_t seed, ///< seed for RNG
|
| 1026 |
-
int MetaSizeInBits = 2, ///< meta data size
|
| 1027 |
-
cudaStream_t stream = nullptr) {
|
| 1028 |
-
|
| 1029 |
-
using RandomFunc = detail::RandomSparseMetaFunc<Element>;
|
| 1030 |
-
|
| 1031 |
-
typename RandomFunc::Params params(seed, MetaSizeInBits);
|
| 1032 |
-
|
| 1033 |
-
BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
|
| 1034 |
-
}
|
| 1035 |
-
|
| 1036 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1037 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1038 |
-
|
| 1039 |
-
namespace detail {
|
| 1040 |
-
|
| 1041 |
-
/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal.
|
| 1042 |
-
template <
|
| 1043 |
-
typename Element, ///< Element type
|
| 1044 |
-
typename Layout> ///< Layout function
|
| 1045 |
-
struct TensorFillDiagonalFunc {
|
| 1046 |
-
|
| 1047 |
-
/// View type
|
| 1048 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1049 |
-
|
| 1050 |
-
/// Scalar type
|
| 1051 |
-
typedef typename TensorView::Element T;
|
| 1052 |
-
|
| 1053 |
-
/// Coordinate in tensor's index space
|
| 1054 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1055 |
-
|
| 1056 |
-
/// Parameters structure
|
| 1057 |
-
struct Params {
|
| 1058 |
-
|
| 1059 |
-
//
|
| 1060 |
-
// Data members
|
| 1061 |
-
//
|
| 1062 |
-
|
| 1063 |
-
TensorView view;
|
| 1064 |
-
Element diag;
|
| 1065 |
-
Element other;
|
| 1066 |
-
|
| 1067 |
-
/// Default ctor
|
| 1068 |
-
CUTLASS_HOST_DEVICE
|
| 1069 |
-
Params() { }
|
| 1070 |
-
|
| 1071 |
-
//
|
| 1072 |
-
// Methods
|
| 1073 |
-
//
|
| 1074 |
-
|
| 1075 |
-
Params(
|
| 1076 |
-
TensorView view_ = TensorView(),
|
| 1077 |
-
Element diag_ = Element(1),
|
| 1078 |
-
Element other_ = Element(0)
|
| 1079 |
-
):
|
| 1080 |
-
view(view_), diag(diag_), other(other_) {
|
| 1081 |
-
|
| 1082 |
-
}
|
| 1083 |
-
};
|
| 1084 |
-
|
| 1085 |
-
//
|
| 1086 |
-
// Data members
|
| 1087 |
-
//
|
| 1088 |
-
|
| 1089 |
-
/// Parameters object
|
| 1090 |
-
Params params;
|
| 1091 |
-
|
| 1092 |
-
//
|
| 1093 |
-
// Methods
|
| 1094 |
-
//
|
| 1095 |
-
|
| 1096 |
-
/// Device-side initialization of RNG
|
| 1097 |
-
CUTLASS_DEVICE
|
| 1098 |
-
TensorFillDiagonalFunc(Params const ¶ms): params(params) {
|
| 1099 |
-
|
| 1100 |
-
}
|
| 1101 |
-
|
| 1102 |
-
/// Updates the tensor
|
| 1103 |
-
CUTLASS_DEVICE
|
| 1104 |
-
void operator()(TensorCoord const &coord) {
|
| 1105 |
-
|
| 1106 |
-
bool is_diag = true;
|
| 1107 |
-
|
| 1108 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1109 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1110 |
-
if (coord[i] != coord[i - 1]) {
|
| 1111 |
-
is_diag = false;
|
| 1112 |
-
break;
|
| 1113 |
-
}
|
| 1114 |
-
}
|
| 1115 |
-
|
| 1116 |
-
params.view.at(coord) = (is_diag ? params.diag : params.other);
|
| 1117 |
-
}
|
| 1118 |
-
};
|
| 1119 |
-
|
| 1120 |
-
// Overwrites the elements of a tensor with a uniform value depending on fill mode
|
| 1121 |
-
template <
|
| 1122 |
-
typename Element, ///< Element type
|
| 1123 |
-
typename Layout> ///< Layout function
|
| 1124 |
-
struct TensorFillPartialFunc {
|
| 1125 |
-
|
| 1126 |
-
/// View type
|
| 1127 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1128 |
-
|
| 1129 |
-
/// Scalar type
|
| 1130 |
-
typedef typename TensorView::Element T;
|
| 1131 |
-
|
| 1132 |
-
/// Coordinate in tensor's index space
|
| 1133 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1134 |
-
|
| 1135 |
-
/// Parameters structure
|
| 1136 |
-
struct Params {
|
| 1137 |
-
|
| 1138 |
-
//
|
| 1139 |
-
// Data members
|
| 1140 |
-
//
|
| 1141 |
-
|
| 1142 |
-
TensorView view;
|
| 1143 |
-
Element element;
|
| 1144 |
-
FillMode fill_mode;
|
| 1145 |
-
|
| 1146 |
-
/// Default ctor
|
| 1147 |
-
CUTLASS_HOST_DEVICE
|
| 1148 |
-
Params(): fill_mode(FillMode::kNone) { }
|
| 1149 |
-
|
| 1150 |
-
//
|
| 1151 |
-
// Methods
|
| 1152 |
-
//
|
| 1153 |
-
|
| 1154 |
-
/// Construction of Gaussian RNG functor.
|
| 1155 |
-
Params(
|
| 1156 |
-
TensorView view_,
|
| 1157 |
-
Element element_,
|
| 1158 |
-
FillMode fill_mode_
|
| 1159 |
-
):
|
| 1160 |
-
view(view_), element(element_), fill_mode(fill_mode_) {
|
| 1161 |
-
|
| 1162 |
-
}
|
| 1163 |
-
};
|
| 1164 |
-
|
| 1165 |
-
//
|
| 1166 |
-
// Data members
|
| 1167 |
-
//
|
| 1168 |
-
|
| 1169 |
-
/// Parameters object
|
| 1170 |
-
Params params;
|
| 1171 |
-
|
| 1172 |
-
//
|
| 1173 |
-
// Methods
|
| 1174 |
-
//
|
| 1175 |
-
|
| 1176 |
-
CUTLASS_DEVICE
|
| 1177 |
-
TensorFillPartialFunc(Params const ¶ms): params(params) {
|
| 1178 |
-
|
| 1179 |
-
}
|
| 1180 |
-
|
| 1181 |
-
/// Overwrites the element if it is within the covered region.
|
| 1182 |
-
CUTLASS_DEVICE
|
| 1183 |
-
void operator()(TensorCoord const &coord) {
|
| 1184 |
-
|
| 1185 |
-
bool predicate = true;
|
| 1186 |
-
|
| 1187 |
-
switch (params.fill_mode) {
|
| 1188 |
-
case FillMode::kFull:
|
| 1189 |
-
predicate = true;
|
| 1190 |
-
break;
|
| 1191 |
-
|
| 1192 |
-
case FillMode::kLower:
|
| 1193 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1194 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1195 |
-
if (coord[i - 1] < coord[i]) {
|
| 1196 |
-
predicate = false;
|
| 1197 |
-
break;
|
| 1198 |
-
}
|
| 1199 |
-
}
|
| 1200 |
-
break;
|
| 1201 |
-
|
| 1202 |
-
case FillMode::kUpper:
|
| 1203 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1204 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1205 |
-
if (coord[i - 1] > coord[i]) {
|
| 1206 |
-
predicate = false;
|
| 1207 |
-
break;
|
| 1208 |
-
}
|
| 1209 |
-
}
|
| 1210 |
-
break;
|
| 1211 |
-
|
| 1212 |
-
case FillMode::kDiagonal:
|
| 1213 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1214 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1215 |
-
if (coord[i - 1] != coord[i]) {
|
| 1216 |
-
predicate = false;
|
| 1217 |
-
break;
|
| 1218 |
-
}
|
| 1219 |
-
}
|
| 1220 |
-
break;
|
| 1221 |
-
|
| 1222 |
-
case FillMode::kNone: // fall-through
|
| 1223 |
-
|
| 1224 |
-
default:
|
| 1225 |
-
predicate = false;
|
| 1226 |
-
break;
|
| 1227 |
-
}
|
| 1228 |
-
|
| 1229 |
-
if (predicate) {
|
| 1230 |
-
params.view.at(coord) = params.element;
|
| 1231 |
-
}
|
| 1232 |
-
}
|
| 1233 |
-
};
|
| 1234 |
-
|
| 1235 |
-
|
| 1236 |
-
template <
|
| 1237 |
-
typename Element, ///< Element type
|
| 1238 |
-
typename Layout> ///< Layout function
|
| 1239 |
-
struct TensorClearPartialFunc {
|
| 1240 |
-
|
| 1241 |
-
/// View type
|
| 1242 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1243 |
-
|
| 1244 |
-
/// Scalar type
|
| 1245 |
-
typedef typename TensorView::Element T;
|
| 1246 |
-
|
| 1247 |
-
/// Coordinate in tensor's index space
|
| 1248 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1249 |
-
|
| 1250 |
-
///
|
| 1251 |
-
static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices");
|
| 1252 |
-
|
| 1253 |
-
/// Parameters structure
|
| 1254 |
-
struct Params {
|
| 1255 |
-
TensorView view{};
|
| 1256 |
-
Element element{};
|
| 1257 |
-
FillMode fill_mode{FillMode::kNone};
|
| 1258 |
-
int alignment{0};
|
| 1259 |
-
};
|
| 1260 |
-
|
| 1261 |
-
//
|
| 1262 |
-
// Data members
|
| 1263 |
-
//
|
| 1264 |
-
|
| 1265 |
-
/// Parameters object
|
| 1266 |
-
Params params;
|
| 1267 |
-
|
| 1268 |
-
//
|
| 1269 |
-
// Methods
|
| 1270 |
-
//
|
| 1271 |
-
|
| 1272 |
-
CUTLASS_DEVICE
|
| 1273 |
-
TensorClearPartialFunc(Params const ¶ms): params(params) {
|
| 1274 |
-
|
| 1275 |
-
}
|
| 1276 |
-
|
| 1277 |
-
/// Overwrites the element if it is within the covered region.
|
| 1278 |
-
CUTLASS_DEVICE
|
| 1279 |
-
void operator()(TensorCoord const &coord) {
|
| 1280 |
-
|
| 1281 |
-
bool predicate = true;
|
| 1282 |
-
|
| 1283 |
-
switch (params.fill_mode) {
|
| 1284 |
-
|
| 1285 |
-
case FillMode::kLower:
|
| 1286 |
-
if ((coord[0] >= coord[1]) ||
|
| 1287 |
-
((coord[1] - coord[0]) >= params.alignment)) {
|
| 1288 |
-
predicate = false;
|
| 1289 |
-
break;
|
| 1290 |
-
}
|
| 1291 |
-
break;
|
| 1292 |
-
|
| 1293 |
-
case FillMode::kUpper:
|
| 1294 |
-
if ((coord[0] <= coord[1]) ||
|
| 1295 |
-
((coord[0] - coord[1]) >= params.alignment)) {
|
| 1296 |
-
predicate = false;
|
| 1297 |
-
break;
|
| 1298 |
-
}
|
| 1299 |
-
break;
|
| 1300 |
-
|
| 1301 |
-
case FillMode::kNone: // fall-through
|
| 1302 |
-
|
| 1303 |
-
default:
|
| 1304 |
-
predicate = false;
|
| 1305 |
-
break;
|
| 1306 |
-
}
|
| 1307 |
-
|
| 1308 |
-
if (predicate) {
|
| 1309 |
-
params.view.at(coord) = params.element;
|
| 1310 |
-
}
|
| 1311 |
-
}
|
| 1312 |
-
};
|
| 1313 |
-
|
| 1314 |
-
} // namespace detail
|
| 1315 |
-
|
| 1316 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1317 |
-
|
| 1318 |
-
/// Fills a tensor everywhere with a unique value for its diagonal.
|
| 1319 |
-
template <
|
| 1320 |
-
typename Element, ///< Element type
|
| 1321 |
-
typename Layout> ///< Layout function
|
| 1322 |
-
void TensorFillDiagonal(
|
| 1323 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1324 |
-
Element diag = Element(1), ///< value to write in the diagonal
|
| 1325 |
-
Element other = Element(0), ///< value to write off the diagonal
|
| 1326 |
-
cudaStream_t stream = nullptr) {
|
| 1327 |
-
|
| 1328 |
-
typedef detail::TensorFillDiagonalFunc<Element, Layout> Func;
|
| 1329 |
-
typedef typename Func::Params Params;
|
| 1330 |
-
|
| 1331 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1332 |
-
view.extent(),
|
| 1333 |
-
Params(view, diag, other),
|
| 1334 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1335 |
-
stream
|
| 1336 |
-
);
|
| 1337 |
-
}
|
| 1338 |
-
|
| 1339 |
-
/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are
|
| 1340 |
-
/// not written.
|
| 1341 |
-
template <
|
| 1342 |
-
typename Element, ///< Element type
|
| 1343 |
-
typename Layout> ///< Layout function
|
| 1344 |
-
void TensorFillPartial(
|
| 1345 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1346 |
-
Element element,
|
| 1347 |
-
FillMode fill_mode,
|
| 1348 |
-
cudaStream_t stream = nullptr) {
|
| 1349 |
-
|
| 1350 |
-
typedef detail::TensorFillPartialFunc<Element, Layout> Func;
|
| 1351 |
-
typedef typename Func::Params Params;
|
| 1352 |
-
|
| 1353 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1354 |
-
view.extent(),
|
| 1355 |
-
Params(view, element, fill_mode),
|
| 1356 |
-
stream
|
| 1357 |
-
);
|
| 1358 |
-
}
|
| 1359 |
-
|
| 1360 |
-
/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side
|
| 1361 |
-
/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros)
|
| 1362 |
-
template <
|
| 1363 |
-
typename Element, ///< Element type
|
| 1364 |
-
typename Layout> ///< Layout function
|
| 1365 |
-
void TensorClearPartial(
|
| 1366 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1367 |
-
Element element,
|
| 1368 |
-
FillMode fill_mode,
|
| 1369 |
-
int alignment,
|
| 1370 |
-
cudaStream_t stream = nullptr) {
|
| 1371 |
-
|
| 1372 |
-
typedef detail::TensorClearPartialFunc<Element, Layout> Func;
|
| 1373 |
-
typedef typename Func::Params Params;
|
| 1374 |
-
|
| 1375 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1376 |
-
view.extent(),
|
| 1377 |
-
Params{view, element, fill_mode, alignment},
|
| 1378 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1379 |
-
stream
|
| 1380 |
-
);
|
| 1381 |
-
}
|
| 1382 |
-
|
| 1383 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1384 |
-
|
| 1385 |
-
/// Fills a tensor with a uniform value
|
| 1386 |
-
template <
|
| 1387 |
-
typename Element, ///< Element type
|
| 1388 |
-
typename Layout> ///< Layout function
|
| 1389 |
-
void TensorFill(
|
| 1390 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1391 |
-
Element val = Element(0), ///< value to uniformly fill it with
|
| 1392 |
-
cudaStream_t stream = nullptr) {
|
| 1393 |
-
|
| 1394 |
-
TensorFillDiagonal(view, val, val, stream);
|
| 1395 |
-
}
|
| 1396 |
-
|
| 1397 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1398 |
-
|
| 1399 |
-
/// Fills a tensor's diagonal with 1 and 0 everywhere else.
|
| 1400 |
-
template <
|
| 1401 |
-
typename Element, ///< Element type
|
| 1402 |
-
typename Layout> ///< Layout function
|
| 1403 |
-
void TensorFillIdentity(
|
| 1404 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1405 |
-
cudaStream_t stream = nullptr) {
|
| 1406 |
-
|
| 1407 |
-
TensorFillDiagonal(view, Element(1), Element(0), stream);
|
| 1408 |
-
}
|
| 1409 |
-
|
| 1410 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1411 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1412 |
-
|
| 1413 |
-
namespace detail {
|
| 1414 |
-
|
| 1415 |
-
/// Computes a random Gaussian distribution
|
| 1416 |
-
template <
|
| 1417 |
-
typename Element, ///< Element type
|
| 1418 |
-
typename Layout> ///< Layout function
|
| 1419 |
-
struct TensorUpdateDiagonalFunc {
|
| 1420 |
-
|
| 1421 |
-
/// View type
|
| 1422 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1423 |
-
|
| 1424 |
-
/// Scalar type
|
| 1425 |
-
typedef typename TensorView::Element T;
|
| 1426 |
-
|
| 1427 |
-
/// Coordinate in tensor's index space
|
| 1428 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1429 |
-
|
| 1430 |
-
/// Parameters structure
|
| 1431 |
-
struct Params {
|
| 1432 |
-
|
| 1433 |
-
//
|
| 1434 |
-
// Data members
|
| 1435 |
-
//
|
| 1436 |
-
|
| 1437 |
-
TensorView view;
|
| 1438 |
-
Element diag;
|
| 1439 |
-
|
| 1440 |
-
/// Default ctor
|
| 1441 |
-
CUTLASS_HOST_DEVICE
|
| 1442 |
-
Params() { }
|
| 1443 |
-
|
| 1444 |
-
//
|
| 1445 |
-
// Methods
|
| 1446 |
-
//
|
| 1447 |
-
|
| 1448 |
-
/// Construction of Gaussian RNG functor.
|
| 1449 |
-
Params(
|
| 1450 |
-
TensorView view_ = TensorView(),
|
| 1451 |
-
Element diag_ = Element(1)
|
| 1452 |
-
):
|
| 1453 |
-
view(view_), diag(diag_) {
|
| 1454 |
-
|
| 1455 |
-
}
|
| 1456 |
-
};
|
| 1457 |
-
|
| 1458 |
-
//
|
| 1459 |
-
// Data members
|
| 1460 |
-
//
|
| 1461 |
-
|
| 1462 |
-
/// Parameters object
|
| 1463 |
-
Params params;
|
| 1464 |
-
|
| 1465 |
-
//
|
| 1466 |
-
// Methods
|
| 1467 |
-
//
|
| 1468 |
-
|
| 1469 |
-
/// Device-side initialization of RNG
|
| 1470 |
-
CUTLASS_DEVICE
|
| 1471 |
-
TensorUpdateDiagonalFunc(Params const ¶ms): params(params) {
|
| 1472 |
-
|
| 1473 |
-
}
|
| 1474 |
-
|
| 1475 |
-
/// Compute random value and update RNG state
|
| 1476 |
-
CUTLASS_DEVICE
|
| 1477 |
-
void operator()(TensorCoord const &coord) {
|
| 1478 |
-
|
| 1479 |
-
bool is_diag = true;
|
| 1480 |
-
|
| 1481 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1482 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1483 |
-
if (coord[i] != coord[i - 1]) {
|
| 1484 |
-
is_diag = false;
|
| 1485 |
-
break;
|
| 1486 |
-
}
|
| 1487 |
-
}
|
| 1488 |
-
|
| 1489 |
-
if (is_diag) {
|
| 1490 |
-
params.view.at(coord) = params.diag;
|
| 1491 |
-
}
|
| 1492 |
-
}
|
| 1493 |
-
};
|
| 1494 |
-
|
| 1495 |
-
} // namespace detail
|
| 1496 |
-
|
| 1497 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1498 |
-
|
| 1499 |
-
/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements.
|
| 1500 |
-
template <
|
| 1501 |
-
typename Element, ///< Element type
|
| 1502 |
-
typename Layout> ///< Layout function
|
| 1503 |
-
void TensorUpdateDiagonal(
|
| 1504 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1505 |
-
Element diag = Element(1),
|
| 1506 |
-
cudaStream_t stream = nullptr) {
|
| 1507 |
-
|
| 1508 |
-
typedef detail::TensorUpdateDiagonalFunc<Element, Layout> Func;
|
| 1509 |
-
typedef typename Func::Params Params;
|
| 1510 |
-
|
| 1511 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1512 |
-
view.extent(),
|
| 1513 |
-
Params(view, diag),
|
| 1514 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1515 |
-
stream
|
| 1516 |
-
);
|
| 1517 |
-
}
|
| 1518 |
-
|
| 1519 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1520 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1521 |
-
|
| 1522 |
-
namespace detail {
|
| 1523 |
-
|
| 1524 |
-
/// Computes a random Gaussian distribution
|
| 1525 |
-
template <
|
| 1526 |
-
typename Element, ///< Element type
|
| 1527 |
-
typename Layout> ///< Layout function
|
| 1528 |
-
struct TensorUpdateOffDiagonalFunc {
|
| 1529 |
-
|
| 1530 |
-
/// View type
|
| 1531 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1532 |
-
|
| 1533 |
-
/// Scalar type
|
| 1534 |
-
typedef typename TensorView::Element T;
|
| 1535 |
-
|
| 1536 |
-
/// Coordinate in tensor's index space
|
| 1537 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1538 |
-
|
| 1539 |
-
/// Parameters structure
|
| 1540 |
-
struct Params {
|
| 1541 |
-
|
| 1542 |
-
//
|
| 1543 |
-
// Data members
|
| 1544 |
-
//
|
| 1545 |
-
|
| 1546 |
-
TensorView view;
|
| 1547 |
-
Element other;
|
| 1548 |
-
|
| 1549 |
-
/// Default ctor
|
| 1550 |
-
CUTLASS_HOST_DEVICE
|
| 1551 |
-
Params() { }
|
| 1552 |
-
|
| 1553 |
-
//
|
| 1554 |
-
// Methods
|
| 1555 |
-
//
|
| 1556 |
-
|
| 1557 |
-
/// Construction of Gaussian RNG functor.
|
| 1558 |
-
Params(
|
| 1559 |
-
TensorView view_ = TensorView(),
|
| 1560 |
-
Element other_ = Element(0)
|
| 1561 |
-
):
|
| 1562 |
-
view(view_), other(other_) {
|
| 1563 |
-
|
| 1564 |
-
}
|
| 1565 |
-
};
|
| 1566 |
-
|
| 1567 |
-
//
|
| 1568 |
-
// Data members
|
| 1569 |
-
//
|
| 1570 |
-
|
| 1571 |
-
/// Parameters object
|
| 1572 |
-
Params params;
|
| 1573 |
-
|
| 1574 |
-
//
|
| 1575 |
-
// Methods
|
| 1576 |
-
//
|
| 1577 |
-
|
| 1578 |
-
/// Device-side initialization of RNG
|
| 1579 |
-
CUTLASS_DEVICE
|
| 1580 |
-
TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) {
|
| 1581 |
-
|
| 1582 |
-
}
|
| 1583 |
-
|
| 1584 |
-
/// Compute random value and update RNG state
|
| 1585 |
-
CUTLASS_DEVICE
|
| 1586 |
-
void operator()(TensorCoord const &coord) {
|
| 1587 |
-
|
| 1588 |
-
bool is_diag = true;
|
| 1589 |
-
|
| 1590 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1591 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1592 |
-
if (coord[i] != coord[i - 1]) {
|
| 1593 |
-
is_diag = false;
|
| 1594 |
-
break;
|
| 1595 |
-
}
|
| 1596 |
-
}
|
| 1597 |
-
|
| 1598 |
-
if (!is_diag) {
|
| 1599 |
-
params.view.at(coord) = params.other;
|
| 1600 |
-
}
|
| 1601 |
-
}
|
| 1602 |
-
};
|
| 1603 |
-
|
| 1604 |
-
} // namespace detail
|
| 1605 |
-
|
| 1606 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1607 |
-
|
| 1608 |
-
/// Writes a uniform value to all elements in the tensor without modifying diagonal elements.
|
| 1609 |
-
template <
|
| 1610 |
-
typename Element, ///< Element type
|
| 1611 |
-
typename Layout> ///< Layout function
|
| 1612 |
-
void TensorUpdateOffDiagonal(
|
| 1613 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1614 |
-
Element other = Element(1),
|
| 1615 |
-
cudaStream_t stream = nullptr) {
|
| 1616 |
-
|
| 1617 |
-
typedef detail::TensorUpdateOffDiagonalFunc<Element, Layout> Func;
|
| 1618 |
-
typedef typename Func::Params Params;
|
| 1619 |
-
|
| 1620 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1621 |
-
view.extent(),
|
| 1622 |
-
Params(view, other),
|
| 1623 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1624 |
-
stream
|
| 1625 |
-
);
|
| 1626 |
-
}
|
| 1627 |
-
|
| 1628 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1629 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1630 |
-
|
| 1631 |
-
namespace detail {
|
| 1632 |
-
|
| 1633 |
-
/// Computes a random Gaussian distribution
|
| 1634 |
-
template <
|
| 1635 |
-
typename Element, ///< Element type
|
| 1636 |
-
typename Layout> ///< Layout function
|
| 1637 |
-
struct TensorFillLinearFunc {
|
| 1638 |
-
|
| 1639 |
-
/// View type
|
| 1640 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1641 |
-
|
| 1642 |
-
/// Scalar type
|
| 1643 |
-
typedef typename TensorView::Element T;
|
| 1644 |
-
|
| 1645 |
-
/// Coordinate in tensor's index space
|
| 1646 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1647 |
-
|
| 1648 |
-
/// Parameters structure
|
| 1649 |
-
struct Params {
|
| 1650 |
-
|
| 1651 |
-
//
|
| 1652 |
-
// Data members
|
| 1653 |
-
//
|
| 1654 |
-
|
| 1655 |
-
TensorView view;
|
| 1656 |
-
Array<Element, Layout::kRank> v;
|
| 1657 |
-
Element s;
|
| 1658 |
-
|
| 1659 |
-
/// Default ctor
|
| 1660 |
-
CUTLASS_HOST_DEVICE
|
| 1661 |
-
Params() { }
|
| 1662 |
-
|
| 1663 |
-
//
|
| 1664 |
-
// Methods
|
| 1665 |
-
//
|
| 1666 |
-
|
| 1667 |
-
/// Construction of Gaussian RNG functor.
|
| 1668 |
-
Params(
|
| 1669 |
-
TensorView view_, ///< destination tensor
|
| 1670 |
-
Array<Element, Layout::kRank> const & v_,
|
| 1671 |
-
Element s_ = Element(0)
|
| 1672 |
-
):
|
| 1673 |
-
view(view_), v(v_), s(s_) {
|
| 1674 |
-
|
| 1675 |
-
}
|
| 1676 |
-
};
|
| 1677 |
-
|
| 1678 |
-
//
|
| 1679 |
-
// Data members
|
| 1680 |
-
//
|
| 1681 |
-
|
| 1682 |
-
/// Parameters object
|
| 1683 |
-
Params params;
|
| 1684 |
-
|
| 1685 |
-
//
|
| 1686 |
-
// Methods
|
| 1687 |
-
//
|
| 1688 |
-
|
| 1689 |
-
/// Device-side initialization of RNG
|
| 1690 |
-
CUTLASS_DEVICE
|
| 1691 |
-
TensorFillLinearFunc(Params const ¶ms): params(params) {
|
| 1692 |
-
|
| 1693 |
-
}
|
| 1694 |
-
|
| 1695 |
-
/// Compute random value and update RNG state
|
| 1696 |
-
CUTLASS_DEVICE
|
| 1697 |
-
void operator()(TensorCoord const &coord) {
|
| 1698 |
-
|
| 1699 |
-
Element sum = params.s;
|
| 1700 |
-
|
| 1701 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1702 |
-
for (int i = 0; i < Layout::kRank; ++i) {
|
| 1703 |
-
if constexpr (is_complex<Element>::value) {
|
| 1704 |
-
if constexpr (sizeof_bits<Element>::value <= 32) {
|
| 1705 |
-
sum = Element(static_cast<complex<float>>(sum) +
|
| 1706 |
-
static_cast<complex<float>>(params.v[i]) * static_cast<complex<float>>(coord[i]));
|
| 1707 |
-
}
|
| 1708 |
-
}
|
| 1709 |
-
else if constexpr (sizeof_bits<Element>::value <= 32) {
|
| 1710 |
-
if constexpr (std::numeric_limits<Element>::is_integer) {
|
| 1711 |
-
sum = Element(static_cast<int32_t>(sum) +
|
| 1712 |
-
static_cast<int32_t>(params.v[i]) * static_cast<int32_t>(coord[i]));
|
| 1713 |
-
}
|
| 1714 |
-
else {
|
| 1715 |
-
sum = Element(static_cast<float>(sum) +
|
| 1716 |
-
static_cast<float>(params.v[i]) * static_cast<float>(coord[i]));
|
| 1717 |
-
}
|
| 1718 |
-
}
|
| 1719 |
-
else {
|
| 1720 |
-
sum += params.v[i] * coord[i];
|
| 1721 |
-
}
|
| 1722 |
-
}
|
| 1723 |
-
|
| 1724 |
-
params.view.at(coord) = sum;
|
| 1725 |
-
}
|
| 1726 |
-
};
|
| 1727 |
-
|
| 1728 |
-
} // namespace detail
|
| 1729 |
-
|
| 1730 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1731 |
-
|
| 1732 |
-
/// Fills tensor with a linear combination of its coordinate and another vector
|
| 1733 |
-
template <
|
| 1734 |
-
typename Element, ///< Element type
|
| 1735 |
-
typename Layout> ///< Layout function
|
| 1736 |
-
void TensorFillLinear(
|
| 1737 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1738 |
-
Array<Element, Layout::kRank> const & v,
|
| 1739 |
-
Element s = Element(0),
|
| 1740 |
-
cudaStream_t stream = nullptr) {
|
| 1741 |
-
|
| 1742 |
-
using Func = detail::TensorFillLinearFunc<Element, Layout>;
|
| 1743 |
-
using Params = typename Func::Params;
|
| 1744 |
-
|
| 1745 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1746 |
-
view.extent(),
|
| 1747 |
-
Params(view, v, s),
|
| 1748 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1749 |
-
stream
|
| 1750 |
-
);
|
| 1751 |
-
}
|
| 1752 |
-
|
| 1753 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1754 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1755 |
-
|
| 1756 |
-
/// Fills a tensor with random values from a distribution.
|
| 1757 |
-
template <
|
| 1758 |
-
typename Element, ///< Element type
|
| 1759 |
-
typename Layout> ///< Layout function
|
| 1760 |
-
void TensorFillRandom(
|
| 1761 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1762 |
-
uint64_t seed,
|
| 1763 |
-
Distribution dist,
|
| 1764 |
-
cudaStream_t stream = nullptr,
|
| 1765 |
-
int exclude_zero = -1 ///< If non-negative, excludes 0.
|
| 1766 |
-
/// Note that setting this flag will result in more 1's,
|
| 1767 |
-
/// as we use a simple mechanism to replace 0's by adding/subtracting 1's.
|
| 1768 |
-
) {
|
| 1769 |
-
|
| 1770 |
-
using Real = typename RealType<Element>::Type;
|
| 1771 |
-
|
| 1772 |
-
if (dist.kind == Distribution::Gaussian) {
|
| 1773 |
-
TensorFillRandomGaussian<Element, Layout>(
|
| 1774 |
-
view,
|
| 1775 |
-
seed,
|
| 1776 |
-
static_cast<Real>(dist.gaussian.mean),
|
| 1777 |
-
static_cast<Real>(dist.gaussian.stddev),
|
| 1778 |
-
dist.int_scale,
|
| 1779 |
-
exclude_zero,
|
| 1780 |
-
stream);
|
| 1781 |
-
} else if (dist.kind == Distribution::Uniform) {
|
| 1782 |
-
TensorFillRandomUniform<Element, Layout>(
|
| 1783 |
-
view,
|
| 1784 |
-
seed,
|
| 1785 |
-
static_cast<Real>(dist.uniform.max),
|
| 1786 |
-
static_cast<Real>(dist.uniform.min),
|
| 1787 |
-
dist.int_scale,
|
| 1788 |
-
dist.uniform.pnan,
|
| 1789 |
-
exclude_zero,
|
| 1790 |
-
stream);
|
| 1791 |
-
}
|
| 1792 |
-
}
|
| 1793 |
-
|
| 1794 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1795 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1796 |
-
|
| 1797 |
-
/// Fills a block of data with sequential elements
|
| 1798 |
-
template <
|
| 1799 |
-
typename Element
|
| 1800 |
-
>
|
| 1801 |
-
void BlockFillSequential(
|
| 1802 |
-
Element *ptr,
|
| 1803 |
-
int64_t capacity,
|
| 1804 |
-
Element v = Element(1),
|
| 1805 |
-
Element s = Element(0)) {
|
| 1806 |
-
|
| 1807 |
-
using Layout = layout::PackedVectorLayout;
|
| 1808 |
-
Layout::TensorCoord size(static_cast<Layout::Index>(capacity)); // -Wconversion
|
| 1809 |
-
Layout layout = Layout::packed(size);
|
| 1810 |
-
TensorView<Element, Layout> view(ptr, layout, size);
|
| 1811 |
-
|
| 1812 |
-
Array<Element, Layout::kRank> c{};
|
| 1813 |
-
c[0] = v;
|
| 1814 |
-
|
| 1815 |
-
TensorFillLinear(view, c, s);
|
| 1816 |
-
}
|
| 1817 |
-
|
| 1818 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1819 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1820 |
-
|
| 1821 |
-
/// Fills a block of data with sequential elements
|
| 1822 |
-
template <
|
| 1823 |
-
typename Element
|
| 1824 |
-
>
|
| 1825 |
-
void BlockFillRandom(
|
| 1826 |
-
Element *ptr,
|
| 1827 |
-
size_t capacity,
|
| 1828 |
-
uint64_t seed,
|
| 1829 |
-
Distribution dist,
|
| 1830 |
-
cudaStream_t stream = nullptr) {
|
| 1831 |
-
|
| 1832 |
-
using Real = typename RealType<Element>::Type;
|
| 1833 |
-
|
| 1834 |
-
if (dist.kind == Distribution::Gaussian) {
|
| 1835 |
-
BlockFillRandomGaussian<Element>(
|
| 1836 |
-
ptr,
|
| 1837 |
-
capacity,
|
| 1838 |
-
seed,
|
| 1839 |
-
static_cast<Real>(dist.gaussian.mean),
|
| 1840 |
-
static_cast<Real>(dist.gaussian.stddev),
|
| 1841 |
-
dist.int_scale,
|
| 1842 |
-
stream);
|
| 1843 |
-
}
|
| 1844 |
-
else if (dist.kind == Distribution::Uniform) {
|
| 1845 |
-
BlockFillRandomUniform<Element>(
|
| 1846 |
-
ptr,
|
| 1847 |
-
capacity,
|
| 1848 |
-
seed,
|
| 1849 |
-
static_cast<Real>(dist.uniform.max),
|
| 1850 |
-
static_cast<Real>(dist.uniform.min),
|
| 1851 |
-
dist.int_scale,
|
| 1852 |
-
dist.uniform.pnan,
|
| 1853 |
-
stream);
|
| 1854 |
-
}
|
| 1855 |
-
}
|
| 1856 |
-
|
| 1857 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1858 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1859 |
-
|
| 1860 |
-
namespace detail {
|
| 1861 |
-
|
| 1862 |
-
/// Computes a random Gaussian distribution
|
| 1863 |
-
template <
|
| 1864 |
-
typename Element, ///< Element type
|
| 1865 |
-
typename Layout> ///< Layout function
|
| 1866 |
-
struct TensorCopyDiagonalInFunc {
|
| 1867 |
-
|
| 1868 |
-
/// View type
|
| 1869 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1870 |
-
|
| 1871 |
-
/// Scalar type
|
| 1872 |
-
typedef typename TensorView::Element T;
|
| 1873 |
-
|
| 1874 |
-
/// Coordinate in tensor's index space
|
| 1875 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1876 |
-
|
| 1877 |
-
/// Parameters structure
|
| 1878 |
-
struct Params {
|
| 1879 |
-
|
| 1880 |
-
//
|
| 1881 |
-
// Data members
|
| 1882 |
-
//
|
| 1883 |
-
|
| 1884 |
-
TensorView view;
|
| 1885 |
-
Element const *ptr;
|
| 1886 |
-
|
| 1887 |
-
/// Default ctor
|
| 1888 |
-
CUTLASS_HOST_DEVICE
|
| 1889 |
-
Params() { }
|
| 1890 |
-
|
| 1891 |
-
//
|
| 1892 |
-
// Methods
|
| 1893 |
-
//
|
| 1894 |
-
|
| 1895 |
-
/// Construction of Gaussian RNG functor.
|
| 1896 |
-
Params(
|
| 1897 |
-
TensorView view_, ///< destination tensor
|
| 1898 |
-
Element const *ptr_
|
| 1899 |
-
):
|
| 1900 |
-
view(view_), ptr(ptr_) {
|
| 1901 |
-
|
| 1902 |
-
}
|
| 1903 |
-
};
|
| 1904 |
-
|
| 1905 |
-
//
|
| 1906 |
-
// Data members
|
| 1907 |
-
//
|
| 1908 |
-
|
| 1909 |
-
/// Parameters object
|
| 1910 |
-
Params params;
|
| 1911 |
-
|
| 1912 |
-
//
|
| 1913 |
-
// Methods
|
| 1914 |
-
//
|
| 1915 |
-
|
| 1916 |
-
/// Device-side initialization of RNG
|
| 1917 |
-
CUTLASS_DEVICE
|
| 1918 |
-
TensorCopyDiagonalInFunc(Params const ¶ms): params(params) {
|
| 1919 |
-
|
| 1920 |
-
}
|
| 1921 |
-
|
| 1922 |
-
/// Only update the diagonal element
|
| 1923 |
-
CUTLASS_DEVICE
|
| 1924 |
-
void operator()(TensorCoord const &coord) {
|
| 1925 |
-
bool is_diagonal = true;
|
| 1926 |
-
|
| 1927 |
-
CUTLASS_PRAGMA_UNROLL
|
| 1928 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1929 |
-
if (coord[i] != coord[0]) {
|
| 1930 |
-
is_diagonal = false;
|
| 1931 |
-
}
|
| 1932 |
-
}
|
| 1933 |
-
if (is_diagonal) {
|
| 1934 |
-
params.view.at(coord) = params.ptr[coord[0]];
|
| 1935 |
-
}
|
| 1936 |
-
}
|
| 1937 |
-
};
|
| 1938 |
-
|
| 1939 |
-
} // namespace detail
|
| 1940 |
-
|
| 1941 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1942 |
-
|
| 1943 |
-
/// Copies a diagonal in from host memory without modifying off-diagonal elements.
|
| 1944 |
-
template <
|
| 1945 |
-
typename Element, ///< Element type
|
| 1946 |
-
typename Layout> ///< Layout function
|
| 1947 |
-
void TensorCopyDiagonalIn(
|
| 1948 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 1949 |
-
Element const *ptr, ///< dense buffer of elements
|
| 1950 |
-
cudaStream_t stream = nullptr) {
|
| 1951 |
-
|
| 1952 |
-
using Func = detail::TensorCopyDiagonalInFunc<Element, Layout>;
|
| 1953 |
-
using Params = typename Func::Params;
|
| 1954 |
-
|
| 1955 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 1956 |
-
view.extent(),
|
| 1957 |
-
Params(view, ptr),
|
| 1958 |
-
/*grid_size*/0, /*block_size*/0,
|
| 1959 |
-
stream
|
| 1960 |
-
);
|
| 1961 |
-
}
|
| 1962 |
-
|
| 1963 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1964 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1965 |
-
|
| 1966 |
-
|
| 1967 |
-
namespace detail {
|
| 1968 |
-
|
| 1969 |
-
/// Computes a random Gaussian distribution
|
| 1970 |
-
template <
|
| 1971 |
-
typename Element, ///< Element type
|
| 1972 |
-
typename Layout> ///< Layout function
|
| 1973 |
-
struct TensorCopyDiagonalOutFunc {
|
| 1974 |
-
|
| 1975 |
-
/// View type
|
| 1976 |
-
using TensorView = TensorView<Element, Layout>;
|
| 1977 |
-
|
| 1978 |
-
/// Scalar type
|
| 1979 |
-
typedef typename TensorView::Element T;
|
| 1980 |
-
|
| 1981 |
-
/// Coordinate in tensor's index space
|
| 1982 |
-
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1983 |
-
|
| 1984 |
-
/// Parameters structure
|
| 1985 |
-
struct Params {
|
| 1986 |
-
|
| 1987 |
-
//
|
| 1988 |
-
// Data members
|
| 1989 |
-
//
|
| 1990 |
-
|
| 1991 |
-
TensorView view;
|
| 1992 |
-
Element *ptr;
|
| 1993 |
-
|
| 1994 |
-
/// Default ctor
|
| 1995 |
-
CUTLASS_HOST_DEVICE
|
| 1996 |
-
Params() { }
|
| 1997 |
-
|
| 1998 |
-
//
|
| 1999 |
-
// Methods
|
| 2000 |
-
//
|
| 2001 |
-
|
| 2002 |
-
/// Construction of Gaussian RNG functor.
|
| 2003 |
-
Params(
|
| 2004 |
-
TensorView view_, ///< destination tensor
|
| 2005 |
-
Element *ptr_
|
| 2006 |
-
):
|
| 2007 |
-
view(view_), ptr(ptr_) {
|
| 2008 |
-
|
| 2009 |
-
}
|
| 2010 |
-
};
|
| 2011 |
-
|
| 2012 |
-
//
|
| 2013 |
-
// Data members
|
| 2014 |
-
//
|
| 2015 |
-
|
| 2016 |
-
/// Parameters object
|
| 2017 |
-
Params params;
|
| 2018 |
-
|
| 2019 |
-
//
|
| 2020 |
-
// Methods
|
| 2021 |
-
//
|
| 2022 |
-
|
| 2023 |
-
/// Device-side initialization of RNG
|
| 2024 |
-
CUTLASS_DEVICE
|
| 2025 |
-
TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) {
|
| 2026 |
-
|
| 2027 |
-
}
|
| 2028 |
-
|
| 2029 |
-
/// Compute random value and update RNG state
|
| 2030 |
-
CUTLASS_DEVICE
|
| 2031 |
-
void operator()(TensorCoord const &coord) {
|
| 2032 |
-
bool is_diagonal = true;
|
| 2033 |
-
|
| 2034 |
-
CUTLASS_PRAGMA_UNROLL
|
| 2035 |
-
for (int i = 1; i < Layout::kRank; ++i) {
|
| 2036 |
-
if (coord[i] != coord[0]) {
|
| 2037 |
-
is_diagonal = false;
|
| 2038 |
-
}
|
| 2039 |
-
}
|
| 2040 |
-
if (is_diagonal) {
|
| 2041 |
-
params.ptr[coord[0]] = params.view.at(coord);
|
| 2042 |
-
}
|
| 2043 |
-
}
|
| 2044 |
-
};
|
| 2045 |
-
|
| 2046 |
-
} // namespace detail
|
| 2047 |
-
|
| 2048 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2049 |
-
|
| 2050 |
-
/// Copies the diagonal of a tensor into a dense buffer in host memory.
|
| 2051 |
-
template <
|
| 2052 |
-
typename Element, ///< Element type
|
| 2053 |
-
typename Layout> ///< Layout function
|
| 2054 |
-
void TensorCopyDiagonalOut(
|
| 2055 |
-
Element *ptr, ///< dense buffer of elements
|
| 2056 |
-
TensorView<Element, Layout> view, ///< source tensor
|
| 2057 |
-
cudaStream_t stream = nullptr) {
|
| 2058 |
-
|
| 2059 |
-
using Func = detail::TensorCopyDiagonalOutFunc<Element, Layout>;
|
| 2060 |
-
using Params = typename Func::Params;
|
| 2061 |
-
|
| 2062 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 2063 |
-
view.extent(),
|
| 2064 |
-
Params(view, ptr),
|
| 2065 |
-
/*grid_size*/0, /*block_size*/0,
|
| 2066 |
-
stream
|
| 2067 |
-
);
|
| 2068 |
-
}
|
| 2069 |
-
|
| 2070 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2071 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2072 |
-
|
| 2073 |
-
} // namespace device
|
| 2074 |
-
} // namespace reference
|
| 2075 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include <stdexcept>
|
| 34 |
-
#include "cutlass/cutlass.h"
|
| 35 |
-
#include "cutlass/util/reference/device/kernel/tensor_foreach.h"
|
| 36 |
-
|
| 37 |
-
namespace cutlass {
|
| 38 |
-
namespace reference {
|
| 39 |
-
namespace device {
|
| 40 |
-
|
| 41 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
-
|
| 43 |
-
/// Launches a kernel calling a functor for each element in a tensor's index space.
|
| 44 |
-
template <typename Func, int Rank, typename Params>
|
| 45 |
-
struct TensorForEach {
|
| 46 |
-
|
| 47 |
-
/// Constructor performs the operation.
|
| 48 |
-
TensorForEach(
|
| 49 |
-
Coord<Rank> size, Params params = Params(),
|
| 50 |
-
int grid_size = 0, int block_size = 0,
|
| 51 |
-
cudaStream_t stream = nullptr) {
|
| 52 |
-
|
| 53 |
-
if (!grid_size || !block_size) {
|
| 54 |
-
|
| 55 |
-
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 56 |
-
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 57 |
-
&grid_size,
|
| 58 |
-
&block_size,
|
| 59 |
-
reinterpret_cast<void const *>(kernel::TensorForEach<Func, Rank, Params>));
|
| 60 |
-
|
| 61 |
-
if (result != cudaSuccess) {
|
| 62 |
-
throw std::runtime_error("Failed to query occupancy.");
|
| 63 |
-
}
|
| 64 |
-
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 65 |
-
// single thread and reduces the impact of initialization overhead.
|
| 66 |
-
block_size = (block_size < 128 ? block_size : 128);
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
dim3 grid(grid_size, 1, 1);
|
| 70 |
-
dim3 block(block_size, 1, 1);
|
| 71 |
-
|
| 72 |
-
kernel::TensorForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(size, params);
|
| 73 |
-
}
|
| 74 |
-
};
|
| 75 |
-
|
| 76 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
-
|
| 78 |
-
/// Launches a kernel calling a functor for each element along a tensor's diagonal
|
| 79 |
-
template <typename Func, int Rank, typename Params>
|
| 80 |
-
struct TensorDiagonalForEach {
|
| 81 |
-
|
| 82 |
-
/// Constructor performs the operation
|
| 83 |
-
TensorDiagonalForEach(
|
| 84 |
-
Coord<Rank> size, Params params = Params(),
|
| 85 |
-
int start = 0, int end = -1,
|
| 86 |
-
int block_size = 128, cudaStream_t stream = nullptr) {
|
| 87 |
-
|
| 88 |
-
if (end < 0) {
|
| 89 |
-
end = size.min();
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
dim3 block(block_size, 1, 1);
|
| 93 |
-
dim3 grid((end - start + block_size - 1) / block_size, 1, 1);
|
| 94 |
-
|
| 95 |
-
kernel::TensorDiagonalForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(
|
| 96 |
-
size, params, start, end);
|
| 97 |
-
}
|
| 98 |
-
};
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 102 |
-
|
| 103 |
-
template <typename Element, typename Func>
|
| 104 |
-
struct BlockForEach {
|
| 105 |
-
|
| 106 |
-
/// Constructor performs the operation.
|
| 107 |
-
BlockForEach(
|
| 108 |
-
Element *ptr,
|
| 109 |
-
size_t capacity,
|
| 110 |
-
typename Func::Params params = typename Func::Params(),
|
| 111 |
-
int grid_size = 0,
|
| 112 |
-
int block_size = 0,
|
| 113 |
-
cudaStream_t stream = nullptr) {
|
| 114 |
-
|
| 115 |
-
if (!grid_size || !block_size) {
|
| 116 |
-
|
| 117 |
-
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 118 |
-
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 119 |
-
&grid_size,
|
| 120 |
-
&block_size,
|
| 121 |
-
reinterpret_cast<void const *>(kernel::BlockForEach<Element, Func>));
|
| 122 |
-
|
| 123 |
-
if (result != cudaSuccess) {
|
| 124 |
-
throw std::runtime_error("Failed to query occupancy.");
|
| 125 |
-
}
|
| 126 |
-
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 127 |
-
// single thread and reduces the impact of initialization overhead.
|
| 128 |
-
block_size = (block_size < 128 ? block_size : 128);
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
dim3 grid(grid_size, 1, 1);
|
| 132 |
-
dim3 block(block_size, 1, 1);
|
| 133 |
-
|
| 134 |
-
kernel::BlockForEach<Element, Func><<< grid, block, 0, stream >>>(ptr, capacity, params);
|
| 135 |
-
}
|
| 136 |
-
};
|
| 137 |
-
|
| 138 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 139 |
-
|
| 140 |
-
} // namespace device
|
| 141 |
-
} // namespace reference
|
| 142 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h
DELETED
|
@@ -1,514 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
#pragma once
|
| 32 |
-
|
| 33 |
-
#include <cmath>
|
| 34 |
-
|
| 35 |
-
#include "cutlass/cutlass.h"
|
| 36 |
-
#include "cutlass/complex.h"
|
| 37 |
-
#include "cutlass/functional.h"
|
| 38 |
-
#include "cutlass/numeric_conversion.h"
|
| 39 |
-
#include "cutlass/tensor_view.h"
|
| 40 |
-
#include "cutlass/util/device_memory.h"
|
| 41 |
-
#include "cutlass/util/reference/detail/linear_to_coordinate.h"
|
| 42 |
-
|
| 43 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
-
|
| 45 |
-
namespace cutlass {
|
| 46 |
-
namespace reference {
|
| 47 |
-
namespace device {
|
| 48 |
-
|
| 49 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
-
|
| 51 |
-
namespace kernel {
|
| 52 |
-
|
| 53 |
-
template <
|
| 54 |
-
typename Element,
|
| 55 |
-
typename Layout,
|
| 56 |
-
typename ComputeType,
|
| 57 |
-
typename ReduceOp,
|
| 58 |
-
typename TransformOp,
|
| 59 |
-
int kBlockSize = 128
|
| 60 |
-
>
|
| 61 |
-
__global__ void TensorTransformReducePartial(
|
| 62 |
-
TensorView<Element, Layout> view, /// View of the tensor to reduce over
|
| 63 |
-
ComputeType identity, /// Identity element of the reduction operation
|
| 64 |
-
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 65 |
-
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 66 |
-
ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 67 |
-
|
| 68 |
-
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 69 |
-
int64_t size = view.size();
|
| 70 |
-
|
| 71 |
-
__shared__ ComputeType scratchpad[kBlockSize];
|
| 72 |
-
|
| 73 |
-
for (; idx < size; idx += blockDim.x * gridDim.x) {
|
| 74 |
-
|
| 75 |
-
// Map linear thread ID onto tensor coordinate
|
| 76 |
-
typename Layout::TensorCoord coord;
|
| 77 |
-
|
| 78 |
-
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
|
| 79 |
-
|
| 80 |
-
if (view.contains(coord)) {
|
| 81 |
-
|
| 82 |
-
// Fetch element
|
| 83 |
-
Element x = view.at(coord);
|
| 84 |
-
|
| 85 |
-
// Transform
|
| 86 |
-
identity = reduce(identity, transform(x));
|
| 87 |
-
}
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
scratchpad[threadIdx.x] = identity;
|
| 91 |
-
|
| 92 |
-
__syncthreads();
|
| 93 |
-
|
| 94 |
-
// One thread performs the final reduction and stores out. This could be enhanced via
|
| 95 |
-
// a tree reduction and pipelining.
|
| 96 |
-
if (threadIdx.x == 0) {
|
| 97 |
-
|
| 98 |
-
for (int i = 1; i < kBlockSize; ++i) {
|
| 99 |
-
identity = reduce(identity, scratchpad[i]);
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
workspace[blockIdx.x] = identity;
|
| 103 |
-
}
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
template <
|
| 107 |
-
typename Element,
|
| 108 |
-
typename Layout,
|
| 109 |
-
typename ComputeType,
|
| 110 |
-
typename ReduceOp,
|
| 111 |
-
typename TransformOp,
|
| 112 |
-
int kBlockSize = 128
|
| 113 |
-
>
|
| 114 |
-
__global__ void TensorTransformReducePartial(
|
| 115 |
-
TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
|
| 116 |
-
TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
|
| 117 |
-
ComputeType identity, /// Identity element of the reduction operation
|
| 118 |
-
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 119 |
-
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 120 |
-
ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 121 |
-
|
| 122 |
-
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 123 |
-
auto size = static_cast<int64_t>(view_A.size());
|
| 124 |
-
|
| 125 |
-
__shared__ ComputeType scratchpad[kBlockSize];
|
| 126 |
-
|
| 127 |
-
for (; idx < size; idx += blockDim.x * gridDim.x) {
|
| 128 |
-
|
| 129 |
-
// Map linear thread ID onto tensor coordinate
|
| 130 |
-
typename Layout::TensorCoord coord;
|
| 131 |
-
|
| 132 |
-
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
|
| 133 |
-
|
| 134 |
-
if (view_A.contains(coord)) {
|
| 135 |
-
|
| 136 |
-
// Fetch element
|
| 137 |
-
Element a = view_A.at(coord);
|
| 138 |
-
Element b = view_B.at(coord);
|
| 139 |
-
|
| 140 |
-
// Transform
|
| 141 |
-
identity = reduce(identity, transform(a, b));
|
| 142 |
-
}
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
scratchpad[threadIdx.x] = identity;
|
| 146 |
-
|
| 147 |
-
__syncthreads();
|
| 148 |
-
|
| 149 |
-
// One thread performs the final reduction and stores out. This could be enhanced via
|
| 150 |
-
// a tree reduction and pipelining.
|
| 151 |
-
if (threadIdx.x == 0) {
|
| 152 |
-
|
| 153 |
-
for (int i = 1; i < kBlockSize; ++i) {
|
| 154 |
-
identity = reduce(identity, scratchpad[i]);
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
workspace[blockIdx.x] = identity;
|
| 158 |
-
}
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
template <
|
| 163 |
-
typename ComputeType,
|
| 164 |
-
typename ReduceOp,
|
| 165 |
-
int kBlockSize = 32
|
| 166 |
-
>
|
| 167 |
-
__global__ void TensorTransformReduceFinalize(
|
| 168 |
-
ComputeType *workspace,
|
| 169 |
-
ComputeType identity,
|
| 170 |
-
int workspace_size,
|
| 171 |
-
ReduceOp reduce) {
|
| 172 |
-
|
| 173 |
-
__shared__ ComputeType scratchpad[kBlockSize];
|
| 174 |
-
|
| 175 |
-
for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) {
|
| 176 |
-
identity = reduce(identity, workspace[idx]);
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
scratchpad[threadIdx.x] = identity;
|
| 180 |
-
|
| 181 |
-
__syncthreads();
|
| 182 |
-
|
| 183 |
-
if (threadIdx.x == 0) {
|
| 184 |
-
|
| 185 |
-
for (int i = 1; i < kBlockSize; ++i) {
|
| 186 |
-
identity = reduce(identity, scratchpad[i]);
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
workspace[0] = identity;
|
| 190 |
-
}
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
} // namespace kernel
|
| 194 |
-
|
| 195 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 196 |
-
|
| 197 |
-
/// Transform-reduce operation over the elements of a tensor
|
| 198 |
-
template <
|
| 199 |
-
typename Element,
|
| 200 |
-
typename Layout,
|
| 201 |
-
typename ComputeType,
|
| 202 |
-
typename ReduceOp,
|
| 203 |
-
typename TransformOp
|
| 204 |
-
>
|
| 205 |
-
ComputeType TensorTransformReduce(
|
| 206 |
-
TensorView<Element, Layout> view, /// View of the tensor to reduce over
|
| 207 |
-
ComputeType identity, /// Identity element of the reduction operation
|
| 208 |
-
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 209 |
-
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 210 |
-
ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 211 |
-
int workspace_size, /// Number of elements in workspace
|
| 212 |
-
cudaStream_t stream = nullptr, /// CUDA stream to launch into
|
| 213 |
-
bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
|
| 214 |
-
) {
|
| 215 |
-
|
| 216 |
-
int const kBlockSize = 128;
|
| 217 |
-
|
| 218 |
-
dim3 block(kBlockSize, 1);
|
| 219 |
-
dim3 grid(workspace_size, 1);
|
| 220 |
-
|
| 221 |
-
kernel::TensorTransformReducePartial<
|
| 222 |
-
Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
|
| 223 |
-
><<< grid, block, 0, stream >>>(
|
| 224 |
-
view, identity, reduce, transform, workspace
|
| 225 |
-
);
|
| 226 |
-
|
| 227 |
-
int const kFinalizeBlockSize = 32;
|
| 228 |
-
|
| 229 |
-
kernel::TensorTransformReduceFinalize<
|
| 230 |
-
ComputeType, ReduceOp, kFinalizeBlockSize
|
| 231 |
-
><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
|
| 232 |
-
workspace, identity, workspace_size, reduce
|
| 233 |
-
);
|
| 234 |
-
|
| 235 |
-
cudaStreamSynchronize(stream);
|
| 236 |
-
|
| 237 |
-
if (copy_out) {
|
| 238 |
-
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
|
| 239 |
-
if (result != cudaSuccess) {
|
| 240 |
-
throw std::runtime_error("cudaMemcpy() failed");
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
return identity;
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
/// Transform-reduce operation over the elements of two tensors, zipped together
|
| 248 |
-
template <
|
| 249 |
-
typename Element,
|
| 250 |
-
typename Layout,
|
| 251 |
-
typename ComputeType,
|
| 252 |
-
typename ReduceOp,
|
| 253 |
-
typename TransformOp
|
| 254 |
-
>
|
| 255 |
-
ComputeType TensorTransformReduce(
|
| 256 |
-
TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
|
| 257 |
-
TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
|
| 258 |
-
ComputeType identity, /// Identity element of the reduction operation
|
| 259 |
-
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 260 |
-
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 261 |
-
ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 262 |
-
int workspace_size, /// Number of elements in workspace
|
| 263 |
-
cudaStream_t stream = nullptr, /// CUDA stream to launch into
|
| 264 |
-
bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
|
| 265 |
-
) {
|
| 266 |
-
|
| 267 |
-
if (view_A.extent() != view_B.extent()) {
|
| 268 |
-
throw std::runtime_error("Extents must be equal.");
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
int const kBlockSize = 128;
|
| 272 |
-
|
| 273 |
-
dim3 block(kBlockSize, 1);
|
| 274 |
-
dim3 grid(workspace_size, 1);
|
| 275 |
-
|
| 276 |
-
kernel::TensorTransformReducePartial<
|
| 277 |
-
Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
|
| 278 |
-
><<< grid, block, 0, stream >>>(
|
| 279 |
-
view_A, view_B, identity, reduce, transform, workspace
|
| 280 |
-
);
|
| 281 |
-
|
| 282 |
-
int const kFinalizeBlockSize = 32;
|
| 283 |
-
|
| 284 |
-
kernel::TensorTransformReduceFinalize<
|
| 285 |
-
ComputeType, ReduceOp, kFinalizeBlockSize
|
| 286 |
-
><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
|
| 287 |
-
workspace, identity, workspace_size, reduce
|
| 288 |
-
);
|
| 289 |
-
|
| 290 |
-
cudaStreamSynchronize(stream);
|
| 291 |
-
|
| 292 |
-
if (copy_out) {
|
| 293 |
-
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
|
| 294 |
-
if (result != cudaSuccess) {
|
| 295 |
-
throw std::runtime_error("cudaMemcpy() failed");
|
| 296 |
-
}
|
| 297 |
-
}
|
| 298 |
-
|
| 299 |
-
return identity;
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 303 |
-
/// workspace
|
| 304 |
-
template <
|
| 305 |
-
typename Element,
|
| 306 |
-
typename Layout,
|
| 307 |
-
typename ComputeType,
|
| 308 |
-
typename ReduceOp,
|
| 309 |
-
typename TransformOp
|
| 310 |
-
>
|
| 311 |
-
ComputeType TensorTransformReduce(
|
| 312 |
-
TensorView<Element, Layout> view,
|
| 313 |
-
ComputeType identity,
|
| 314 |
-
ReduceOp reduce,
|
| 315 |
-
TransformOp transform,
|
| 316 |
-
cudaStream_t stream = nullptr,
|
| 317 |
-
int workspace_size = 0
|
| 318 |
-
) {
|
| 319 |
-
|
| 320 |
-
// Optionally query for the SM count to size the workspace.
|
| 321 |
-
if (!workspace_size) {
|
| 322 |
-
|
| 323 |
-
int device_idx = 0;
|
| 324 |
-
cudaDeviceProp prop;
|
| 325 |
-
|
| 326 |
-
cudaError_t result = cudaGetDevice(&device_idx);
|
| 327 |
-
if (result != cudaSuccess) {
|
| 328 |
-
throw std::runtime_error("cudaGetDevice() failed");
|
| 329 |
-
}
|
| 330 |
-
|
| 331 |
-
result = cudaGetDeviceProperties(&prop, device_idx);
|
| 332 |
-
if (result != cudaSuccess) {
|
| 333 |
-
throw std::runtime_error("cudaGetDeviceProp() failed");
|
| 334 |
-
}
|
| 335 |
-
|
| 336 |
-
workspace_size = int(prop.multiProcessorCount);
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
DeviceAllocation<ComputeType> workspace(workspace_size);
|
| 340 |
-
|
| 341 |
-
ComputeType output = TensorTransformReduce(
|
| 342 |
-
view,
|
| 343 |
-
identity,
|
| 344 |
-
reduce,
|
| 345 |
-
transform,
|
| 346 |
-
workspace.get(),
|
| 347 |
-
workspace_size,
|
| 348 |
-
stream,
|
| 349 |
-
true);
|
| 350 |
-
|
| 351 |
-
return output;
|
| 352 |
-
}
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 356 |
-
/// workspace
|
| 357 |
-
template <
|
| 358 |
-
typename Element,
|
| 359 |
-
typename Layout,
|
| 360 |
-
typename ComputeType,
|
| 361 |
-
typename ReduceOp,
|
| 362 |
-
typename TransformOp
|
| 363 |
-
>
|
| 364 |
-
ComputeType TensorTransformReduce(
|
| 365 |
-
TensorView<Element, Layout> view_A,
|
| 366 |
-
TensorView<Element, Layout> view_B,
|
| 367 |
-
ComputeType identity,
|
| 368 |
-
ReduceOp reduce,
|
| 369 |
-
TransformOp transform,
|
| 370 |
-
cudaStream_t stream = nullptr,
|
| 371 |
-
int workspace_size = 0
|
| 372 |
-
) {
|
| 373 |
-
|
| 374 |
-
// Optionally query for the SM count to size the workspace.
|
| 375 |
-
if (!workspace_size) {
|
| 376 |
-
|
| 377 |
-
int device_idx = 0;
|
| 378 |
-
cudaDeviceProp prop;
|
| 379 |
-
|
| 380 |
-
cudaError_t result = cudaGetDevice(&device_idx);
|
| 381 |
-
if (result != cudaSuccess) {
|
| 382 |
-
throw std::runtime_error("cudaGetDevice() failed");
|
| 383 |
-
}
|
| 384 |
-
|
| 385 |
-
result = cudaGetDeviceProperties(&prop, device_idx);
|
| 386 |
-
if (result != cudaSuccess) {
|
| 387 |
-
throw std::runtime_error("cudaGetDeviceProp() failed");
|
| 388 |
-
}
|
| 389 |
-
|
| 390 |
-
workspace_size = int(prop.multiProcessorCount);
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
DeviceAllocation<ComputeType> workspace(workspace_size);
|
| 394 |
-
|
| 395 |
-
ComputeType output = TensorTransformReduce(
|
| 396 |
-
view_A,
|
| 397 |
-
view_B,
|
| 398 |
-
identity,
|
| 399 |
-
reduce,
|
| 400 |
-
transform,
|
| 401 |
-
workspace.get(),
|
| 402 |
-
workspace_size,
|
| 403 |
-
stream,
|
| 404 |
-
true);
|
| 405 |
-
|
| 406 |
-
return output;
|
| 407 |
-
}
|
| 408 |
-
|
| 409 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 410 |
-
|
| 411 |
-
/// Helper to compute the sum of the elements of a tensor
|
| 412 |
-
template <
|
| 413 |
-
typename Element,
|
| 414 |
-
typename Layout,
|
| 415 |
-
typename ComputeType = Element
|
| 416 |
-
>
|
| 417 |
-
ComputeType TensorSum(
|
| 418 |
-
TensorView<Element, Layout> view,
|
| 419 |
-
ComputeType identity = ComputeType(),
|
| 420 |
-
cudaStream_t stream = nullptr,
|
| 421 |
-
int workspace_size = 0
|
| 422 |
-
) {
|
| 423 |
-
|
| 424 |
-
plus<ComputeType> reduce;
|
| 425 |
-
NumericConverter<ComputeType, Element> transform;
|
| 426 |
-
|
| 427 |
-
return TensorTransformReduce(
|
| 428 |
-
view, identity, reduce, transform, stream, workspace_size);
|
| 429 |
-
}
|
| 430 |
-
|
| 431 |
-
/// Helper to compute the sum of the squares of the elements of a tensor
|
| 432 |
-
template <
|
| 433 |
-
typename Element,
|
| 434 |
-
typename Layout,
|
| 435 |
-
typename ComputeType = Element
|
| 436 |
-
>
|
| 437 |
-
ComputeType TensorSumSq(
|
| 438 |
-
TensorView<Element, Layout> view,
|
| 439 |
-
ComputeType identity = ComputeType(),
|
| 440 |
-
cudaStream_t stream = nullptr,
|
| 441 |
-
int workspace_size = 0
|
| 442 |
-
) {
|
| 443 |
-
|
| 444 |
-
plus<ComputeType> reduce;
|
| 445 |
-
magnitude_squared<Element, ComputeType> transform;
|
| 446 |
-
|
| 447 |
-
return TensorTransformReduce(
|
| 448 |
-
view, identity, reduce, transform, stream, workspace_size);
|
| 449 |
-
}
|
| 450 |
-
|
| 451 |
-
/// Helper to compute the norm of the elements of a tensor.
|
| 452 |
-
template <
|
| 453 |
-
typename Element,
|
| 454 |
-
typename Layout,
|
| 455 |
-
typename ComputeType = double
|
| 456 |
-
>
|
| 457 |
-
ComputeType TensorNorm(
|
| 458 |
-
TensorView<Element, Layout> view,
|
| 459 |
-
ComputeType identity = ComputeType(),
|
| 460 |
-
cudaStream_t stream = nullptr,
|
| 461 |
-
int workspace_size = 0
|
| 462 |
-
) {
|
| 463 |
-
|
| 464 |
-
return std::sqrt(TensorSumSq(view, identity, stream, workspace_size));
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 468 |
-
|
| 469 |
-
/// Helper to compute the sum of the squares of the differences of two tensors
|
| 470 |
-
template <
|
| 471 |
-
typename Element,
|
| 472 |
-
typename Layout,
|
| 473 |
-
typename ComputeType = double
|
| 474 |
-
>
|
| 475 |
-
ComputeType TensorSumSqDiff(
|
| 476 |
-
TensorView<Element, Layout> view_A,
|
| 477 |
-
TensorView<Element, Layout> view_B,
|
| 478 |
-
ComputeType identity = ComputeType(),
|
| 479 |
-
cudaStream_t stream = nullptr,
|
| 480 |
-
int workspace_size = 0
|
| 481 |
-
) {
|
| 482 |
-
|
| 483 |
-
plus<ComputeType> reduce;
|
| 484 |
-
magnitude_squared_difference<Element, ComputeType> transform;
|
| 485 |
-
|
| 486 |
-
return TensorTransformReduce(
|
| 487 |
-
view_A, view_B, identity, reduce, transform, stream, workspace_size);
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
|
| 492 |
-
template <
|
| 493 |
-
typename Element,
|
| 494 |
-
typename Layout,
|
| 495 |
-
typename ComputeType = double
|
| 496 |
-
>
|
| 497 |
-
ComputeType TensorNormDiff(
|
| 498 |
-
TensorView<Element, Layout> view_A,
|
| 499 |
-
TensorView<Element, Layout> view_B,
|
| 500 |
-
ComputeType identity = ComputeType(),
|
| 501 |
-
cudaStream_t stream = nullptr,
|
| 502 |
-
int workspace_size = 0
|
| 503 |
-
) {
|
| 504 |
-
|
| 505 |
-
return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size));
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 509 |
-
|
| 510 |
-
} // namespace device
|
| 511 |
-
} // namespace reference
|
| 512 |
-
} // namespace cutlass
|
| 513 |
-
|
| 514 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/* \file
|
| 32 |
-
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
|
| 33 |
-
in this header are not specialized for any particular data layout and are therefore not
|
| 34 |
-
intended to offer the best possible performance. Rather, they are intended to be generic
|
| 35 |
-
reference implementations to support the CUTLASS unit tests.
|
| 36 |
-
*/
|
| 37 |
-
|
| 38 |
-
#pragma once
|
| 39 |
-
|
| 40 |
-
// Cutlass includes
|
| 41 |
-
#include "cutlass/cutlass.h"
|
| 42 |
-
#include "cutlass/tensor_view.h"
|
| 43 |
-
|
| 44 |
-
#include "cutlass/util/reference/device/tensor_foreach.h"
|
| 45 |
-
|
| 46 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
-
|
| 48 |
-
namespace cutlass {
|
| 49 |
-
namespace reference {
|
| 50 |
-
namespace device {
|
| 51 |
-
|
| 52 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
-
|
| 55 |
-
namespace detail {
|
| 56 |
-
|
| 57 |
-
template <
|
| 58 |
-
typename Element, ///< Element type
|
| 59 |
-
typename Layout> ///< Layout function
|
| 60 |
-
struct TensorReLuFunc {
|
| 61 |
-
|
| 62 |
-
/// View type
|
| 63 |
-
using TensorView = TensorView<Element, Layout>;
|
| 64 |
-
|
| 65 |
-
/// Coordinate in tensor's index space
|
| 66 |
-
using TensorCoord = typename TensorView::TensorCoord;
|
| 67 |
-
|
| 68 |
-
/// Parameters structure
|
| 69 |
-
struct Params {
|
| 70 |
-
|
| 71 |
-
//
|
| 72 |
-
// Data members
|
| 73 |
-
//
|
| 74 |
-
|
| 75 |
-
TensorView view;
|
| 76 |
-
Element threshold;
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
//
|
| 80 |
-
// Methods
|
| 81 |
-
//
|
| 82 |
-
|
| 83 |
-
Params(
|
| 84 |
-
TensorView view_ = TensorView(),
|
| 85 |
-
Element threshold_ = Element(0)
|
| 86 |
-
):
|
| 87 |
-
view(view_), threshold(threshold_) {
|
| 88 |
-
|
| 89 |
-
}
|
| 90 |
-
};
|
| 91 |
-
|
| 92 |
-
//
|
| 93 |
-
// Data members
|
| 94 |
-
//
|
| 95 |
-
|
| 96 |
-
Params params;
|
| 97 |
-
|
| 98 |
-
//
|
| 99 |
-
// Methods
|
| 100 |
-
//
|
| 101 |
-
|
| 102 |
-
CUTLASS_DEVICE
|
| 103 |
-
TensorReLuFunc(Params const ¶ms): params(params) {
|
| 104 |
-
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
CUTLASS_DEVICE
|
| 108 |
-
void operator()(TensorCoord const &coord) {
|
| 109 |
-
|
| 110 |
-
Element const & value = params.view.at(coord);
|
| 111 |
-
params.view.at(coord) = (value < params.threshold) ? params.threshold : value;
|
| 112 |
-
}
|
| 113 |
-
};
|
| 114 |
-
|
| 115 |
-
} // namespace detail
|
| 116 |
-
|
| 117 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 118 |
-
|
| 119 |
-
/// Apply ReLu on a tensor
|
| 120 |
-
template <
|
| 121 |
-
typename Element, ///< Element type
|
| 122 |
-
typename Layout> ///< Layout function
|
| 123 |
-
void TensorReLu(
|
| 124 |
-
TensorView<Element, Layout> view, ///< destination tensor
|
| 125 |
-
Element threshold = Element(0)) { ///< ReLu threshold
|
| 126 |
-
|
| 127 |
-
using Func = detail::TensorReLuFunc<Element, Layout>;
|
| 128 |
-
using Params = typename Func::Params;
|
| 129 |
-
|
| 130 |
-
TensorForEach<Func, Layout::kRank, Params>(
|
| 131 |
-
view.extent(),
|
| 132 |
-
Params(view, threshold)
|
| 133 |
-
);
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 137 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 138 |
-
|
| 139 |
-
} // namespace device
|
| 140 |
-
} // namespace reference
|
| 141 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h
DELETED
|
@@ -1,186 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/tensor_view.h"
|
| 39 |
-
#include "cutlass/gemm/gemm.h"
|
| 40 |
-
|
| 41 |
-
namespace cutlass {
|
| 42 |
-
namespace reference {
|
| 43 |
-
namespace device {
|
| 44 |
-
namespace thread {
|
| 45 |
-
|
| 46 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
-
|
| 48 |
-
/// Thread-level blocked general matrix product.
|
| 49 |
-
//
|
| 50 |
-
// Note, this is a reference implementation. Performance is not expected to approach peak.
|
| 51 |
-
//
|
| 52 |
-
template <
|
| 53 |
-
typename TensorRefA,
|
| 54 |
-
typename TensorRefB,
|
| 55 |
-
typename TensorRefC,
|
| 56 |
-
typename ScalarType,
|
| 57 |
-
typename AccumulatorType,
|
| 58 |
-
typename OutputTile,
|
| 59 |
-
typename InnerProductOp = multiply_add<AccumulatorType>,
|
| 60 |
-
typename ConvertOp = NumericConverter<typename TensorRefC::Element, ScalarType>
|
| 61 |
-
>
|
| 62 |
-
struct Gemm {
|
| 63 |
-
|
| 64 |
-
using ElementA = typename TensorRefA::Element;
|
| 65 |
-
using ElementB = typename TensorRefB::Element;
|
| 66 |
-
using ElementC = typename TensorRefC::Element;
|
| 67 |
-
|
| 68 |
-
//
|
| 69 |
-
// Data members
|
| 70 |
-
//
|
| 71 |
-
|
| 72 |
-
/// Tile for A operand
|
| 73 |
-
ElementA A_tile[OutputTile::kColumn];
|
| 74 |
-
|
| 75 |
-
/// Tile for B operand
|
| 76 |
-
ElementB B_tile[OutputTile::kRow];
|
| 77 |
-
|
| 78 |
-
/// Tile for Accumulator
|
| 79 |
-
AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow];
|
| 80 |
-
|
| 81 |
-
//
|
| 82 |
-
// Methods
|
| 83 |
-
//
|
| 84 |
-
|
| 85 |
-
/// Constructor
|
| 86 |
-
CUTLASS_HOST_DEVICE
|
| 87 |
-
Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 88 |
-
|
| 89 |
-
// Clear fetch registers
|
| 90 |
-
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 91 |
-
A_tile[i] = ElementA(0);
|
| 92 |
-
}
|
| 93 |
-
|
| 94 |
-
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 95 |
-
B_tile[j] = ElementB(0);
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
// Clear accumulators
|
| 99 |
-
CUTLASS_PRAGMA_UNROLL
|
| 100 |
-
for (int j = 0; j < OutputTile::kColumn; ++j) {
|
| 101 |
-
CUTLASS_PRAGMA_UNROLL
|
| 102 |
-
for (int i = 0; i < OutputTile::kRow; ++i) {
|
| 103 |
-
accum[j][i] = initial_accum;
|
| 104 |
-
}
|
| 105 |
-
}
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
/// Computes a matrix product
|
| 109 |
-
CUTLASS_HOST_DEVICE
|
| 110 |
-
Gemm & multiply_add(
|
| 111 |
-
gemm::GemmCoord problem_size,
|
| 112 |
-
TensorRefA tensor_a,
|
| 113 |
-
TensorRefB tensor_b,
|
| 114 |
-
MatrixCoord output_coord = MatrixCoord()) {
|
| 115 |
-
|
| 116 |
-
InnerProductOp inner_product_op;
|
| 117 |
-
|
| 118 |
-
// Loop over the GEMM K dimension
|
| 119 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 120 |
-
for (int k = 0; k < problem_size.k(); ++k) {
|
| 121 |
-
|
| 122 |
-
// Fetch a slice of the A matrix
|
| 123 |
-
CUTLASS_PRAGMA_UNROLL
|
| 124 |
-
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 125 |
-
if (output_coord.row() + i < problem_size.m()) {
|
| 126 |
-
A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k));
|
| 127 |
-
}
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
// Fetch a slice of the B matrix
|
| 131 |
-
CUTLASS_PRAGMA_UNROLL
|
| 132 |
-
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 133 |
-
if (output_coord.column() + j < problem_size.n()) {
|
| 134 |
-
B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j));
|
| 135 |
-
}
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
// Compute an accumulated matrix product
|
| 139 |
-
CUTLASS_PRAGMA_UNROLL
|
| 140 |
-
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 141 |
-
CUTLASS_PRAGMA_UNROLL
|
| 142 |
-
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 143 |
-
accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]);
|
| 144 |
-
}
|
| 145 |
-
}
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
return *this;
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
/// Performs linear scaling of matrix product and updates output tensor
|
| 152 |
-
CUTLASS_HOST_DEVICE
|
| 153 |
-
Gemm & epilogue(
|
| 154 |
-
gemm::GemmCoord problem_size,
|
| 155 |
-
ScalarType alpha,
|
| 156 |
-
ScalarType beta,
|
| 157 |
-
TensorRefC tensor_c,
|
| 158 |
-
TensorRefC tensor_d,
|
| 159 |
-
MatrixCoord output_coord = MatrixCoord()) {
|
| 160 |
-
|
| 161 |
-
ConvertOp convert_op;
|
| 162 |
-
|
| 163 |
-
// Update the output tensor
|
| 164 |
-
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 165 |
-
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 166 |
-
MatrixCoord coord = output_coord + MatrixCoord(i, j);
|
| 167 |
-
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
| 168 |
-
|
| 169 |
-
tensor_d.at(coord) = convert_op(
|
| 170 |
-
alpha * ScalarType(accum[j][i]) +
|
| 171 |
-
beta * ScalarType(tensor_c.at(coord))
|
| 172 |
-
);
|
| 173 |
-
}
|
| 174 |
-
}
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
return *this;
|
| 178 |
-
}
|
| 179 |
-
};
|
| 180 |
-
|
| 181 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 182 |
-
|
| 183 |
-
} // namespace thread
|
| 184 |
-
} // namespace device
|
| 185 |
-
} // namespace reference
|
| 186 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp
DELETED
|
@@ -1,782 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for CONV in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
#pragma once
|
| 35 |
-
|
| 36 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 37 |
-
|
| 38 |
-
#include "cutlass/complex.h"
|
| 39 |
-
#include "cutlass/numeric_conversion.h"
|
| 40 |
-
#include "cutlass/epilogue/thread/activation.h"
|
| 41 |
-
|
| 42 |
-
#include "cute/tensor.hpp"
|
| 43 |
-
|
| 44 |
-
#include <cuda_runtime.h>
|
| 45 |
-
|
| 46 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
-
|
| 48 |
-
namespace cutlass::reference::host {
|
| 49 |
-
|
| 50 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
-
|
| 52 |
-
namespace detail {
|
| 53 |
-
|
| 54 |
-
template<class EngineAct, class LayoutAct>
|
| 55 |
-
bool
|
| 56 |
-
is_activation_in_bounds(
|
| 57 |
-
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
| 58 |
-
int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
|
| 59 |
-
return ((g_ >= 0 && g_ < size<5>(activation)) &&
|
| 60 |
-
(n_ >= 0 && n_ < size<4>(activation)) &&
|
| 61 |
-
(d_ >= 0 && d_ < size<3>(activation)) &&
|
| 62 |
-
(h_ >= 0 && h_ < size<2>(activation)) &&
|
| 63 |
-
(w_ >= 0 && w_ < size<1>(activation)) &&
|
| 64 |
-
(c_ >= 0 && c_ < size<0>(activation)));
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
template<class EngineAct, class LayoutAct>
|
| 68 |
-
bool
|
| 69 |
-
is_activation_in_bounds(
|
| 70 |
-
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
| 71 |
-
int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
|
| 72 |
-
return ((g_ >= 0 && g_ < size<4>(activation)) &&
|
| 73 |
-
(n_ >= 0 && n_ < size<3>(activation)) &&
|
| 74 |
-
(h_ >= 0 && h_ < size<2>(activation)) &&
|
| 75 |
-
(w_ >= 0 && w_ < size<1>(activation)) &&
|
| 76 |
-
(c_ >= 0 && c_ < size<0>(activation)));
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
template<class EngineAct, class LayoutAct>
|
| 80 |
-
bool
|
| 81 |
-
is_activation_in_bounds(
|
| 82 |
-
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
| 83 |
-
int32_t n_, int32_t w_, int32_t c_, int32_t g_) {
|
| 84 |
-
return ((g_ >= 0 && g_ < size<3>(activation)) &&
|
| 85 |
-
(n_ >= 0 && n_ < size<2>(activation)) &&
|
| 86 |
-
(w_ >= 0 && w_ < size<1>(activation)) &&
|
| 87 |
-
(c_ >= 0 && c_ < size<0>(activation)));
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
} // namespace detail
|
| 91 |
-
|
| 92 |
-
template<
|
| 93 |
-
class ElementAcc_,
|
| 94 |
-
class ElementScalar_,
|
| 95 |
-
class ElementCompute_,
|
| 96 |
-
class ElementC_,
|
| 97 |
-
class ElementOut_,
|
| 98 |
-
bool ResidualAdd_,
|
| 99 |
-
class TensorAlpha_,
|
| 100 |
-
class TensorBeta_,
|
| 101 |
-
class TensorBias_,
|
| 102 |
-
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>
|
| 103 |
-
>
|
| 104 |
-
struct ConvEpilogueFusionParams {
|
| 105 |
-
using ElementAcc = ElementAcc_;
|
| 106 |
-
using ElementScalar = ElementScalar_;
|
| 107 |
-
using ElementCompute = ElementCompute_;
|
| 108 |
-
using ElementC = ElementC_;
|
| 109 |
-
using ElementOut = ElementOut_;
|
| 110 |
-
using TensorAlpha = TensorAlpha_;
|
| 111 |
-
using TensorBeta = TensorBeta_;
|
| 112 |
-
using TensorBias = TensorBias_;
|
| 113 |
-
using ActivationFunctor = ActivationFunctor_;
|
| 114 |
-
static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation
|
| 115 |
-
|
| 116 |
-
ElementScalar alpha = ElementScalar(1);
|
| 117 |
-
ElementScalar beta = ElementScalar(0);
|
| 118 |
-
|
| 119 |
-
TensorAlpha tensor_alpha{};
|
| 120 |
-
TensorBeta tensor_beta{};
|
| 121 |
-
TensorBias tensor_bias{};
|
| 122 |
-
};
|
| 123 |
-
|
| 124 |
-
template<
|
| 125 |
-
cutlass::conv::Operator ConvOp,
|
| 126 |
-
int NumSpatialDims,
|
| 127 |
-
class TensorA,
|
| 128 |
-
class TensorB,
|
| 129 |
-
class TensorC,
|
| 130 |
-
class TensorD,
|
| 131 |
-
class ShapePadding,
|
| 132 |
-
class StrideTraversal,
|
| 133 |
-
class ShapeDilation,
|
| 134 |
-
class EpilogueFusionParams
|
| 135 |
-
>
|
| 136 |
-
struct ConvReferenceImpl {
|
| 137 |
-
// Hard code accumlulator type to float to avoid data lost in accumulating add.
|
| 138 |
-
using ElementAcc = cutlass::platform::conditional_t<cutlass::platform::is_same_v<typename EpilogueFusionParams::ElementAcc, double>, double, float>;
|
| 139 |
-
using ElementC = typename EpilogueFusionParams::ElementC;
|
| 140 |
-
using ElementOut = typename EpilogueFusionParams::ElementOut;
|
| 141 |
-
using ElementScalar = typename EpilogueFusionParams::ElementScalar;
|
| 142 |
-
using ElementCompute = typename EpilogueFusionParams::ElementCompute;
|
| 143 |
-
using ElementBias = typename EpilogueFusionParams::TensorBias::value_type;
|
| 144 |
-
using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor;
|
| 145 |
-
|
| 146 |
-
// Input related converter
|
| 147 |
-
NumericConverter<ElementCompute, ElementAcc> acc_converter;
|
| 148 |
-
NumericConverter<ElementCompute, ElementC> residual_converter;
|
| 149 |
-
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
| 150 |
-
// Scale related converter
|
| 151 |
-
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
| 152 |
-
// Output related converter
|
| 153 |
-
NumericConverter<ElementOut, ElementCompute> output_converter;
|
| 154 |
-
|
| 155 |
-
EpilogueFusionParams& epi_fusion_params_;
|
| 156 |
-
TensorA const& tensor_a_;
|
| 157 |
-
TensorB const& tensor_b_;
|
| 158 |
-
TensorC const& tensor_c_;
|
| 159 |
-
TensorD& tensor_d_;
|
| 160 |
-
|
| 161 |
-
ShapePadding const& padding_;
|
| 162 |
-
StrideTraversal const& tstride_;
|
| 163 |
-
ShapeDilation const& dilation_;
|
| 164 |
-
|
| 165 |
-
// Epilogue activation operation
|
| 166 |
-
ActivationFunctor epi_activation;
|
| 167 |
-
|
| 168 |
-
ConvReferenceImpl(
|
| 169 |
-
TensorA const& tensor_a,
|
| 170 |
-
TensorB const& tensor_b,
|
| 171 |
-
TensorC const& tensor_c,
|
| 172 |
-
TensorD& tensor_d,
|
| 173 |
-
ShapePadding const& padding,
|
| 174 |
-
StrideTraversal const& tstride,
|
| 175 |
-
ShapeDilation const& dilation,
|
| 176 |
-
EpilogueFusionParams& epi_fusion_params)
|
| 177 |
-
: tensor_a_(tensor_a),
|
| 178 |
-
tensor_b_(tensor_b),
|
| 179 |
-
tensor_c_(tensor_c),
|
| 180 |
-
tensor_d_(tensor_d),
|
| 181 |
-
padding_(padding),
|
| 182 |
-
tstride_(tstride),
|
| 183 |
-
dilation_(dilation),
|
| 184 |
-
epi_fusion_params_(epi_fusion_params)
|
| 185 |
-
{
|
| 186 |
-
static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
|
| 187 |
-
static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
void compute_reference() {
|
| 191 |
-
if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
|
| 192 |
-
fprop_reference(cute::Int<NumSpatialDims>{});
|
| 193 |
-
}
|
| 194 |
-
else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 195 |
-
dgrad_reference(cute::Int<NumSpatialDims>{});
|
| 196 |
-
}
|
| 197 |
-
else {
|
| 198 |
-
wgrad_reference(cute::Int<NumSpatialDims>{});
|
| 199 |
-
}
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
private:
|
| 203 |
-
// Specialization for 1D fprop kernel
|
| 204 |
-
void fprop_reference(cute::Int<1> spatial_dims) {
|
| 205 |
-
int32_t G = size<3>(tensor_d_);
|
| 206 |
-
int32_t N = size<2>(tensor_d_);
|
| 207 |
-
int32_t Q = size<1>(tensor_d_);
|
| 208 |
-
int32_t K = size<0>(tensor_d_);
|
| 209 |
-
int32_t S = size<1>(tensor_b_);
|
| 210 |
-
int32_t C = size<0>(tensor_b_);
|
| 211 |
-
|
| 212 |
-
#if defined(_OPENMP)
|
| 213 |
-
#pragma omp parallel for collapse(2)
|
| 214 |
-
#endif
|
| 215 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 216 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 217 |
-
for (int32_t q = 0; q < Q; ++q) {
|
| 218 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 219 |
-
auto accumulator = ElementAcc(0);
|
| 220 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 221 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 222 |
-
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 223 |
-
if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) {
|
| 224 |
-
auto a = tensor_a_(c, w, n, g);
|
| 225 |
-
auto b = tensor_b_(c, s, k, g);
|
| 226 |
-
accumulator += ElementAcc(a * b);
|
| 227 |
-
}
|
| 228 |
-
}
|
| 229 |
-
}
|
| 230 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 231 |
-
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
| 232 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 233 |
-
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
| 234 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 235 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 236 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
|
| 237 |
-
}
|
| 238 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 239 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
| 240 |
-
}
|
| 241 |
-
output = epi_activation(output);
|
| 242 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 243 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
|
| 244 |
-
}
|
| 245 |
-
tensor_d_(k, q, n, g) = output_converter(output);
|
| 246 |
-
}
|
| 247 |
-
}
|
| 248 |
-
}
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
// Specialization for 2D fprop kernel
|
| 254 |
-
void fprop_reference(cute::Int<2> spatial_dims) {
|
| 255 |
-
int32_t G = size<4>(tensor_d_);
|
| 256 |
-
int32_t N = size<3>(tensor_d_);
|
| 257 |
-
int32_t P = size<2>(tensor_d_);
|
| 258 |
-
int32_t Q = size<1>(tensor_d_);
|
| 259 |
-
int32_t K = size<0>(tensor_d_);
|
| 260 |
-
int32_t R = size<2>(tensor_b_);
|
| 261 |
-
int32_t S = size<1>(tensor_b_);
|
| 262 |
-
int32_t C = size<0>(tensor_b_);
|
| 263 |
-
|
| 264 |
-
#if defined(_OPENMP)
|
| 265 |
-
#pragma omp parallel for collapse(3)
|
| 266 |
-
#endif
|
| 267 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 268 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 269 |
-
for (int32_t p = 0; p < P; ++p) {
|
| 270 |
-
for (int32_t q = 0; q < Q; ++q) {
|
| 271 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 272 |
-
auto accumulator = ElementAcc(0);
|
| 273 |
-
for (int32_t r = 0; r < R; ++r) {
|
| 274 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 275 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 276 |
-
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 277 |
-
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 278 |
-
if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) {
|
| 279 |
-
auto a = tensor_a_(c, w, h, n, g);
|
| 280 |
-
auto b = tensor_b_(c, s, r, k, g);
|
| 281 |
-
accumulator += ElementAcc(a * b);
|
| 282 |
-
}
|
| 283 |
-
}
|
| 284 |
-
}
|
| 285 |
-
}
|
| 286 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 287 |
-
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
| 288 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 289 |
-
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
| 290 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 291 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 292 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
|
| 293 |
-
}
|
| 294 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 295 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
| 296 |
-
}
|
| 297 |
-
output = epi_activation(output);
|
| 298 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 299 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
|
| 300 |
-
}
|
| 301 |
-
tensor_d_(k, q, p, n, g) = output_converter(output);
|
| 302 |
-
}
|
| 303 |
-
}
|
| 304 |
-
}
|
| 305 |
-
}
|
| 306 |
-
}
|
| 307 |
-
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
// Specialization for 3D fprop kernel
|
| 311 |
-
void fprop_reference(cute::Int<3> spatial_dims) {
|
| 312 |
-
int32_t G = size<5>(tensor_d_);
|
| 313 |
-
int32_t N = size<4>(tensor_d_);
|
| 314 |
-
int32_t Z = size<3>(tensor_d_);
|
| 315 |
-
int32_t P = size<2>(tensor_d_);
|
| 316 |
-
int32_t Q = size<1>(tensor_d_);
|
| 317 |
-
int32_t K = size<0>(tensor_d_);
|
| 318 |
-
int32_t T = size<3>(tensor_b_);
|
| 319 |
-
int32_t R = size<2>(tensor_b_);
|
| 320 |
-
int32_t S = size<1>(tensor_b_);
|
| 321 |
-
int32_t C = size<0>(tensor_b_);
|
| 322 |
-
|
| 323 |
-
#if defined(_OPENMP)
|
| 324 |
-
#pragma omp parallel for collapse(3)
|
| 325 |
-
#endif
|
| 326 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 327 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 328 |
-
for (int32_t z = 0; z < Z; ++z) {
|
| 329 |
-
for (int32_t p = 0; p < P; ++p) {
|
| 330 |
-
for (int32_t q = 0; q < Q; ++q) {
|
| 331 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 332 |
-
auto accumulator = ElementAcc(0);
|
| 333 |
-
for (int32_t t = 0; t < T; ++t) {
|
| 334 |
-
for (int32_t r = 0; r < R; ++r) {
|
| 335 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 336 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 337 |
-
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 338 |
-
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 339 |
-
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
| 340 |
-
if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) {
|
| 341 |
-
auto a = tensor_a_(c, w, h, d, n, g);
|
| 342 |
-
auto b = tensor_b_(c, s, r, t, k, g);
|
| 343 |
-
accumulator += ElementAcc(a * b);
|
| 344 |
-
}
|
| 345 |
-
}
|
| 346 |
-
}
|
| 347 |
-
}
|
| 348 |
-
}
|
| 349 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 350 |
-
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
| 351 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 352 |
-
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
| 353 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 354 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 355 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
|
| 356 |
-
}
|
| 357 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 358 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
| 359 |
-
}
|
| 360 |
-
output = epi_activation(output);
|
| 361 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 362 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
|
| 363 |
-
}
|
| 364 |
-
tensor_d_(k, q, p, z, n, g) = output_converter(output);
|
| 365 |
-
}
|
| 366 |
-
}
|
| 367 |
-
}
|
| 368 |
-
}
|
| 369 |
-
}
|
| 370 |
-
}
|
| 371 |
-
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
// Specialization for 1D dgrad kernel
|
| 375 |
-
void dgrad_reference(cute::Int<1> spatial_dims) {
|
| 376 |
-
int32_t G = size<3>(tensor_d_);
|
| 377 |
-
int32_t N = size<2>(tensor_d_);
|
| 378 |
-
int32_t W = size<1>(tensor_d_);
|
| 379 |
-
int32_t C = size<0>(tensor_d_);
|
| 380 |
-
int32_t K = size<2>(tensor_b_);
|
| 381 |
-
int32_t S = size<1>(tensor_b_);
|
| 382 |
-
|
| 383 |
-
#if defined(_OPENMP)
|
| 384 |
-
#pragma omp parallel for collapse(2)
|
| 385 |
-
#endif
|
| 386 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 387 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 388 |
-
for (int32_t w = 0; w < W; ++w) {
|
| 389 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 390 |
-
auto accumulator = ElementAcc(0);
|
| 391 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 392 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 393 |
-
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
| 394 |
-
|
| 395 |
-
if (q % cute::get<0>(tstride_) == 0) {
|
| 396 |
-
q /= cute::get<0>(tstride_);
|
| 397 |
-
} else {
|
| 398 |
-
continue;
|
| 399 |
-
}
|
| 400 |
-
|
| 401 |
-
if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) {
|
| 402 |
-
accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g));
|
| 403 |
-
}
|
| 404 |
-
}
|
| 405 |
-
}
|
| 406 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
| 407 |
-
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 408 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
| 409 |
-
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 410 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 411 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 412 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
|
| 413 |
-
}
|
| 414 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 415 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 416 |
-
}
|
| 417 |
-
output = epi_activation(output);
|
| 418 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 419 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
|
| 420 |
-
}
|
| 421 |
-
tensor_d_(c, w, n, g) = output_converter(output);
|
| 422 |
-
}
|
| 423 |
-
}
|
| 424 |
-
}
|
| 425 |
-
}
|
| 426 |
-
|
| 427 |
-
}
|
| 428 |
-
|
| 429 |
-
// Specialization for 2D dgrad kernel
|
| 430 |
-
void dgrad_reference(cute::Int<2> spatial_dims) {
|
| 431 |
-
int32_t G = size<4>(tensor_d_);
|
| 432 |
-
int32_t N = size<3>(tensor_d_);
|
| 433 |
-
int32_t H = size<2>(tensor_d_);
|
| 434 |
-
int32_t W = size<1>(tensor_d_);
|
| 435 |
-
int32_t C = size<0>(tensor_d_);
|
| 436 |
-
int32_t K = size<3>(tensor_b_);
|
| 437 |
-
int32_t R = size<2>(tensor_b_);
|
| 438 |
-
int32_t S = size<1>(tensor_b_);
|
| 439 |
-
|
| 440 |
-
#if defined(_OPENMP)
|
| 441 |
-
#pragma omp parallel for collapse(3)
|
| 442 |
-
#endif
|
| 443 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 444 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 445 |
-
for (int32_t h = 0; h < H; ++h) {
|
| 446 |
-
for (int32_t w = 0; w < W; ++w) {
|
| 447 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 448 |
-
auto accumulator = ElementAcc(0);
|
| 449 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 450 |
-
for (int32_t r = 0; r < R; ++r) {
|
| 451 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 452 |
-
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
| 453 |
-
int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
|
| 454 |
-
|
| 455 |
-
if (q % cute::get<0>(tstride_) == 0) {
|
| 456 |
-
q /= cute::get<0>(tstride_);
|
| 457 |
-
} else {
|
| 458 |
-
continue;
|
| 459 |
-
}
|
| 460 |
-
|
| 461 |
-
if (p % cute::get<1>(tstride_) == 0) {
|
| 462 |
-
p /= cute::get<1>(tstride_);
|
| 463 |
-
} else {
|
| 464 |
-
continue;
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) {
|
| 468 |
-
accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g));
|
| 469 |
-
}
|
| 470 |
-
}
|
| 471 |
-
}
|
| 472 |
-
}
|
| 473 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
| 474 |
-
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 475 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
| 476 |
-
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 477 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 478 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 479 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
|
| 480 |
-
}
|
| 481 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 482 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 483 |
-
}
|
| 484 |
-
output = epi_activation(output);
|
| 485 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 486 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
|
| 487 |
-
}
|
| 488 |
-
|
| 489 |
-
tensor_d_(c, w, h, n, g) = output_converter(output);
|
| 490 |
-
}
|
| 491 |
-
}
|
| 492 |
-
}
|
| 493 |
-
}
|
| 494 |
-
}
|
| 495 |
-
|
| 496 |
-
}
|
| 497 |
-
|
| 498 |
-
// Specialization for 3D dgrad kernel
|
| 499 |
-
void dgrad_reference(cute::Int<3> spatial_dims) {
|
| 500 |
-
int32_t G = size<5>(tensor_d_);
|
| 501 |
-
int32_t N = size<4>(tensor_d_);
|
| 502 |
-
int32_t D = size<3>(tensor_d_);
|
| 503 |
-
int32_t H = size<2>(tensor_d_);
|
| 504 |
-
int32_t W = size<1>(tensor_d_);
|
| 505 |
-
int32_t C = size<0>(tensor_d_);
|
| 506 |
-
int32_t K = size<4>(tensor_b_);
|
| 507 |
-
int32_t T = size<3>(tensor_b_);
|
| 508 |
-
int32_t R = size<2>(tensor_b_);
|
| 509 |
-
int32_t S = size<1>(tensor_b_);
|
| 510 |
-
|
| 511 |
-
#if defined(_OPENMP)
|
| 512 |
-
#pragma omp parallel for collapse(3)
|
| 513 |
-
#endif
|
| 514 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 515 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 516 |
-
for (int32_t d = 0; d < D; ++d) {
|
| 517 |
-
for (int32_t h = 0; h < H; ++h) {
|
| 518 |
-
for (int32_t w = 0; w < W; ++w) {
|
| 519 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 520 |
-
auto accumulator = ElementAcc(0);
|
| 521 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 522 |
-
for (int32_t t = 0; t < T; ++t) {
|
| 523 |
-
for (int32_t r = 0; r < R; ++r) {
|
| 524 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 525 |
-
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
| 526 |
-
int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
|
| 527 |
-
int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_);
|
| 528 |
-
|
| 529 |
-
if (q % cute::get<0>(tstride_) == 0) {
|
| 530 |
-
q /= cute::get<0>(tstride_);
|
| 531 |
-
} else {
|
| 532 |
-
continue;
|
| 533 |
-
}
|
| 534 |
-
|
| 535 |
-
if (p % cute::get<1>(tstride_) == 0) {
|
| 536 |
-
p /= cute::get<1>(tstride_);
|
| 537 |
-
} else {
|
| 538 |
-
continue;
|
| 539 |
-
}
|
| 540 |
-
|
| 541 |
-
if (z % cute::get<2>(tstride_) == 0) {
|
| 542 |
-
z /= cute::get<2>(tstride_);
|
| 543 |
-
} else {
|
| 544 |
-
continue;
|
| 545 |
-
}
|
| 546 |
-
|
| 547 |
-
if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) {
|
| 548 |
-
accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g));
|
| 549 |
-
}
|
| 550 |
-
}
|
| 551 |
-
}
|
| 552 |
-
}
|
| 553 |
-
}
|
| 554 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
| 555 |
-
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 556 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
| 557 |
-
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 558 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 559 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 560 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
|
| 561 |
-
}
|
| 562 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 563 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 564 |
-
}
|
| 565 |
-
output = epi_activation(output);
|
| 566 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 567 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
|
| 568 |
-
}
|
| 569 |
-
tensor_d_(c, w, h, d, n, g) = output_converter(output);
|
| 570 |
-
}
|
| 571 |
-
}
|
| 572 |
-
}
|
| 573 |
-
}
|
| 574 |
-
}
|
| 575 |
-
}
|
| 576 |
-
|
| 577 |
-
}
|
| 578 |
-
|
| 579 |
-
// Specialization for 1D wgrad kernel
|
| 580 |
-
void wgrad_reference(cute::Int<1> spatial_dims) {
|
| 581 |
-
int32_t G = size<3>(tensor_d_);
|
| 582 |
-
int32_t N =
|
| 583 |
-
size<2>(tensor_a_);
|
| 584 |
-
int32_t Q =
|
| 585 |
-
size<1>(tensor_a_);
|
| 586 |
-
int32_t K =
|
| 587 |
-
size<0>(tensor_a_);
|
| 588 |
-
int32_t S = size<1>(tensor_d_);
|
| 589 |
-
int32_t C = size<0>(tensor_d_);
|
| 590 |
-
|
| 591 |
-
#if defined(_OPENMP)
|
| 592 |
-
#pragma omp parallel for collapse(2)
|
| 593 |
-
#endif
|
| 594 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 595 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 596 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 597 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 598 |
-
auto accumulator = ElementAcc(0);
|
| 599 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 600 |
-
for (int32_t q = 0; q < Q; ++q) {
|
| 601 |
-
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 602 |
-
bool is_in_bounds =
|
| 603 |
-
detail::is_activation_in_bounds(tensor_b_, n, w, c, g);
|
| 604 |
-
if (is_in_bounds) {
|
| 605 |
-
auto act =
|
| 606 |
-
tensor_b_(c, w, n, g);
|
| 607 |
-
auto xformed_act =
|
| 608 |
-
tensor_a_(k, q, n, g);
|
| 609 |
-
accumulator += ElementAcc(act * xformed_act);
|
| 610 |
-
}
|
| 611 |
-
}
|
| 612 |
-
}
|
| 613 |
-
|
| 614 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 615 |
-
epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 616 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 617 |
-
epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 618 |
-
|
| 619 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 620 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 621 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
|
| 622 |
-
}
|
| 623 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 624 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 625 |
-
}
|
| 626 |
-
output = epi_activation(output);
|
| 627 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 628 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
|
| 629 |
-
}
|
| 630 |
-
tensor_d_(c, s, k, g) = output_converter(output);
|
| 631 |
-
}
|
| 632 |
-
}
|
| 633 |
-
}
|
| 634 |
-
}
|
| 635 |
-
}
|
| 636 |
-
|
| 637 |
-
// Specialization for 2D wgrad kernel
|
| 638 |
-
void wgrad_reference(cute::Int<2> spatial_dims) {
|
| 639 |
-
int32_t G = size<4>(tensor_d_);
|
| 640 |
-
int32_t N =
|
| 641 |
-
size<3>(tensor_a_);
|
| 642 |
-
int32_t P =
|
| 643 |
-
size<2>(tensor_a_);
|
| 644 |
-
int32_t Q =
|
| 645 |
-
size<1>(tensor_a_);
|
| 646 |
-
int32_t K =
|
| 647 |
-
size<0>(tensor_a_);
|
| 648 |
-
int32_t R = size<2>(tensor_d_);
|
| 649 |
-
int32_t S = size<1>(tensor_d_);
|
| 650 |
-
int32_t C = size<0>(tensor_d_);
|
| 651 |
-
|
| 652 |
-
#if defined(_OPENMP)
|
| 653 |
-
#pragma omp parallel for collapse(3)
|
| 654 |
-
#endif
|
| 655 |
-
for (int32_t g = 0; g < G; ++g) {
|
| 656 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 657 |
-
for (int32_t r = 0; r < R; ++r) {
|
| 658 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 659 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 660 |
-
auto accumulator = ElementAcc(0);
|
| 661 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 662 |
-
for (int32_t p = 0; p < P; ++p) {
|
| 663 |
-
for (int32_t q = 0; q < Q; ++q) {
|
| 664 |
-
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 665 |
-
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 666 |
-
bool is_in_bounds =
|
| 667 |
-
detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g);
|
| 668 |
-
if (is_in_bounds) {
|
| 669 |
-
auto act =
|
| 670 |
-
tensor_b_(c, w, h, n, g);
|
| 671 |
-
auto xformed_act =
|
| 672 |
-
tensor_a_(k, q, p, n, g);
|
| 673 |
-
accumulator += ElementAcc(act * xformed_act);
|
| 674 |
-
}
|
| 675 |
-
}
|
| 676 |
-
}
|
| 677 |
-
}
|
| 678 |
-
|
| 679 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 680 |
-
epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 681 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 682 |
-
epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 683 |
-
|
| 684 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 685 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 686 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
|
| 687 |
-
}
|
| 688 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 689 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 690 |
-
}
|
| 691 |
-
output = epi_activation(output);
|
| 692 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 693 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
|
| 694 |
-
}
|
| 695 |
-
tensor_d_(c, s, r, k, g) = output_converter(output);
|
| 696 |
-
}
|
| 697 |
-
}
|
| 698 |
-
}
|
| 699 |
-
}
|
| 700 |
-
}
|
| 701 |
-
}
|
| 702 |
-
|
| 703 |
-
// Specialization for 3D wgrad kernel
|
| 704 |
-
void wgrad_reference(cute::Int<3> spatial_dims) {
|
| 705 |
-
int32_t G = size<5>(tensor_d_);
|
| 706 |
-
int32_t N =
|
| 707 |
-
size<4>(tensor_a_);
|
| 708 |
-
int32_t Z =
|
| 709 |
-
size<3>(tensor_a_);
|
| 710 |
-
int32_t P =
|
| 711 |
-
size<2>(tensor_a_);
|
| 712 |
-
int32_t Q =
|
| 713 |
-
size<1>(tensor_a_);
|
| 714 |
-
int32_t K =
|
| 715 |
-
size<0>(tensor_a_);
|
| 716 |
-
int32_t T = size<3>(tensor_d_);
|
| 717 |
-
int32_t R = size<2>(tensor_d_);
|
| 718 |
-
int32_t S = size<1>(tensor_d_);
|
| 719 |
-
int32_t C = size<0>(tensor_d_);
|
| 720 |
-
|
| 721 |
-
#if defined(_OPENMP)
|
| 722 |
-
#pragma omp parallel for collapse(3)
|
| 723 |
-
#endif
|
| 724 |
-
for (int32_t g = 0 ; g < G; ++g) {
|
| 725 |
-
for (int32_t k = 0; k < K; ++k) {
|
| 726 |
-
for (int32_t t = 0; t < T; ++t) {
|
| 727 |
-
for (int32_t r = 0; r < R; ++r) {
|
| 728 |
-
for (int32_t s = 0; s < S; ++s) {
|
| 729 |
-
for (int32_t c = 0; c < C; ++c) {
|
| 730 |
-
auto accumulator = ElementAcc(0);
|
| 731 |
-
for (int32_t n = 0; n < N; ++n) {
|
| 732 |
-
for (int32_t z = 0; z < Z; ++z) {
|
| 733 |
-
for (int32_t p = 0; p < P; ++p) {
|
| 734 |
-
for (int32_t q = 0; q < Q; ++q) {
|
| 735 |
-
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 736 |
-
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 737 |
-
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
| 738 |
-
bool is_in_bounds =
|
| 739 |
-
detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g);
|
| 740 |
-
if (is_in_bounds) {
|
| 741 |
-
auto act =
|
| 742 |
-
tensor_b_(c, w, h, d, n, g);
|
| 743 |
-
auto xformed_act =
|
| 744 |
-
tensor_a_(k, q, p, z, n, g);
|
| 745 |
-
accumulator += ElementAcc(act * xformed_act);
|
| 746 |
-
}
|
| 747 |
-
}
|
| 748 |
-
}
|
| 749 |
-
}
|
| 750 |
-
}
|
| 751 |
-
|
| 752 |
-
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 753 |
-
epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 754 |
-
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 755 |
-
epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 756 |
-
|
| 757 |
-
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 758 |
-
if (not EpilogueFusionParams::ResidualAdd) {
|
| 759 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
|
| 760 |
-
}
|
| 761 |
-
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 762 |
-
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 763 |
-
}
|
| 764 |
-
output = epi_activation(output);
|
| 765 |
-
if (EpilogueFusionParams::ResidualAdd) {
|
| 766 |
-
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
|
| 767 |
-
}
|
| 768 |
-
tensor_d_(c, s, r, t, k, g) = output_converter(output);
|
| 769 |
-
}
|
| 770 |
-
}
|
| 771 |
-
}
|
| 772 |
-
}
|
| 773 |
-
}
|
| 774 |
-
}
|
| 775 |
-
}
|
| 776 |
-
};
|
| 777 |
-
|
| 778 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 779 |
-
|
| 780 |
-
} // cutlass::reference::host
|
| 781 |
-
|
| 782 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
DELETED
|
@@ -1,802 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
|
| 32 |
-
/*! \file
|
| 33 |
-
\brief Reference implementation for convolution in host-side code.
|
| 34 |
-
*/
|
| 35 |
-
|
| 36 |
-
#pragma once
|
| 37 |
-
|
| 38 |
-
#include "cutlass/coord.h"
|
| 39 |
-
#include "cutlass/functional.h"
|
| 40 |
-
#include "cutlass/layout/tensor.h"
|
| 41 |
-
#include "cutlass/numeric_conversion.h"
|
| 42 |
-
#include "cutlass/numeric_types.h"
|
| 43 |
-
#include "cutlass/tensor_ref.h"
|
| 44 |
-
#include "cutlass/tensor_view.h"
|
| 45 |
-
#include "cutlass/conv/convolution.h"
|
| 46 |
-
#include "cutlass/conv/conv2d_problem_size.h"
|
| 47 |
-
#include "cutlass/conv/conv3d_problem_size.h"
|
| 48 |
-
#include <iostream>
|
| 49 |
-
|
| 50 |
-
namespace cutlass {
|
| 51 |
-
namespace reference {
|
| 52 |
-
namespace host {
|
| 53 |
-
|
| 54 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
-
/// Forward propagation
|
| 56 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
-
|
| 58 |
-
/// y = conv2d(x, w)
|
| 59 |
-
template <
|
| 60 |
-
typename ElementA,
|
| 61 |
-
typename LayoutA,
|
| 62 |
-
typename ElementB,
|
| 63 |
-
typename LayoutB,
|
| 64 |
-
typename ElementC,
|
| 65 |
-
typename LayoutC,
|
| 66 |
-
typename ElementCompute,
|
| 67 |
-
typename ElementAccumulator = ElementCompute,
|
| 68 |
-
typename ElementD = ElementC,
|
| 69 |
-
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 70 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 71 |
-
>
|
| 72 |
-
void Conv2dFprop(
|
| 73 |
-
conv::Conv2dProblemSize problem_size,
|
| 74 |
-
TensorRef<ElementA, LayoutA> tensor_x,
|
| 75 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 76 |
-
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 77 |
-
TensorRef<ElementD, LayoutC> tensor_y_out,
|
| 78 |
-
ElementCompute alpha,
|
| 79 |
-
ElementCompute beta) {
|
| 80 |
-
|
| 81 |
-
ConvertOp convert_op;
|
| 82 |
-
InnerProductOp inner_product_op;
|
| 83 |
-
|
| 84 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 85 |
-
for (int n = 0; n < problem_size.N; ++n) {
|
| 86 |
-
for (int p = 0; p < problem_size.P; ++p) {
|
| 87 |
-
for (int q = 0; q < problem_size.Q; ++q) {
|
| 88 |
-
for (int k = 0; k < problem_size.K; ++k) {
|
| 89 |
-
|
| 90 |
-
int group_idx = k / (problem_size.K / problem_size.groups);
|
| 91 |
-
int channels_per_group = problem_size.C / problem_size.groups;
|
| 92 |
-
|
| 93 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 94 |
-
|
| 95 |
-
for (int r = 0; r < problem_size.R; ++r) {
|
| 96 |
-
for (int s = 0; s < problem_size.S; ++s) {
|
| 97 |
-
for (int c = 0; c < channels_per_group; ++c) {
|
| 98 |
-
|
| 99 |
-
int filter_r = r;
|
| 100 |
-
int filter_s = s;
|
| 101 |
-
|
| 102 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 103 |
-
filter_r = problem_size.R - 1 - r;
|
| 104 |
-
filter_s = problem_size.S - 1 - s;
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 108 |
-
int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 109 |
-
|
| 110 |
-
if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
|
| 111 |
-
|
| 112 |
-
ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group});
|
| 113 |
-
ElementB b = tensor_w.at({k, r, s, c});
|
| 114 |
-
|
| 115 |
-
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 116 |
-
|
| 117 |
-
}
|
| 118 |
-
}
|
| 119 |
-
}
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 123 |
-
ElementC c_ref = ElementC();
|
| 124 |
-
|
| 125 |
-
if (beta != ElementCompute()) {
|
| 126 |
-
c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k));
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) =
|
| 130 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 131 |
-
}
|
| 132 |
-
}
|
| 133 |
-
}
|
| 134 |
-
}
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
/// Depthwise-separable convolution
|
| 138 |
-
template <typename ElementA,
|
| 139 |
-
typename LayoutA,
|
| 140 |
-
typename ElementB,
|
| 141 |
-
typename LayoutB,
|
| 142 |
-
typename ElementC,
|
| 143 |
-
typename LayoutC,
|
| 144 |
-
typename ElementCompute,
|
| 145 |
-
typename ElementAccumulator = ElementCompute,
|
| 146 |
-
typename ElementD = ElementC,
|
| 147 |
-
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 148 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>>
|
| 149 |
-
void Depsep_Fprop(cutlass::TensorView<ElementA, LayoutA> tensor_A,
|
| 150 |
-
cutlass::TensorView<ElementB, LayoutB> tensor_B,
|
| 151 |
-
cutlass::TensorView<ElementC, LayoutC> tensor_C,
|
| 152 |
-
cutlass::TensorView<ElementD, LayoutC> tensor_D,
|
| 153 |
-
ElementCompute alpha,
|
| 154 |
-
ElementCompute beta,
|
| 155 |
-
cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(),
|
| 156 |
-
cutlass::Coord<2> conv_stride = cutlass::Coord<2>(),
|
| 157 |
-
cutlass::Coord<2> dilation = cutlass::Coord<2>(),
|
| 158 |
-
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) {
|
| 159 |
-
|
| 160 |
-
ConvertOp convert_op;
|
| 161 |
-
InnerProductOp inner_product_op;
|
| 162 |
-
|
| 163 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 164 |
-
for (int n = 0; n < tensor_C.extent().n(); ++n) {
|
| 165 |
-
for (int p = 0; p < tensor_C.extent().h(); ++p) {
|
| 166 |
-
for (int q = 0; q < tensor_C.extent().w(); ++q) {
|
| 167 |
-
for (int g = 0; g < tensor_C.extent().c(); ++g) {
|
| 168 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 169 |
-
for (int r = 0; r < tensor_B.extent().h(); ++r) {
|
| 170 |
-
for (int s = 0; s < tensor_B.extent().w(); ++s) {
|
| 171 |
-
|
| 172 |
-
// input activation H and W
|
| 173 |
-
int h = p * conv_stride[0] - padding[0] + r * dilation[0];
|
| 174 |
-
int w = q * conv_stride[1] - padding[2] + s * dilation[1];
|
| 175 |
-
|
| 176 |
-
if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) {
|
| 177 |
-
ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g));
|
| 178 |
-
|
| 179 |
-
ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation)
|
| 180 |
-
? tensor_B.at(cutlass::make_Coord(g, r, s, 0))
|
| 181 |
-
: tensor_B.at(cutlass::make_Coord(
|
| 182 |
-
g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0));
|
| 183 |
-
|
| 184 |
-
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 185 |
-
}
|
| 186 |
-
}
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 190 |
-
ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g));
|
| 191 |
-
tensor_D.at(cutlass::make_Coord(n, p, q, g)) =
|
| 192 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 193 |
-
}
|
| 194 |
-
}
|
| 195 |
-
}
|
| 196 |
-
}
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 200 |
-
/// Dgrad / Deconv
|
| 201 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 202 |
-
|
| 203 |
-
/// dx = dgrad(dy, w)
|
| 204 |
-
template <
|
| 205 |
-
typename ElementA,
|
| 206 |
-
typename LayoutA,
|
| 207 |
-
typename ElementB,
|
| 208 |
-
typename LayoutB,
|
| 209 |
-
typename ElementC,
|
| 210 |
-
typename LayoutC,
|
| 211 |
-
typename ElementCompute,
|
| 212 |
-
typename ElementAccumulator = ElementCompute,
|
| 213 |
-
typename ElementD = ElementC,
|
| 214 |
-
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 215 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 216 |
-
>
|
| 217 |
-
void Conv2dDgrad(
|
| 218 |
-
cutlass::conv::Conv2dProblemSize problem_size,
|
| 219 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 220 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 221 |
-
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 222 |
-
TensorRef<ElementD, LayoutC> tensor_dx_out,
|
| 223 |
-
ElementCompute alpha,
|
| 224 |
-
ElementCompute beta,
|
| 225 |
-
bool is_deconv = false) {
|
| 226 |
-
|
| 227 |
-
ConvertOp convert_op;
|
| 228 |
-
InnerProductOp inner_product_op;
|
| 229 |
-
|
| 230 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 231 |
-
for (int n = 0; n < problem_size.N; ++n) {
|
| 232 |
-
for (int h = 0; h < problem_size.H; ++h) {
|
| 233 |
-
for (int w = 0; w < problem_size.W; ++w) {
|
| 234 |
-
for (int c = 0; c < problem_size.C; ++c) {
|
| 235 |
-
|
| 236 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 237 |
-
|
| 238 |
-
for (int r = 0; r < problem_size.R; ++r) {
|
| 239 |
-
for (int s = 0; s < problem_size.S; ++s) {
|
| 240 |
-
for (int k = 0; k < problem_size.K; ++k) {
|
| 241 |
-
|
| 242 |
-
int filter_r = r;
|
| 243 |
-
int filter_s = s;
|
| 244 |
-
|
| 245 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 246 |
-
filter_r = problem_size.R - 1 - r;
|
| 247 |
-
filter_s = problem_size.S - 1 - s;
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 251 |
-
int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 252 |
-
|
| 253 |
-
if (p >= 0 && (p % problem_size.stride_h) == 0 &&
|
| 254 |
-
q >= 0 && (q % problem_size.stride_w) == 0) {
|
| 255 |
-
|
| 256 |
-
p = p / problem_size.stride_h;
|
| 257 |
-
q = q / problem_size.stride_w;
|
| 258 |
-
#if 0
|
| 259 |
-
std::cout << "row:"
|
| 260 |
-
<< n * problem_size.H * problem_size.W +
|
| 261 |
-
h * problem_size.W +
|
| 262 |
-
w << " "
|
| 263 |
-
<< "n, p, q: ("
|
| 264 |
-
<< n << ", "
|
| 265 |
-
<< p << ", "
|
| 266 |
-
<< q << ") * "
|
| 267 |
-
<< "r, s: ("
|
| 268 |
-
<< r << ", "
|
| 269 |
-
<< s << ") ["
|
| 270 |
-
<< ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]"
|
| 271 |
-
<< std::endl;
|
| 272 |
-
#endif
|
| 273 |
-
if (p < problem_size.P && q < problem_size.Q) {
|
| 274 |
-
|
| 275 |
-
ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k));
|
| 276 |
-
ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k))
|
| 277 |
-
: tensor_w.at(cutlass::make_Coord(k, r, s, c));
|
| 278 |
-
|
| 279 |
-
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 280 |
-
}
|
| 281 |
-
}
|
| 282 |
-
|
| 283 |
-
} // for (K)
|
| 284 |
-
} // for (S)
|
| 285 |
-
} // for (R)
|
| 286 |
-
|
| 287 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 288 |
-
ElementC c_ref = ElementC();
|
| 289 |
-
|
| 290 |
-
if (beta != ElementCompute()) {
|
| 291 |
-
c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c));
|
| 292 |
-
}
|
| 293 |
-
|
| 294 |
-
tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) =
|
| 295 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 296 |
-
|
| 297 |
-
} // for (C)
|
| 298 |
-
} // for (W)
|
| 299 |
-
} // for (H)
|
| 300 |
-
} // for (N)
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 304 |
-
/// Wgrad
|
| 305 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 306 |
-
|
| 307 |
-
/// dw = wgrad(dy, x)
|
| 308 |
-
template <
|
| 309 |
-
typename ElementA,
|
| 310 |
-
typename LayoutA,
|
| 311 |
-
typename ElementB,
|
| 312 |
-
typename LayoutB,
|
| 313 |
-
typename ElementC,
|
| 314 |
-
typename LayoutC,
|
| 315 |
-
typename ElementCompute,
|
| 316 |
-
typename ElementAccumulator = ElementCompute,
|
| 317 |
-
typename ElementD = ElementC,
|
| 318 |
-
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 319 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 320 |
-
>
|
| 321 |
-
void Conv2dWgrad(
|
| 322 |
-
cutlass::conv::Conv2dProblemSize problem_size,
|
| 323 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 324 |
-
TensorRef<ElementB, LayoutB> tensor_x,
|
| 325 |
-
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 326 |
-
TensorRef<ElementD, LayoutC> tensor_dw_out,
|
| 327 |
-
ElementCompute alpha,
|
| 328 |
-
ElementCompute beta) {
|
| 329 |
-
|
| 330 |
-
InnerProductOp inner_product_op;
|
| 331 |
-
ConvertOp convert_op;
|
| 332 |
-
|
| 333 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 334 |
-
for (int k = 0; k < problem_size.K; ++k) {
|
| 335 |
-
for (int r = 0; r < problem_size.R; ++r) {
|
| 336 |
-
for (int s = 0; s < problem_size.S; ++s) {
|
| 337 |
-
for (int c = 0; c < problem_size.C; ++c) {
|
| 338 |
-
|
| 339 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 340 |
-
|
| 341 |
-
for (int n = 0; n < problem_size.N; ++n) {
|
| 342 |
-
for (int p = 0; p < problem_size.P; ++p) {
|
| 343 |
-
for (int q = 0; q < problem_size.Q; ++q) {
|
| 344 |
-
|
| 345 |
-
cutlass::Tensor4DCoord b_coord;
|
| 346 |
-
|
| 347 |
-
int filter_r = r;
|
| 348 |
-
int filter_s = s;
|
| 349 |
-
|
| 350 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 351 |
-
filter_r = problem_size.R - 1 - r;
|
| 352 |
-
filter_s = problem_size.S - 1 - s;
|
| 353 |
-
}
|
| 354 |
-
|
| 355 |
-
b_coord = make_Coord(
|
| 356 |
-
n,
|
| 357 |
-
p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
|
| 358 |
-
q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
|
| 359 |
-
c);
|
| 360 |
-
|
| 361 |
-
if (b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
|
| 362 |
-
b_coord.w() < problem_size.W && b_coord.w() >= 0) {
|
| 363 |
-
|
| 364 |
-
ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k)));
|
| 365 |
-
ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
|
| 366 |
-
acc = inner_product_op(a, b, acc);
|
| 367 |
-
}
|
| 368 |
-
}
|
| 369 |
-
}
|
| 370 |
-
}
|
| 371 |
-
|
| 372 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 373 |
-
ElementC c_ref = ElementC();
|
| 374 |
-
|
| 375 |
-
if (beta != ElementCompute()) {
|
| 376 |
-
c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c));
|
| 377 |
-
}
|
| 378 |
-
|
| 379 |
-
tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) =
|
| 380 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 381 |
-
|
| 382 |
-
} // for (C)
|
| 383 |
-
} // for (S)
|
| 384 |
-
} // for (R)
|
| 385 |
-
} // for (K)
|
| 386 |
-
}
|
| 387 |
-
|
| 388 |
-
/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
|
| 389 |
-
template <
|
| 390 |
-
typename ElementA,
|
| 391 |
-
typename LayoutA,
|
| 392 |
-
typename ElementB,
|
| 393 |
-
typename LayoutB,
|
| 394 |
-
typename ElementC,
|
| 395 |
-
typename LayoutC,
|
| 396 |
-
typename ElementCompute,
|
| 397 |
-
typename ElementAccumulator = ElementCompute,
|
| 398 |
-
typename ElementD = ElementC,
|
| 399 |
-
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 400 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 401 |
-
>
|
| 402 |
-
void Conv2d(
|
| 403 |
-
conv::Operator convolutional_operator,
|
| 404 |
-
conv::Conv2dProblemSize problem_size,
|
| 405 |
-
TensorRef<ElementA, LayoutA> tensor_A,
|
| 406 |
-
TensorRef<ElementB, LayoutB> tensor_B,
|
| 407 |
-
TensorRef<ElementC, LayoutC> tensor_C,
|
| 408 |
-
TensorRef<ElementD, LayoutC> tensor_D,
|
| 409 |
-
ElementCompute alpha,
|
| 410 |
-
ElementCompute beta) {
|
| 411 |
-
|
| 412 |
-
switch (convolutional_operator) {
|
| 413 |
-
case conv::Operator::kFprop:
|
| 414 |
-
Conv2dFprop<
|
| 415 |
-
ElementA, LayoutA,
|
| 416 |
-
ElementB, LayoutB,
|
| 417 |
-
ElementC, LayoutC,
|
| 418 |
-
ElementCompute,
|
| 419 |
-
ElementAccumulator,
|
| 420 |
-
ElementD,
|
| 421 |
-
ConvertOp, InnerProductOp
|
| 422 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 423 |
-
break;
|
| 424 |
-
|
| 425 |
-
case conv::Operator::kDeconv:
|
| 426 |
-
case conv::Operator::kDgrad:
|
| 427 |
-
Conv2dDgrad<
|
| 428 |
-
ElementA, LayoutA,
|
| 429 |
-
ElementB, LayoutB,
|
| 430 |
-
ElementC, LayoutC,
|
| 431 |
-
ElementCompute,
|
| 432 |
-
ElementAccumulator,
|
| 433 |
-
ElementD,
|
| 434 |
-
ConvertOp, InnerProductOp
|
| 435 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
|
| 436 |
-
break;
|
| 437 |
-
|
| 438 |
-
case conv::Operator::kWgrad:
|
| 439 |
-
Conv2dWgrad<
|
| 440 |
-
ElementA, LayoutA,
|
| 441 |
-
ElementB, LayoutB,
|
| 442 |
-
ElementC, LayoutC,
|
| 443 |
-
ElementCompute,
|
| 444 |
-
ElementAccumulator,
|
| 445 |
-
ElementD,
|
| 446 |
-
ConvertOp, InnerProductOp
|
| 447 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 448 |
-
break;
|
| 449 |
-
|
| 450 |
-
default:
|
| 451 |
-
break;
|
| 452 |
-
}
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 456 |
-
/// 3D convolution
|
| 457 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 458 |
-
|
| 459 |
-
/// y = conv3d(x, w)
|
| 460 |
-
template <
|
| 461 |
-
typename ElementA,
|
| 462 |
-
typename LayoutA,
|
| 463 |
-
typename ElementB,
|
| 464 |
-
typename LayoutB,
|
| 465 |
-
typename ElementC,
|
| 466 |
-
typename LayoutC,
|
| 467 |
-
typename ElementCompute,
|
| 468 |
-
typename ElementAccumulator = ElementCompute,
|
| 469 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 470 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 471 |
-
>
|
| 472 |
-
void Conv3dFprop(
|
| 473 |
-
conv::Conv3dProblemSize problem_size,
|
| 474 |
-
TensorRef<ElementA, LayoutA> tensor_x,
|
| 475 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 476 |
-
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 477 |
-
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 478 |
-
ElementCompute alpha,
|
| 479 |
-
ElementCompute beta) {
|
| 480 |
-
|
| 481 |
-
ConvertOp convert_op;
|
| 482 |
-
InnerProductOp inner_product_op;
|
| 483 |
-
|
| 484 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 485 |
-
for (int n = 0; n < problem_size.N; ++n) {
|
| 486 |
-
for (int z = 0; z < problem_size.Z; ++z) {
|
| 487 |
-
for (int p = 0; p < problem_size.P; ++p) {
|
| 488 |
-
for (int q = 0; q < problem_size.Q; ++q) {
|
| 489 |
-
for (int k = 0; k < problem_size.K; ++k) {
|
| 490 |
-
|
| 491 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 492 |
-
|
| 493 |
-
for (int t = 0; t < problem_size.T; ++t) {
|
| 494 |
-
for (int r = 0; r < problem_size.R; ++r) {
|
| 495 |
-
for (int s = 0; s < problem_size.S; ++s) {
|
| 496 |
-
for (int c = 0; c < problem_size.C; ++c) {
|
| 497 |
-
|
| 498 |
-
int filter_t = t;
|
| 499 |
-
int filter_r = r;
|
| 500 |
-
int filter_s = s;
|
| 501 |
-
|
| 502 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 503 |
-
filter_t = problem_size.T - 1 - t;
|
| 504 |
-
filter_r = problem_size.R - 1 - r;
|
| 505 |
-
filter_s = problem_size.S - 1 - s;
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
| 509 |
-
int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 510 |
-
int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 511 |
-
|
| 512 |
-
if (d >= 0 && d < problem_size.D &&
|
| 513 |
-
h >=0 && h < problem_size.H &&
|
| 514 |
-
w >= 0 && w < problem_size.W) {
|
| 515 |
-
|
| 516 |
-
ElementA a = tensor_x.at({n, d, h, w, c});
|
| 517 |
-
ElementB b = tensor_w.at({k, t, r, s, c});
|
| 518 |
-
|
| 519 |
-
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 520 |
-
}
|
| 521 |
-
}
|
| 522 |
-
}
|
| 523 |
-
}
|
| 524 |
-
}
|
| 525 |
-
|
| 526 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 527 |
-
ElementC c_ref = ElementC();
|
| 528 |
-
|
| 529 |
-
if (beta != ElementCompute()) {
|
| 530 |
-
c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k));
|
| 531 |
-
}
|
| 532 |
-
|
| 533 |
-
tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) =
|
| 534 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 535 |
-
}
|
| 536 |
-
}
|
| 537 |
-
}
|
| 538 |
-
}
|
| 539 |
-
}
|
| 540 |
-
}
|
| 541 |
-
|
| 542 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 543 |
-
/// Dgrad / Deconv
|
| 544 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 545 |
-
|
| 546 |
-
/// dx = dgrad(dy, w)
|
| 547 |
-
template <
|
| 548 |
-
typename ElementA,
|
| 549 |
-
typename LayoutA,
|
| 550 |
-
typename ElementB,
|
| 551 |
-
typename LayoutB,
|
| 552 |
-
typename ElementC,
|
| 553 |
-
typename LayoutC,
|
| 554 |
-
typename ElementCompute,
|
| 555 |
-
typename ElementAccumulator = ElementCompute,
|
| 556 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 557 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 558 |
-
>
|
| 559 |
-
void Conv3dDgrad(
|
| 560 |
-
cutlass::conv::Conv3dProblemSize problem_size,
|
| 561 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 562 |
-
TensorRef<ElementB, LayoutB> tensor_w,
|
| 563 |
-
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 564 |
-
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 565 |
-
ElementCompute alpha,
|
| 566 |
-
ElementCompute beta,
|
| 567 |
-
bool is_deconv = false) {
|
| 568 |
-
|
| 569 |
-
ConvertOp convert_op;
|
| 570 |
-
InnerProductOp inner_product_op;
|
| 571 |
-
|
| 572 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 573 |
-
for (int n = 0; n < problem_size.N; ++n) {
|
| 574 |
-
for (int d = 0; d < problem_size.D; ++d) {
|
| 575 |
-
for (int h = 0; h < problem_size.H; ++h) {
|
| 576 |
-
for (int w = 0; w < problem_size.W; ++w) {
|
| 577 |
-
for (int c = 0; c < problem_size.C; ++c) {
|
| 578 |
-
|
| 579 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 580 |
-
|
| 581 |
-
for (int t = 0; t < problem_size.T; ++t) {
|
| 582 |
-
for (int r = 0; r < problem_size.R; ++r) {
|
| 583 |
-
for (int s = 0; s < problem_size.S; ++s) {
|
| 584 |
-
for (int k = 0; k < problem_size.K; ++k) {
|
| 585 |
-
|
| 586 |
-
int filter_t = t;
|
| 587 |
-
int filter_r = r;
|
| 588 |
-
int filter_s = s;
|
| 589 |
-
|
| 590 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 591 |
-
filter_t = problem_size.T - 1 - t;
|
| 592 |
-
filter_r = problem_size.R - 1 - r;
|
| 593 |
-
filter_s = problem_size.S - 1 - s;
|
| 594 |
-
}
|
| 595 |
-
|
| 596 |
-
int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d;
|
| 597 |
-
int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 598 |
-
int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 599 |
-
|
| 600 |
-
if (z >= 0 && (z % problem_size.stride_d) == 0 &&
|
| 601 |
-
p >= 0 && (p % problem_size.stride_h) == 0 &&
|
| 602 |
-
q >= 0 && (q % problem_size.stride_w) == 0) {
|
| 603 |
-
|
| 604 |
-
z = z / problem_size.stride_d;
|
| 605 |
-
p = p / problem_size.stride_h;
|
| 606 |
-
q = q / problem_size.stride_w;
|
| 607 |
-
|
| 608 |
-
if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
|
| 609 |
-
|
| 610 |
-
ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k));
|
| 611 |
-
ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k))
|
| 612 |
-
: tensor_w.at(cutlass::make_Coord(k, t, r, s, c));
|
| 613 |
-
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 614 |
-
}
|
| 615 |
-
}
|
| 616 |
-
|
| 617 |
-
} // for (K)
|
| 618 |
-
} // for (S)
|
| 619 |
-
} // for (R)
|
| 620 |
-
} // for (T)
|
| 621 |
-
|
| 622 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 623 |
-
ElementC c_ref = ElementC();
|
| 624 |
-
|
| 625 |
-
if (beta != ElementCompute()) {
|
| 626 |
-
c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c));
|
| 627 |
-
}
|
| 628 |
-
|
| 629 |
-
tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) =
|
| 630 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 631 |
-
|
| 632 |
-
} // for (C)
|
| 633 |
-
} // for (W)
|
| 634 |
-
} // for (H)
|
| 635 |
-
} // for (D)
|
| 636 |
-
} // for (N)
|
| 637 |
-
}
|
| 638 |
-
|
| 639 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 640 |
-
/// Wgrad
|
| 641 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 642 |
-
|
| 643 |
-
/// dw = wgrad(dy, x)
|
| 644 |
-
template <
|
| 645 |
-
typename ElementA,
|
| 646 |
-
typename LayoutA,
|
| 647 |
-
typename ElementB,
|
| 648 |
-
typename LayoutB,
|
| 649 |
-
typename ElementC,
|
| 650 |
-
typename LayoutC,
|
| 651 |
-
typename ElementCompute,
|
| 652 |
-
typename ElementAccumulator = ElementCompute,
|
| 653 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 654 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 655 |
-
>
|
| 656 |
-
void Conv3dWgrad(
|
| 657 |
-
cutlass::conv::Conv3dProblemSize problem_size,
|
| 658 |
-
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 659 |
-
TensorRef<ElementB, LayoutB> tensor_x,
|
| 660 |
-
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 661 |
-
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 662 |
-
ElementCompute alpha,
|
| 663 |
-
ElementCompute beta) {
|
| 664 |
-
|
| 665 |
-
InnerProductOp inner_product_op;
|
| 666 |
-
ConvertOp convert_op;
|
| 667 |
-
|
| 668 |
-
// Apply MMA and accumulate ElementAccumulator
|
| 669 |
-
for (int k = 0; k < problem_size.K; ++k) {
|
| 670 |
-
for (int t = 0; t < problem_size.T; ++t) {
|
| 671 |
-
for (int r = 0; r < problem_size.R; ++r) {
|
| 672 |
-
for (int s = 0; s < problem_size.S; ++s) {
|
| 673 |
-
for (int c = 0; c < problem_size.C; ++c) {
|
| 674 |
-
|
| 675 |
-
ElementAccumulator acc = ElementAccumulator();
|
| 676 |
-
|
| 677 |
-
for (int n = 0; n < problem_size.N; ++n) {
|
| 678 |
-
for (int z = 0; z < problem_size.Z; ++z) {
|
| 679 |
-
for (int p = 0; p < problem_size.P; ++p) {
|
| 680 |
-
for (int q = 0; q < problem_size.Q; ++q) {
|
| 681 |
-
|
| 682 |
-
int filter_t = t;
|
| 683 |
-
int filter_r = r;
|
| 684 |
-
int filter_s = s;
|
| 685 |
-
|
| 686 |
-
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 687 |
-
filter_t = problem_size.T - 1 - t;
|
| 688 |
-
filter_r = problem_size.R - 1 - r;
|
| 689 |
-
filter_s = problem_size.S - 1 - s;
|
| 690 |
-
}
|
| 691 |
-
|
| 692 |
-
Tensor5DCoord b_coord = make_Coord(
|
| 693 |
-
n,
|
| 694 |
-
z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d,
|
| 695 |
-
p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
|
| 696 |
-
q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
|
| 697 |
-
c);
|
| 698 |
-
|
| 699 |
-
if (b_coord.d() < problem_size.D && b_coord.d() >= 0 &&
|
| 700 |
-
b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
|
| 701 |
-
b_coord.w() < problem_size.W && b_coord.w() >= 0) {
|
| 702 |
-
|
| 703 |
-
ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)));
|
| 704 |
-
ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
|
| 705 |
-
|
| 706 |
-
acc = inner_product_op(a, b, acc);
|
| 707 |
-
}
|
| 708 |
-
}
|
| 709 |
-
}
|
| 710 |
-
}
|
| 711 |
-
}
|
| 712 |
-
|
| 713 |
-
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 714 |
-
ElementC c_ref = ElementC();
|
| 715 |
-
|
| 716 |
-
if (beta != ElementCompute()) {
|
| 717 |
-
c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c));
|
| 718 |
-
}
|
| 719 |
-
|
| 720 |
-
tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) =
|
| 721 |
-
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 722 |
-
|
| 723 |
-
} // for (C)
|
| 724 |
-
} // for (S)
|
| 725 |
-
} // for (R)
|
| 726 |
-
} // for (T)
|
| 727 |
-
} // for (K)
|
| 728 |
-
}
|
| 729 |
-
|
| 730 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 731 |
-
|
| 732 |
-
/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
|
| 733 |
-
template <
|
| 734 |
-
typename ElementA,
|
| 735 |
-
typename LayoutA,
|
| 736 |
-
typename ElementB,
|
| 737 |
-
typename LayoutB,
|
| 738 |
-
typename ElementC,
|
| 739 |
-
typename LayoutC,
|
| 740 |
-
typename ElementCompute,
|
| 741 |
-
typename ElementAccumulator = ElementCompute,
|
| 742 |
-
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 743 |
-
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 744 |
-
>
|
| 745 |
-
void Conv3d(
|
| 746 |
-
conv::Operator convolutional_operator,
|
| 747 |
-
conv::Conv3dProblemSize problem_size,
|
| 748 |
-
TensorRef<ElementA, LayoutA> tensor_A,
|
| 749 |
-
TensorRef<ElementB, LayoutB> tensor_B,
|
| 750 |
-
TensorRef<ElementC, LayoutC> tensor_C,
|
| 751 |
-
TensorRef<ElementC, LayoutC> tensor_D,
|
| 752 |
-
ElementCompute alpha,
|
| 753 |
-
ElementCompute beta) {
|
| 754 |
-
|
| 755 |
-
switch (convolutional_operator) {
|
| 756 |
-
case conv::Operator::kFprop:
|
| 757 |
-
Conv3dFprop<
|
| 758 |
-
ElementA, LayoutA,
|
| 759 |
-
ElementB, LayoutB,
|
| 760 |
-
ElementC, LayoutC,
|
| 761 |
-
ElementCompute,
|
| 762 |
-
ElementAccumulator,
|
| 763 |
-
ConvertOp, InnerProductOp
|
| 764 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 765 |
-
break;
|
| 766 |
-
|
| 767 |
-
case conv::Operator::kDeconv:
|
| 768 |
-
case conv::Operator::kDgrad:
|
| 769 |
-
Conv3dDgrad<
|
| 770 |
-
ElementA, LayoutA,
|
| 771 |
-
ElementB, LayoutB,
|
| 772 |
-
ElementC, LayoutC,
|
| 773 |
-
ElementCompute,
|
| 774 |
-
ElementAccumulator,
|
| 775 |
-
ConvertOp, InnerProductOp
|
| 776 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
|
| 777 |
-
break;
|
| 778 |
-
|
| 779 |
-
case conv::Operator::kWgrad:
|
| 780 |
-
Conv3dWgrad<
|
| 781 |
-
ElementA, LayoutA,
|
| 782 |
-
ElementB, LayoutB,
|
| 783 |
-
ElementC, LayoutC,
|
| 784 |
-
ElementCompute,
|
| 785 |
-
ElementAccumulator,
|
| 786 |
-
ConvertOp, InnerProductOp
|
| 787 |
-
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 788 |
-
break;
|
| 789 |
-
|
| 790 |
-
default:
|
| 791 |
-
break;
|
| 792 |
-
}
|
| 793 |
-
}
|
| 794 |
-
|
| 795 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 796 |
-
|
| 797 |
-
} // namespace host
|
| 798 |
-
} // namespace reference
|
| 799 |
-
} // namespace cutlass
|
| 800 |
-
|
| 801 |
-
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
/***************************************************************************************************
|
| 3 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
*
|
| 6 |
-
* Redistribution and use in source and binary forms, with or without
|
| 7 |
-
* modification, are permitted provided that the following conditions are met:
|
| 8 |
-
*
|
| 9 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
* list of conditions and the following disclaimer.
|
| 11 |
-
*
|
| 12 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
* and/or other materials provided with the distribution.
|
| 15 |
-
*
|
| 16 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
* contributors may be used to endorse or promote products derived from
|
| 18 |
-
* this software without specific prior written permission.
|
| 19 |
-
*
|
| 20 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
*
|
| 31 |
-
**************************************************************************************************/
|
| 32 |
-
#pragma once
|
| 33 |
-
|
| 34 |
-
#include <cmath>
|
| 35 |
-
|
| 36 |
-
#include "cutlass/cutlass.h"
|
| 37 |
-
#include "cutlass/complex.h"
|
| 38 |
-
#include "cutlass/util/reference/host/tensor_reduce.h"
|
| 39 |
-
#include "cutlass/core_io.h"
|
| 40 |
-
|
| 41 |
-
namespace cutlass {
|
| 42 |
-
namespace reference {
|
| 43 |
-
namespace host {
|
| 44 |
-
|
| 45 |
-
/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference
|
| 46 |
-
template <
|
| 47 |
-
typename Element,
|
| 48 |
-
typename Layout,
|
| 49 |
-
typename ComputeType = double
|
| 50 |
-
>
|
| 51 |
-
ComputeType TensorRelativeErrorMetric(
|
| 52 |
-
TensorView<Element, Layout> view_A_computed,
|
| 53 |
-
TensorView<Element, Layout> view_B_reference,
|
| 54 |
-
ComputeType identity = ComputeType()
|
| 55 |
-
) {
|
| 56 |
-
|
| 57 |
-
return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) /
|
| 58 |
-
cutlass::reference::host::TensorNorm(view_B_reference, identity);
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 63 |
-
|
| 64 |
-
} // namespace host
|
| 65 |
-
} // namespace reference
|
| 66 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h
DELETED
|
@@ -1,531 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/numeric_types.h"
|
| 39 |
-
#include "cutlass/functional.h"
|
| 40 |
-
#include "cutlass/numeric_conversion.h"
|
| 41 |
-
|
| 42 |
-
#include "cutlass/tensor_view.h"
|
| 43 |
-
#include "cutlass/gemm/gemm.h"
|
| 44 |
-
#include "cutlass/arch/mma.h"
|
| 45 |
-
#include "cutlass/util/host_tensor.h"
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
namespace reference {
|
| 49 |
-
namespace host {
|
| 50 |
-
|
| 51 |
-
template<typename Out, typename In>
|
| 52 |
-
struct CastIfScalar {
|
| 53 |
-
static Out cast(In in) {
|
| 54 |
-
return Out(in);
|
| 55 |
-
}
|
| 56 |
-
};
|
| 57 |
-
|
| 58 |
-
template<typename OutScalar, typename In>
|
| 59 |
-
struct CastIfScalar<cutlass::complex<OutScalar>, In> {
|
| 60 |
-
typedef cutlass::complex<OutScalar> Out;
|
| 61 |
-
static Out cast(In in) {
|
| 62 |
-
return Out(static_cast<OutScalar>(in));
|
| 63 |
-
}
|
| 64 |
-
};
|
| 65 |
-
|
| 66 |
-
template<typename OutScalar, typename InScalar>
|
| 67 |
-
struct CastIfScalar<cutlass::complex<OutScalar>, cutlass::complex<InScalar>> {
|
| 68 |
-
typedef cutlass::complex<OutScalar> Out;
|
| 69 |
-
typedef cutlass::complex<InScalar> In;
|
| 70 |
-
static Out cast(In in) {
|
| 71 |
-
return Out(in);
|
| 72 |
-
}
|
| 73 |
-
};
|
| 74 |
-
|
| 75 |
-
template<typename Out, typename In>
|
| 76 |
-
Out cast_if_scalar(In in) {
|
| 77 |
-
return CastIfScalar<Out, In>::cast(in);
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 81 |
-
|
| 82 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 83 |
-
/// objects.
|
| 84 |
-
template <
|
| 85 |
-
typename ElementA,
|
| 86 |
-
typename LayoutA,
|
| 87 |
-
typename ElementB,
|
| 88 |
-
typename LayoutB,
|
| 89 |
-
typename ElementC,
|
| 90 |
-
typename LayoutC,
|
| 91 |
-
typename ScalarType,
|
| 92 |
-
typename ComputeType,
|
| 93 |
-
typename InnerProductOp = multiply_add<ComputeType>,
|
| 94 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 95 |
-
>
|
| 96 |
-
void compute_gemm(
|
| 97 |
-
gemm::GemmCoord problem_size,
|
| 98 |
-
ScalarType alpha,
|
| 99 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 100 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 101 |
-
ScalarType beta,
|
| 102 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 103 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 104 |
-
ComputeType initial_accum) {
|
| 105 |
-
|
| 106 |
-
static_assert(
|
| 107 |
-
LayoutA::kRank == 2 &&
|
| 108 |
-
LayoutB::kRank == 2 &&
|
| 109 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
// Note: batch is ignored.
|
| 113 |
-
int const M = problem_size.m();
|
| 114 |
-
int const N = problem_size.n();
|
| 115 |
-
int const K = problem_size.k();
|
| 116 |
-
|
| 117 |
-
// Blocking necessary to speedup reference implementation
|
| 118 |
-
int const Mblock = 16;
|
| 119 |
-
int const Nblock = 16;
|
| 120 |
-
|
| 121 |
-
ConvertOp convert_op;
|
| 122 |
-
InnerProductOp inner_product_op;
|
| 123 |
-
|
| 124 |
-
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 125 |
-
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 126 |
-
|
| 127 |
-
ComputeType accum[Mblock][Nblock];
|
| 128 |
-
|
| 129 |
-
for (int j = 0; j < Nblock; j++) {
|
| 130 |
-
for (int i = 0; i < Mblock; i++) {
|
| 131 |
-
accum[i][j] = initial_accum;
|
| 132 |
-
}
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
for (int k_block = 0; k_block < K; ++k_block) {
|
| 136 |
-
for (int j = 0; j < Nblock; j++) {
|
| 137 |
-
for (int i = 0; i < Mblock; i++) {
|
| 138 |
-
int row = row_block + i;
|
| 139 |
-
int col = col_block + j;
|
| 140 |
-
|
| 141 |
-
if (row < M && col < N) {
|
| 142 |
-
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 143 |
-
ElementB b = tensor_b.at(MatrixCoord(k_block, col));
|
| 144 |
-
|
| 145 |
-
ComputeType compute_a(cast_if_scalar<ComputeType>(a));
|
| 146 |
-
ComputeType compute_b(cast_if_scalar<ComputeType>(b));
|
| 147 |
-
|
| 148 |
-
accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
|
| 149 |
-
}
|
| 150 |
-
}
|
| 151 |
-
}
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
for (int j = 0; j < Nblock; j++) {
|
| 155 |
-
for (int i = 0; i < Mblock; i++) {
|
| 156 |
-
int row = row_block + i;
|
| 157 |
-
int col = col_block + j;
|
| 158 |
-
|
| 159 |
-
MatrixCoord coord = MatrixCoord(row, col);
|
| 160 |
-
|
| 161 |
-
if (row < M && col < N) {
|
| 162 |
-
tensor_d.at(coord) = convert_op(
|
| 163 |
-
alpha * ScalarType(accum[i][j]) +
|
| 164 |
-
beta * ScalarType(tensor_c.at(coord)));
|
| 165 |
-
}
|
| 166 |
-
}
|
| 167 |
-
}
|
| 168 |
-
}
|
| 169 |
-
}
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 173 |
-
|
| 174 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 175 |
-
/// objects.
|
| 176 |
-
template <
|
| 177 |
-
typename ElementA,
|
| 178 |
-
typename LayoutA,
|
| 179 |
-
typename ElementB,
|
| 180 |
-
typename LayoutB,
|
| 181 |
-
typename ElementC,
|
| 182 |
-
typename LayoutC,
|
| 183 |
-
typename ScalarType,
|
| 184 |
-
typename ComputeType,
|
| 185 |
-
typename InnerProductOp = multiply_add<ComputeType>,
|
| 186 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 187 |
-
>
|
| 188 |
-
void compute_gemm(
|
| 189 |
-
gemm::GemmCoord problem_size,
|
| 190 |
-
ScalarType alpha,
|
| 191 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 192 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 193 |
-
ScalarType beta,
|
| 194 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 195 |
-
ComputeType initial_accum) {
|
| 196 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 197 |
-
ScalarType, ComputeType, InnerProductOp, ConvertOp>(
|
| 198 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
| 199 |
-
initial_accum);
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 203 |
-
|
| 204 |
-
template <
|
| 205 |
-
typename ElementA,
|
| 206 |
-
typename LayoutA,
|
| 207 |
-
typename ElementB,
|
| 208 |
-
typename LayoutB,
|
| 209 |
-
typename ElementC,
|
| 210 |
-
typename LayoutC,
|
| 211 |
-
typename ScalarType,
|
| 212 |
-
typename ComputeType,
|
| 213 |
-
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 214 |
-
>
|
| 215 |
-
struct Gemm;
|
| 216 |
-
|
| 217 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 218 |
-
|
| 219 |
-
/// Partial specialization for multiply-add
|
| 220 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 221 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 222 |
-
typename ScalarType, typename ComputeType>
|
| 223 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 224 |
-
ComputeType, arch::OpMultiplyAdd> {
|
| 225 |
-
|
| 226 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 227 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 228 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 229 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 230 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 231 |
-
static_assert(
|
| 232 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 233 |
-
"Tensors must be of rank 2");
|
| 234 |
-
|
| 235 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 236 |
-
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 237 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 238 |
-
}
|
| 239 |
-
|
| 240 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 241 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 242 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 243 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 244 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 245 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 246 |
-
static_assert(
|
| 247 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 248 |
-
"Tensors must be of rank 2");
|
| 249 |
-
|
| 250 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 251 |
-
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 252 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 253 |
-
}
|
| 254 |
-
};
|
| 255 |
-
|
| 256 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 257 |
-
|
| 258 |
-
/// Partial specialization for multiply-add
|
| 259 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 260 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 261 |
-
typename ScalarType, typename ComputeType>
|
| 262 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 263 |
-
ComputeType, arch::OpMultiplyAddFastBF16> {
|
| 264 |
-
|
| 265 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 266 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 267 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 268 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 269 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 270 |
-
static_assert(
|
| 271 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 272 |
-
"Tensors must be of rank 2");
|
| 273 |
-
|
| 274 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 275 |
-
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 276 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 277 |
-
}
|
| 278 |
-
|
| 279 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 280 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 281 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 282 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 283 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 284 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 285 |
-
static_assert(
|
| 286 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 287 |
-
"Tensors must be of rank 2");
|
| 288 |
-
|
| 289 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 290 |
-
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 291 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 292 |
-
}
|
| 293 |
-
};
|
| 294 |
-
|
| 295 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 296 |
-
|
| 297 |
-
/// Partial specialization for multiply-add-saturate
|
| 298 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 299 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 300 |
-
typename ScalarType, typename ComputeType>
|
| 301 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 302 |
-
ComputeType, arch::OpMultiplyAddSaturate> {
|
| 303 |
-
|
| 304 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 305 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 306 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 307 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 308 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 309 |
-
static_assert(
|
| 310 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 311 |
-
"Tensors must be of rank 2");
|
| 312 |
-
|
| 313 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 314 |
-
ScalarType, ComputeType, multiply_add<ComputeType>,
|
| 315 |
-
NumericConverterClamp<ElementC, ScalarType>>(
|
| 316 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 320 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 321 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 322 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 323 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 324 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 325 |
-
static_assert(
|
| 326 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 327 |
-
"Tensors must be of rank 2");
|
| 328 |
-
|
| 329 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 330 |
-
ScalarType, ComputeType, multiply_add<ComputeType>,
|
| 331 |
-
NumericConverterClamp<ElementC, ScalarType>>(
|
| 332 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 333 |
-
}
|
| 334 |
-
};
|
| 335 |
-
|
| 336 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 337 |
-
|
| 338 |
-
/// Partial specialization for XOR-popc
|
| 339 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 340 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 341 |
-
typename ScalarType, typename ComputeType>
|
| 342 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 343 |
-
ComputeType, arch::OpXorPopc> {
|
| 344 |
-
|
| 345 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 346 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 347 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 348 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 349 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 350 |
-
static_assert(
|
| 351 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 352 |
-
"Tensors must be of rank 2");
|
| 353 |
-
|
| 354 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 355 |
-
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
|
| 356 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 357 |
-
}
|
| 358 |
-
|
| 359 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 360 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 361 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 362 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 363 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 364 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 365 |
-
static_assert(
|
| 366 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 367 |
-
"Tensors must be of rank 2");
|
| 368 |
-
|
| 369 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 370 |
-
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
|
| 371 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 372 |
-
}
|
| 373 |
-
};
|
| 374 |
-
|
| 375 |
-
/// Partial specialization for AND-popc
|
| 376 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 377 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 378 |
-
typename ScalarType, typename ComputeType>
|
| 379 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 380 |
-
ComputeType, arch::OpAndPopc> {
|
| 381 |
-
|
| 382 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 383 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 384 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 385 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 386 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 387 |
-
static_assert(
|
| 388 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 389 |
-
"Tensors must be of rank 2");
|
| 390 |
-
|
| 391 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 392 |
-
ScalarType, ComputeType, and_popc_add<ComputeType>>(
|
| 393 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 394 |
-
}
|
| 395 |
-
|
| 396 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 397 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 398 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 399 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 400 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 401 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 402 |
-
static_assert(
|
| 403 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 404 |
-
"Tensors must be of rank 2");
|
| 405 |
-
|
| 406 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 407 |
-
ScalarType, ComputeType, and_popc_add<ComputeType>>(
|
| 408 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 409 |
-
}
|
| 410 |
-
};
|
| 411 |
-
|
| 412 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 413 |
-
|
| 414 |
-
/// Partial specialization for multiply-add
|
| 415 |
-
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 416 |
-
typename LayoutB, typename ElementC, typename LayoutC,
|
| 417 |
-
typename ScalarType, typename ComputeType>
|
| 418 |
-
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 419 |
-
ComputeType, arch::OpMultiplyAddFastF32> {
|
| 420 |
-
|
| 421 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 422 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 423 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 424 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 425 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 426 |
-
static_assert(
|
| 427 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 428 |
-
"Tensors must be of rank 2");
|
| 429 |
-
|
| 430 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 431 |
-
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 432 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 433 |
-
}
|
| 434 |
-
|
| 435 |
-
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 436 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 437 |
-
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 438 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 439 |
-
TensorRef<ElementC, LayoutC> tensor_d,
|
| 440 |
-
ComputeType initial_accum = ComputeType(0)) {
|
| 441 |
-
static_assert(
|
| 442 |
-
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 443 |
-
"Tensors must be of rank 2");
|
| 444 |
-
|
| 445 |
-
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 446 |
-
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 447 |
-
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 448 |
-
}
|
| 449 |
-
};
|
| 450 |
-
|
| 451 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 452 |
-
|
| 453 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 454 |
-
//
|
| 455 |
-
// Batched GEMM
|
| 456 |
-
//
|
| 457 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 458 |
-
|
| 459 |
-
/// Computes a batch of GEMMs over a set of matrices of common dimension.
|
| 460 |
-
//
|
| 461 |
-
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 462 |
-
//
|
| 463 |
-
template <
|
| 464 |
-
typename TensorRefCollectionA,
|
| 465 |
-
typename TensorRefCollectionB,
|
| 466 |
-
typename TensorRefCollectionC,
|
| 467 |
-
typename ScalarType,
|
| 468 |
-
typename AccumulatorType
|
| 469 |
-
>
|
| 470 |
-
void BatchedGemm(
|
| 471 |
-
gemm::GemmCoord problem_size,
|
| 472 |
-
int batch_count,
|
| 473 |
-
ScalarType alpha,
|
| 474 |
-
TensorRefCollectionA const& tensor_a,
|
| 475 |
-
TensorRefCollectionB const& tensor_b,
|
| 476 |
-
ScalarType beta,
|
| 477 |
-
TensorRefCollectionC &tensor_c,
|
| 478 |
-
AccumulatorType initial_accum) {
|
| 479 |
-
|
| 480 |
-
typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
|
| 481 |
-
typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
|
| 482 |
-
typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
|
| 483 |
-
|
| 484 |
-
for (int batch = 0;
|
| 485 |
-
batch < batch_count;
|
| 486 |
-
++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
|
| 487 |
-
|
| 488 |
-
Gemm<typename TensorRefCollectionA::Element,
|
| 489 |
-
typename TensorRefCollectionA::Layout,
|
| 490 |
-
typename TensorRefCollectionB::Element,
|
| 491 |
-
typename TensorRefCollectionB::Layout,
|
| 492 |
-
typename TensorRefCollectionC::Element,
|
| 493 |
-
typename TensorRefCollectionC::Layout,
|
| 494 |
-
typename TensorRefCollectionC::Element,
|
| 495 |
-
typename TensorRefCollectionC::Element>
|
| 496 |
-
gemm;
|
| 497 |
-
|
| 498 |
-
gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
|
| 499 |
-
initial_accum);
|
| 500 |
-
}
|
| 501 |
-
}
|
| 502 |
-
|
| 503 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 504 |
-
/// objects.
|
| 505 |
-
//
|
| 506 |
-
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 507 |
-
//
|
| 508 |
-
template <
|
| 509 |
-
typename TensorRefCollectionA,
|
| 510 |
-
typename TensorRefCollectionB,
|
| 511 |
-
typename TensorRefCollectionC,
|
| 512 |
-
typename ScalarType,
|
| 513 |
-
typename AccumulatorType
|
| 514 |
-
>
|
| 515 |
-
void BatchedGemm(
|
| 516 |
-
gemm::GemmCoord problem_size,
|
| 517 |
-
int batch_count,
|
| 518 |
-
ScalarType alpha,
|
| 519 |
-
TensorRefCollectionA const& tensor_a,
|
| 520 |
-
TensorRefCollectionB const& tensor_b,
|
| 521 |
-
ScalarType beta,
|
| 522 |
-
TensorRefCollectionC &tensor_c) {
|
| 523 |
-
|
| 524 |
-
BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
| 525 |
-
}
|
| 526 |
-
|
| 527 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 528 |
-
|
| 529 |
-
} // namespace host
|
| 530 |
-
} // namespace reference
|
| 531 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for complex-valued GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/complex.h"
|
| 39 |
-
#include "cutlass/numeric_types.h"
|
| 40 |
-
#include "cutlass/functional.h"
|
| 41 |
-
#include "cutlass/numeric_conversion.h"
|
| 42 |
-
#include "cutlass/matrix_coord.h"
|
| 43 |
-
|
| 44 |
-
#include "cutlass/tensor_view.h"
|
| 45 |
-
|
| 46 |
-
#include "cutlass/gemm/gemm.h"
|
| 47 |
-
|
| 48 |
-
namespace cutlass {
|
| 49 |
-
namespace reference {
|
| 50 |
-
namespace host {
|
| 51 |
-
|
| 52 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
-
|
| 54 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
-
/// objects.
|
| 56 |
-
///
|
| 57 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 58 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 59 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 60 |
-
/// arguments explicitly.
|
| 61 |
-
template <
|
| 62 |
-
typename ElementA,
|
| 63 |
-
typename LayoutA,
|
| 64 |
-
typename ElementB,
|
| 65 |
-
typename LayoutB,
|
| 66 |
-
typename ElementC,
|
| 67 |
-
typename LayoutC,
|
| 68 |
-
typename ScalarType,
|
| 69 |
-
typename ComputeType,
|
| 70 |
-
typename ElementD = ElementC,
|
| 71 |
-
typename ConvertOp = NumericConverter<ElementD, ScalarType>,
|
| 72 |
-
typename InnerProductOp = multiply_add<ComputeType>
|
| 73 |
-
>
|
| 74 |
-
void GemmComplex(
|
| 75 |
-
gemm::GemmCoord problem_size,
|
| 76 |
-
ScalarType alpha,
|
| 77 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 78 |
-
ComplexTransform transform_a,
|
| 79 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 80 |
-
ComplexTransform transform_b,
|
| 81 |
-
ScalarType beta,
|
| 82 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 83 |
-
TensorRef<ElementD, LayoutC> tensor_d,
|
| 84 |
-
ComputeType initial_accum,
|
| 85 |
-
int batch_count = 1,
|
| 86 |
-
int64_t batch_stride_A = 0,
|
| 87 |
-
int64_t batch_stride_B = 0,
|
| 88 |
-
int64_t batch_stride_C = 0,
|
| 89 |
-
int64_t batch_stride_D = 0) {
|
| 90 |
-
|
| 91 |
-
static_assert(
|
| 92 |
-
LayoutA::kRank == 2 &&
|
| 93 |
-
LayoutB::kRank == 2 &&
|
| 94 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 95 |
-
|
| 96 |
-
// Note: batch is ignored.
|
| 97 |
-
int const M = problem_size.m();
|
| 98 |
-
int const N = problem_size.n();
|
| 99 |
-
int const K = problem_size.k();
|
| 100 |
-
|
| 101 |
-
// Blocking necessary to speedup reference implementation
|
| 102 |
-
int const Mblock = 16;
|
| 103 |
-
int const Nblock = 16;
|
| 104 |
-
|
| 105 |
-
ConvertOp convert_op;
|
| 106 |
-
InnerProductOp inner_product_op;
|
| 107 |
-
|
| 108 |
-
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
| 109 |
-
|
| 110 |
-
// Compute matrix product using blocks
|
| 111 |
-
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 112 |
-
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 113 |
-
|
| 114 |
-
ComputeType accum[Mblock][Nblock];
|
| 115 |
-
|
| 116 |
-
for (int j = 0; j < Nblock; j++) {
|
| 117 |
-
for (int i = 0; i < Mblock; i++) {
|
| 118 |
-
accum[i][j] = initial_accum;
|
| 119 |
-
}
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
for (int k_block = 0; k_block < K; ++k_block) {
|
| 123 |
-
for (int j = 0; j < Nblock; j++) {
|
| 124 |
-
for (int i = 0; i < Mblock; i++) {
|
| 125 |
-
int row = row_block + i;
|
| 126 |
-
int col = col_block + j;
|
| 127 |
-
|
| 128 |
-
if (row < M && col < N) {
|
| 129 |
-
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 130 |
-
ElementB b = tensor_b.at(MatrixCoord(k_block, col));
|
| 131 |
-
|
| 132 |
-
ComputeType a_ik = ComputeType(a);
|
| 133 |
-
ComputeType b_kj = ComputeType(b);
|
| 134 |
-
|
| 135 |
-
if (transform_a == ComplexTransform::kConjugate) {
|
| 136 |
-
a_ik = conj(a_ik);
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
if (transform_b == ComplexTransform::kConjugate) {
|
| 140 |
-
b_kj = conj(b_kj);
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 144 |
-
}
|
| 145 |
-
}
|
| 146 |
-
}
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
for (int j = 0; j < Nblock; j++) {
|
| 150 |
-
for (int i = 0; i < Mblock; i++) {
|
| 151 |
-
int row = row_block + i;
|
| 152 |
-
int col = col_block + j;
|
| 153 |
-
|
| 154 |
-
MatrixCoord coord = MatrixCoord(row, col);
|
| 155 |
-
|
| 156 |
-
if (row < M && col < N) {
|
| 157 |
-
|
| 158 |
-
tensor_d.at(coord) = convert_op(
|
| 159 |
-
alpha * ScalarType(accum[i][j]) +
|
| 160 |
-
beta * ScalarType(tensor_c.at(coord)));
|
| 161 |
-
}
|
| 162 |
-
}
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
} // for (col_block)
|
| 166 |
-
} // for (row_block)
|
| 167 |
-
|
| 168 |
-
tensor_a.add_pointer_offset(batch_stride_A);
|
| 169 |
-
tensor_b.add_pointer_offset(batch_stride_B);
|
| 170 |
-
tensor_c.add_pointer_offset(batch_stride_C);
|
| 171 |
-
tensor_d.add_pointer_offset(batch_stride_D);
|
| 172 |
-
|
| 173 |
-
} // for (batch_idx)
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 177 |
-
|
| 178 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 179 |
-
/// objects.
|
| 180 |
-
///
|
| 181 |
-
/// This assumes the accumulator type is the same type as the scalars.
|
| 182 |
-
template <
|
| 183 |
-
typename ElementA,
|
| 184 |
-
typename LayoutA,
|
| 185 |
-
typename ElementB,
|
| 186 |
-
typename LayoutB,
|
| 187 |
-
typename ElementC,
|
| 188 |
-
typename LayoutC,
|
| 189 |
-
typename ScalarType,
|
| 190 |
-
typename ElementD = ElementC
|
| 191 |
-
>
|
| 192 |
-
void GemmComplex(
|
| 193 |
-
gemm::GemmCoord problem_size,
|
| 194 |
-
ScalarType alpha,
|
| 195 |
-
TensorRef<ElementA, LayoutA> tensor_a,
|
| 196 |
-
ComplexTransform transform_a,
|
| 197 |
-
TensorRef<ElementB, LayoutB> tensor_b,
|
| 198 |
-
ComplexTransform transform_b,
|
| 199 |
-
ScalarType beta,
|
| 200 |
-
TensorRef<ElementC, LayoutC> tensor_c,
|
| 201 |
-
TensorRef<ElementD, LayoutC> tensor_d) {
|
| 202 |
-
|
| 203 |
-
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 207 |
-
|
| 208 |
-
} // namespace host
|
| 209 |
-
} // namespace reference
|
| 210 |
-
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu126-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h
DELETED
|
@@ -1,228 +0,0 @@
|
|
| 1 |
-
/***************************************************************************************************
|
| 2 |
-
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
-
*
|
| 5 |
-
* Redistribution and use in source and binary forms, with or without
|
| 6 |
-
* modification, are permitted provided that the following conditions are met:
|
| 7 |
-
*
|
| 8 |
-
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
-
* list of conditions and the following disclaimer.
|
| 10 |
-
*
|
| 11 |
-
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
-
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
-
* and/or other materials provided with the distribution.
|
| 14 |
-
*
|
| 15 |
-
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
-
* contributors may be used to endorse or promote products derived from
|
| 17 |
-
* this software without specific prior written permission.
|
| 18 |
-
*
|
| 19 |
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
-
*
|
| 30 |
-
**************************************************************************************************/
|
| 31 |
-
/*! \file
|
| 32 |
-
\brief Reference implementation for complex-valued GEMM in host-side code.
|
| 33 |
-
*/
|
| 34 |
-
|
| 35 |
-
#pragma once
|
| 36 |
-
|
| 37 |
-
#include "cutlass/coord.h"
|
| 38 |
-
#include "cutlass/complex.h"
|
| 39 |
-
#include "cutlass/numeric_types.h"
|
| 40 |
-
#include "cutlass/functional.h"
|
| 41 |
-
#include "cutlass/numeric_conversion.h"
|
| 42 |
-
#include "cutlass/tensor_ref_planar_complex.h"
|
| 43 |
-
|
| 44 |
-
#include "cutlass/tensor_view.h"
|
| 45 |
-
#include "cutlass/gemm/gemm.h"
|
| 46 |
-
|
| 47 |
-
namespace cutlass {
|
| 48 |
-
namespace reference {
|
| 49 |
-
namespace host {
|
| 50 |
-
|
| 51 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
-
|
| 53 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 54 |
-
/// objects.
|
| 55 |
-
///
|
| 56 |
-
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 57 |
-
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 58 |
-
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 59 |
-
/// arguments explicitly.
|
| 60 |
-
template <
|
| 61 |
-
typename ElementA,
|
| 62 |
-
typename LayoutA,
|
| 63 |
-
typename ElementB,
|
| 64 |
-
typename LayoutB,
|
| 65 |
-
typename ElementC,
|
| 66 |
-
typename LayoutC,
|
| 67 |
-
typename ScalarType,
|
| 68 |
-
typename ComputeType,
|
| 69 |
-
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 70 |
-
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
| 71 |
-
>
|
| 72 |
-
void GemmPlanarComplex(
|
| 73 |
-
gemm::GemmCoord problem_size,
|
| 74 |
-
complex<ScalarType> alpha,
|
| 75 |
-
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 76 |
-
ComplexTransform transform_a,
|
| 77 |
-
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 78 |
-
ComplexTransform transform_b,
|
| 79 |
-
complex<ScalarType> beta,
|
| 80 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 81 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
| 82 |
-
complex<ComputeType> initial_accum) {
|
| 83 |
-
|
| 84 |
-
static_assert(
|
| 85 |
-
LayoutA::kRank == 2 &&
|
| 86 |
-
LayoutB::kRank == 2 &&
|
| 87 |
-
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 88 |
-
|
| 89 |
-
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
|
| 90 |
-
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
|
| 91 |
-
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
|
| 92 |
-
|
| 93 |
-
// Note: batch is ignored.
|
| 94 |
-
int const M = problem_size.m();
|
| 95 |
-
int const N = problem_size.n();
|
| 96 |
-
int const K = problem_size.k();
|
| 97 |
-
|
| 98 |
-
// Blocking necessary to speedup reference implementation
|
| 99 |
-
int const Mblock = 16;
|
| 100 |
-
int const Nblock = 16;
|
| 101 |
-
|
| 102 |
-
ConvertOp convert_op;
|
| 103 |
-
InnerProductOp inner_product_op;
|
| 104 |
-
|
| 105 |
-
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 106 |
-
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 107 |
-
|
| 108 |
-
complex<ComputeType> accum[Mblock][Nblock];
|
| 109 |
-
|
| 110 |
-
for (int j = 0; j < Nblock; j++) {
|
| 111 |
-
for (int i = 0; i < Mblock; i++) {
|
| 112 |
-
accum[i][j] = initial_accum;
|
| 113 |
-
}
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
for (int k_block = 0; k_block < K; ++k_block) {
|
| 117 |
-
for (int j = 0; j < Nblock; j++) {
|
| 118 |
-
for (int i = 0; i < Mblock; i++) {
|
| 119 |
-
int row = row_block + i;
|
| 120 |
-
int col = col_block + j;
|
| 121 |
-
|
| 122 |
-
if (row < M && col < N) {
|
| 123 |
-
|
| 124 |
-
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
|
| 125 |
-
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
|
| 126 |
-
|
| 127 |
-
complex<ComputeType> a = complex<ComputeType>{
|
| 128 |
-
ComputeType(a_ik.real()),
|
| 129 |
-
ComputeType(a_ik.imag())
|
| 130 |
-
};
|
| 131 |
-
|
| 132 |
-
complex<ComputeType> b = complex<ComputeType>{
|
| 133 |
-
ComputeType(b_kj.real()),
|
| 134 |
-
ComputeType(b_kj.imag())
|
| 135 |
-
};
|
| 136 |
-
|
| 137 |
-
if (transform_a == ComplexTransform::kConjugate) {
|
| 138 |
-
a = conj(a);
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
if (transform_b == ComplexTransform::kConjugate) {
|
| 142 |
-
b = conj(b);
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
accum[i][j] = inner_product_op(a, b, accum[i][j]);
|
| 146 |
-
}
|
| 147 |
-
}
|
| 148 |
-
}
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
for (int j = 0; j < Nblock; j++) {
|
| 152 |
-
for (int i = 0; i < Mblock; i++) {
|
| 153 |
-
int row = row_block + i;
|
| 154 |
-
int col = col_block + j;
|
| 155 |
-
|
| 156 |
-
MatrixCoord coord = MatrixCoord(row, col);
|
| 157 |
-
|
| 158 |
-
if (row < M && col < N) {
|
| 159 |
-
|
| 160 |
-
complex<ScalarType> acc{
|
| 161 |
-
ScalarType(accum[i][j].real()),
|
| 162 |
-
ScalarType(accum[i][j].imag())
|
| 163 |
-
};
|
| 164 |
-
|
| 165 |
-
ComplexC d_ij = tensor_c.at(coord);
|
| 166 |
-
|
| 167 |
-
complex<ScalarType> src{
|
| 168 |
-
ScalarType(d_ij.real()),
|
| 169 |
-
ScalarType(d_ij.imag())
|
| 170 |
-
};
|
| 171 |
-
|
| 172 |
-
complex<ScalarType> result = alpha * acc + beta * src;
|
| 173 |
-
|
| 174 |
-
d_ij.real() = convert_op(result.real());
|
| 175 |
-
d_ij.imag() = convert_op(result.imag());
|
| 176 |
-
|
| 177 |
-
tensor_d.at(coord) = d_ij;
|
| 178 |
-
}
|
| 179 |
-
}
|
| 180 |
-
}
|
| 181 |
-
}
|
| 182 |
-
}
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 186 |
-
|
| 187 |
-
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 188 |
-
/// objects.
|
| 189 |
-
///
|
| 190 |
-
/// This assumes the accumulator type is the same type as the scalars.
|
| 191 |
-
template <
|
| 192 |
-
typename ElementA,
|
| 193 |
-
typename LayoutA,
|
| 194 |
-
typename ElementB,
|
| 195 |
-
typename LayoutB,
|
| 196 |
-
typename ElementC,
|
| 197 |
-
typename LayoutC,
|
| 198 |
-
typename ScalarType
|
| 199 |
-
>
|
| 200 |
-
void GemmPlanarComplex(
|
| 201 |
-
gemm::GemmCoord problem_size,
|
| 202 |
-
complex<ScalarType> alpha,
|
| 203 |
-
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 204 |
-
ComplexTransform transform_a,
|
| 205 |
-
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 206 |
-
ComplexTransform transform_b,
|
| 207 |
-
complex<ScalarType> beta,
|
| 208 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 209 |
-
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
|
| 210 |
-
|
| 211 |
-
GemmPlanarComplex(
|
| 212 |
-
problem_size,
|
| 213 |
-
alpha,
|
| 214 |
-
tensor_a, transform_a,
|
| 215 |
-
tensor_b, transform_b,
|
| 216 |
-
beta,
|
| 217 |
-
tensor_c,
|
| 218 |
-
tensor_d,
|
| 219 |
-
complex<ScalarType>());
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 223 |
-
|
| 224 |
-
} // namespace host
|
| 225 |
-
} // namespace reference
|
| 226 |
-
} // namespace cutlass
|
| 227 |
-
|
| 228 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|