|
|
|
|
|
|
|
|
#include <ATen/jit_macros.h>
|
|
|
|
|
|
|
|
|
|
|
|
#include <ATen/cuda/CUDAConfig.h>
|
|
|
|
|
|
#include <ATen/OpMathType.h>
|
|
|
#include <ATen/TensorIterator.h>
|
|
|
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
|
|
|
|
|
#include <ATen/native/cuda/MemoryAccess.cuh>
|
|
|
|
|
|
#include <ATen/native/cuda/CUDAJitLoops.cuh>
|
|
|
|
|
|
namespace at::native {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
char const* name,
|
|
|
typename return_type,
|
|
|
typename f_inputs_type,
|
|
|
int arity,
|
|
|
typename... Args>
|
|
|
void jitted_gpu_kernel(
|
|
|
TensorIteratorBase& iter,
|
|
|
const std::string& f,
|
|
|
at::cuda::jit::BinaryFuncVariant scalar_pos =
|
|
|
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
|
|
at::opmath_type<f_inputs_type> scalar_val = 0,
|
|
|
std::tuple<Args...> extra_args = std::make_tuple()) {
|
|
|
|
|
|
|
|
|
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
|
iter.device(arg).is_cuda(),
|
|
|
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
|
|
|
}
|
|
|
|
|
|
if (iter.numel() == 0) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
if (!iter.can_use_32bit_indexing()) {
|
|
|
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
|
|
|
sub_iter, f, scalar_pos, scalar_val, extra_args);
|
|
|
}
|
|
|
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool needs_dynamic_casting = false;
|
|
|
|
|
|
|
|
|
const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
|
|
|
const auto dtype0 = iter.dtype(0);
|
|
|
if (dtype0 != return_scalar_type) {
|
|
|
needs_dynamic_casting = true;
|
|
|
}
|
|
|
|
|
|
|
|
|
const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
|
|
|
for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
|
|
|
const auto dtypei = iter.dtype(i);
|
|
|
if (dtypei != inputs_scalar_type) {
|
|
|
needs_dynamic_casting = true;
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
|
|
|
|
|
|
|
|
|
|
|
|
jitted_gpu_kernel_impl<
|
|
|
name,
|
|
|
return_type,
|
|
|
f_inputs_type,
|
|
|
arity,
|
|
|
at::cuda::jit::BinaryFuncVariant::NoScalar>(
|
|
|
iter, f, needs_dynamic_casting, scalar_val, extra_args);
|
|
|
} else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
|
|
|
jitted_gpu_kernel_impl<
|
|
|
name,
|
|
|
return_type,
|
|
|
f_inputs_type,
|
|
|
arity,
|
|
|
at::cuda::jit::BinaryFuncVariant::RhsScalar>(
|
|
|
iter,
|
|
|
f,
|
|
|
needs_dynamic_casting,
|
|
|
scalar_val,
|
|
|
extra_args);
|
|
|
|
|
|
} else {
|
|
|
jitted_gpu_kernel_impl<
|
|
|
name,
|
|
|
return_type,
|
|
|
f_inputs_type,
|
|
|
arity,
|
|
|
at::cuda::jit::BinaryFuncVariant::LhsScalar>(
|
|
|
iter,
|
|
|
f,
|
|
|
needs_dynamic_casting,
|
|
|
scalar_val,
|
|
|
extra_args);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
template <char const *name, typename return_type, typename f_inputs_type>
|
|
|
void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
|
|
|
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
|
|
|
|
|
|
using opmath_t = at::opmath_type<f_inputs_type>;
|
|
|
if (iter.is_cpu_scalar(1)) {
|
|
|
auto scalar_val = iter.scalar_value<opmath_t>(1);
|
|
|
iter.remove_operand(1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const OptionalDeviceGuard device_guard(iter.device(1));
|
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
|
|
|
} else if (iter.is_cpu_scalar(2)) {
|
|
|
auto scalar_val = iter.scalar_value<opmath_t>(2);
|
|
|
iter.remove_operand(2);
|
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
|
|
|
} else {
|
|
|
jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|