File size: 12,068 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
#pragma once
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/OpMathType.h>
#include <ATen/native/cuda/thread_constants.h>
#include <thrust/tuple.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <tuple>
namespace at::native {
template<int N>
static OffsetCalculator<N> make_input_offset_calculator(const TensorIteratorBase& iter) {
// array size can not be 0, this happens when N == 0
constexpr int array_size = std::max<int>(N, 1);
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
std::array<const int64_t*, array_size> strides;
int64_t element_sizes[array_size];
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i + iter.noutputs()).data();
element_sizes[i] = iter.element_size(i + iter.noutputs());
}
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}
template <int num_outputs = 1>
static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
std::array<const int64_t*, num_outputs> strides;
int64_t element_sizes[num_outputs];
for (int i = 0; i < num_outputs; i++) {
strides[i] = iter.strides(i).data();
element_sizes[i] = iter.element_size(i);
}
return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}
template <bool reverted_idx = false, typename func_t, typename policy_t>
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;
constexpr int elems_per_thread = policy_t::tws;
int idx = blockIdx.x;
if constexpr (reverted_idx)
idx = gridDim.x - blockIdx.x - 1;
return_t results[elems_per_thread];
args_t args[elems_per_thread];
// load
policy.load(args, idx);
// compute
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}
}
// store
policy.store(results, idx);
}
} // namespace at::native
#include <ATen/native/cuda/CUDALoops.cuh>
namespace at:: native {
template <typename func_t>
void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) {
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(
iter.device(arg).is_cuda(),
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel_nocast(sub_iter, f);
}
return;
}
gpu_kernel_impl_nocast(iter, f);
}
template <typename func_t>
void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(
iter.device(arg).is_cuda(),
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel(sub_iter, f);
}
return;
}
gpu_kernel_impl(iter, f);
}
template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
struct AUnaryFunctor {
using traits = function_traits<func_t>;
using opmath_arg1_t = typename traits::template arg<0>::type;
__device__ return_t operator()(arg2_t b) const {
return f(a, b);
}
// NB: scalar is stored in higher precision!
AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {}
private:
func_t f;
opmath_arg1_t a;
};
template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
struct BUnaryFunctor {
using traits = function_traits<func_t>;
using opmath_arg2_t = typename traits::template arg<1>::type;
__device__ return_t operator()(arg1_t a) const {
return f(a, b);
}
// NB: scalar is stored in higher precision!
BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {}
private:
func_t f;
opmath_arg2_t b;
};
// Though seemingly noop, this inserts casts from arg1_t to func_t's type
// (which may be higher precision), as well as casts to return_t
template <typename arg1_t, typename arg2_t, typename return_t, typename func_t>
struct BinaryFunctor {
__device__ return_t operator()(arg1_t a, arg2_t b) const {
return f(a, b);
}
BinaryFunctor(func_t f_): f(f_) {}
private:
func_t f;
};
// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which
// accepts inputs at higher precision (typically opmath_t), but then
// ensure that we load from memory at the correct precision (scalar_t)
// to avoid expensive loads. For the whole sordid story see
// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
template <typename arg1_t, typename arg2_t = arg1_t, typename return_t = arg1_t, typename func_t>
void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
using opmath_arg1_t = typename traits::template arg<0>::type;
using opmath_arg2_t = typename traits::template arg<1>::type;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
if (iter.is_cpu_scalar(1)) {
AUnaryFunctor<arg1_t, arg2_t, return_t, func_t> af(f, iter.scalar_value<opmath_arg1_t>(1));
iter.remove_operand(1);
// TODO: When all kernels that use gpu_kernel_with_scalars are
// ported to structured, this device guard can be deleted. This
// works around incorrect device guard generation for pre-structured
// kernels device guards, but structured kernels do it right and
// we can assume the device is already set correctly
const OptionalDeviceGuard device_guard(iter.device(1));
gpu_kernel(iter, af);
} else if (iter.is_cpu_scalar(2)) {
BUnaryFunctor<arg1_t, arg2_t, return_t, func_t> bf(f, iter.scalar_value<opmath_arg2_t>(2));
iter.remove_operand(2);
gpu_kernel(iter, bf);
} else {
gpu_kernel(iter, BinaryFunctor<arg1_t, arg2_t, return_t, func_t>(f));
}
}
template <typename scalar_t, typename return_t = scalar_t, typename func_t>
void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
// Use symmetric property of the functor to reduce number of kernels,
// requires f(a, b) == f(b, a)
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
using opmath_arg_t = typename traits::template arg<0>::type;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
static_assert(std::is_same_v<opmath_arg_t, typename traits::template arg<1>::type>,
"f is not symmetric");
OptionalDeviceGuard device_guard;
opmath_arg_t scalar_val{};
if (iter.is_cpu_scalar(1)) {
scalar_val = iter.scalar_value<opmath_arg_t>(1);
iter.remove_operand(1);
// TODO: When all kernels that use gpu_kernel_with_scalars are
// ported to structured, this device guard can be deleted. This
// works around incorrect device guard generation for pre-structured
// kernels device guards, but structured kernels do it right and
// we can assume the device is already set correctly
device_guard.reset_device(iter.device(1));
} else if (iter.is_cpu_scalar(2)) {
scalar_val = iter.scalar_value<opmath_arg_t>(2);
iter.remove_operand(2);
}
if (iter.ninputs() == 2) {
gpu_kernel(iter, BinaryFunctor<scalar_t, scalar_t, return_t, func_t>(f));
} else {
AUnaryFunctor<scalar_t, scalar_t, return_t, func_t> unary_f(f, scalar_val);
gpu_kernel(iter, unary_f);
}
}
// Legacy variant that assumes that func_t has the correct types
// that we expect to load from memory
template <typename func_t>
void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
using return_t = typename traits::result_type;
opmath_gpu_kernel_with_scalars<arg1_t, arg2_t, return_t, func_t>(iter, f);
}
namespace { // functions for `gpu_kernel_multiple_outputs`.
// check the return type is `thrust::tuple`, not `std::tuple`.
template <typename T> struct is_tuple: std::false_type {};
template <typename ...T> struct is_tuple<thrust::tuple<T...>>: std::true_type {};
template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) {
int remaining = N - block_work_size() * blockIdx.x;
elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll<array_t, inp_calc_t, out_calc_t, num_outputs>(data, remaining, ic, oc));
}
template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel_for_multi_outputs<num_outputs, func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename func_t>
void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
using output_t = typename traits::result_type;
static_assert(is_tuple<output_t>::value, "f's return type must be `thrust::tuple`");
constexpr int num_outputs = thrust::tuple_size<output_t>::value;
constexpr int num_inputs = traits::arity;
constexpr int ntensors = num_outputs + num_inputs;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
std::array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
if (iter.is_contiguous()) {
auto input_calc = TrivialOffsetCalculator<num_inputs>();
auto output_calc = TrivialOffsetCalculator<num_outputs>();
launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
} else {
auto input_calc = make_input_offset_calculator<num_inputs>(iter);
auto output_calc = make_output_offset_calculator<num_outputs>(iter);
launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
}
}
} // namespace
template <typename func_t>
void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel_multiple_outputs(sub_iter, f);
}
return;
}
gpu_kernel_multiple_outputs_impl(iter, f);
}
} //namespace at::native
|