File size: 610 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
#pragma once
namespace at::native {
#if defined(USE_ROCM)
// take these out when ROCm implements std:: math functions
#include <math.h>
template <typename scalar_t>
static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
template <>
__forceinline__ __device__ float device_sqrt(float val) {
return ::sqrtf(val);
}
template <>
__forceinline__ __device__ double device_sqrt(double val) {
return ::sqrt(val);
}
#else
template<typename scalar_t>
__forceinline__ __device__ double device_sqrt(scalar_t val) {
return std::sqrt(val);
}
#endif
} // namespace at::native
|