| namespace at::native { | |
| // take these out when ROCm implements std:: math functions | |
| template <typename scalar_t> | |
| static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); | |
| template <> | |
| __forceinline__ __device__ float device_sqrt(float val) { | |
| return ::sqrtf(val); | |
| } | |
| template <> | |
| __forceinline__ __device__ double device_sqrt(double val) { | |
| return ::sqrt(val); | |
| } | |
| template<typename scalar_t> | |
| __forceinline__ __device__ double device_sqrt(scalar_t val) { | |
| return std::sqrt(val); | |
| } | |
| } // namespace at::native | |