|
|
#pragma once
|
|
|
#include <ATen/native/Pow.h>
|
|
|
#include <c10/core/Scalar.h>
|
|
|
|
|
|
namespace at::native {
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef _MSC_VER
|
|
|
|
|
|
|
|
|
static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) {
|
|
|
return static_cast<at::Half>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
|
|
|
}
|
|
|
|
|
|
static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) {
|
|
|
return static_cast<at::BFloat16>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
|
|
|
}
|
|
|
|
|
|
template <typename Base_type, typename Exp_type>
|
|
|
static inline __host__ __device__ typename std::enable_if_t<std::is_floating_point_v<Base_type> && (std::is_same_v<Base_type, Exp_type> || std::is_same_v<Exp_type, int>), Base_type>
|
|
|
pow_(Base_type base, Exp_type exp) {
|
|
|
return std::pow(base, exp);
|
|
|
}
|
|
|
|
|
|
template <typename Base_type, typename Exp_type>
|
|
|
static inline __host__ __device__ typename std::enable_if_t<!std::is_same_v<Base_type, Exp_type> && !std::is_same_v<Exp_type, int>, Base_type>
|
|
|
pow_(Base_type base, Exp_type exp) {
|
|
|
return static_cast<Base_type>(std::pow(static_cast<double>(base), static_cast<double>(exp)));
|
|
|
}
|
|
|
#else
|
|
|
template <typename Base_type, typename Exp_type>
|
|
|
static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) {
|
|
|
return ::pow(base, exp);
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
template <typename T>
|
|
|
static inline __host__ __device__ std::enable_if_t<std::is_integral_v<T>, T> pow_(
|
|
|
T base, T exp) {
|
|
|
return at::native::powi(base, exp);
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
static inline __host__ __device__ c10::complex<T> pow_(c10::complex<T> base, c10::complex<T> exp) {
|
|
|
return c10_complex_math::pow(base, exp);
|
|
|
}
|
|
|
|
|
|
}
|
|
|
}
|
|
|
|