|
|
|
|
|
|
|
|
#include <ATen/jit_macros.h> |
|
|
|
|
|
|
|
|
|
|
|
#include <ATen/cuda/CUDAConfig.h> |
|
|
|
|
|
#include <ATen/OpMathType.h> |
|
|
#include <ATen/TensorIterator.h> |
|
|
#include <ATen/native/TensorIteratorDynamicCasting.h> |
|
|
|
|
|
#include <ATen/native/cuda/MemoryAccess.cuh> |
|
|
|
|
|
#include <ATen/native/cuda/CUDAJitLoops.cuh> |
|
|
|
|
|
namespace at { |
|
|
namespace native { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
char const* name, |
|
|
typename return_type, |
|
|
typename f_inputs_type, |
|
|
int arity, |
|
|
typename... Args> |
|
|
void jitted_gpu_kernel( |
|
|
TensorIteratorBase& iter, |
|
|
const std::string& f, |
|
|
at::cuda::jit::BinaryFuncVariant scalar_pos = |
|
|
at::cuda::jit::BinaryFuncVariant::NoScalar, |
|
|
at::opmath_type<f_inputs_type> scalar_val = 0, |
|
|
std::tuple<Args...> extra_args = std::make_tuple()) { |
|
|
|
|
|
|
|
|
for (int arg = 0; arg < iter.ntensors(); arg++) { |
|
|
TORCH_INTERNAL_ASSERT( |
|
|
iter.device(arg).is_cuda(), |
|
|
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); |
|
|
} |
|
|
|
|
|
if (iter.numel() == 0) { |
|
|
return; |
|
|
} |
|
|
|
|
|
if (!iter.can_use_32bit_indexing()) { |
|
|
for (auto& sub_iter : iter.with_32bit_indexing()) { |
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, arity>( |
|
|
sub_iter, f, scalar_pos, scalar_val, extra_args); |
|
|
} |
|
|
|
|
|
return; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool needs_dynamic_casting = false; |
|
|
|
|
|
|
|
|
const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value; |
|
|
const auto dtype0 = iter.dtype(0); |
|
|
if (dtype0 != return_scalar_type) { |
|
|
needs_dynamic_casting = true; |
|
|
} |
|
|
|
|
|
|
|
|
const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value; |
|
|
for (auto i = decltype(arity){1}; i < (arity + 1); ++i) { |
|
|
const auto dtypei = iter.dtype(i); |
|
|
if (dtypei != inputs_scalar_type) { |
|
|
needs_dynamic_casting = true; |
|
|
break; |
|
|
} |
|
|
} |
|
|
if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) { |
|
|
|
|
|
|
|
|
|
|
|
jitted_gpu_kernel_impl< |
|
|
name, |
|
|
return_type, |
|
|
f_inputs_type, |
|
|
arity, |
|
|
at::cuda::jit::BinaryFuncVariant::NoScalar>( |
|
|
iter, f, needs_dynamic_casting, scalar_val, extra_args); |
|
|
} else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) { |
|
|
jitted_gpu_kernel_impl< |
|
|
name, |
|
|
return_type, |
|
|
f_inputs_type, |
|
|
arity, |
|
|
at::cuda::jit::BinaryFuncVariant::RhsScalar>( |
|
|
iter, |
|
|
f, |
|
|
needs_dynamic_casting, |
|
|
scalar_val, |
|
|
extra_args); |
|
|
|
|
|
} else { |
|
|
jitted_gpu_kernel_impl< |
|
|
name, |
|
|
return_type, |
|
|
f_inputs_type, |
|
|
arity, |
|
|
at::cuda::jit::BinaryFuncVariant::LhsScalar>( |
|
|
iter, |
|
|
f, |
|
|
needs_dynamic_casting, |
|
|
scalar_val, |
|
|
extra_args); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <char const *name, typename return_type, typename f_inputs_type> |
|
|
void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) { |
|
|
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); |
|
|
|
|
|
using opmath_t = at::opmath_type<f_inputs_type>; |
|
|
if (iter.is_cpu_scalar(1)) { |
|
|
auto scalar_val = iter.scalar_value<opmath_t>(1); |
|
|
iter.remove_operand(1); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const OptionalDeviceGuard device_guard(iter.device(1)); |
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val); |
|
|
} else if (iter.is_cpu_scalar(2)) { |
|
|
auto scalar_val = iter.scalar_value<opmath_t>(2); |
|
|
iter.remove_operand(2); |
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val); |
|
|
} else { |
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f); |
|
|
} |
|
|
} |
|
|
|
|
|
}} |
|
|
|
|
|
|
|
|
|