|
|
#pragma once |
|
|
|
|
|
namespace at { namespace native { |
|
|
#if defined(USE_ROCM) |
|
|
|
|
|
#include <math.h> |
|
|
template <typename scalar_t> |
|
|
static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); |
|
|
|
|
|
template <> |
|
|
__forceinline__ __device__ float device_sqrt(float val) { |
|
|
return ::sqrtf(val); |
|
|
} |
|
|
|
|
|
template <> |
|
|
__forceinline__ __device__ double device_sqrt(double val) { |
|
|
return ::sqrt(val); |
|
|
} |
|
|
#else |
|
|
template<typename scalar_t> |
|
|
__forceinline__ __device__ double device_sqrt(scalar_t val) { |
|
|
return std::sqrt(val); |
|
|
} |
|
|
#endif |
|
|
}} |
|
|
|