|
|
#pragma once
|
|
|
|
|
|
#include <c10/util/BFloat16.h>
|
|
|
#include <c10/util/Half.h>
|
|
|
|
|
|
C10_CLANG_DIAGNOSTIC_PUSH()
|
|
|
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
|
|
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
|
|
|
#endif
|
|
|
|
|
|
namespace c10 {
|
|
|
template <typename T>
|
|
|
struct is_reduced_floating_point
|
|
|
: std::integral_constant<
|
|
|
bool,
|
|
|
std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>> {};
|
|
|
|
|
|
template <typename T>
|
|
|
constexpr bool is_reduced_floating_point_v =
|
|
|
is_reduced_floating_point<T>::value;
|
|
|
}
|
|
|
|
|
|
namespace std {
|
|
|
|
|
|
#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
|
|
|
using c10::is_reduced_floating_point;
|
|
|
using c10::is_reduced_floating_point_v;
|
|
|
#endif
|
|
|
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T acos(T a) {
|
|
|
return std::acos(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T asin(T a) {
|
|
|
return std::asin(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T atan(T a) {
|
|
|
return std::atan(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T atanh(T a) {
|
|
|
return std::atanh(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T erf(T a) {
|
|
|
return std::erf(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T erfc(T a) {
|
|
|
return std::erfc(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T exp(T a) {
|
|
|
return std::exp(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T expm1(T a) {
|
|
|
return std::expm1(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline bool isfinite(T a) {
|
|
|
return std::isfinite(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T log(T a) {
|
|
|
return std::log(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T log10(T a) {
|
|
|
return std::log10(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T log1p(T a) {
|
|
|
return std::log1p(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T log2(T a) {
|
|
|
return std::log2(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T ceil(T a) {
|
|
|
return std::ceil(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T cos(T a) {
|
|
|
return std::cos(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T floor(T a) {
|
|
|
return std::floor(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T nearbyint(T a) {
|
|
|
return std::nearbyint(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T sin(T a) {
|
|
|
return std::sin(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T tan(T a) {
|
|
|
return std::tan(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T sinh(T a) {
|
|
|
return std::sinh(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T cosh(T a) {
|
|
|
return std::cosh(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T tanh(T a) {
|
|
|
return std::tanh(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T trunc(T a) {
|
|
|
return std::trunc(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T lgamma(T a) {
|
|
|
return std::lgamma(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T sqrt(T a) {
|
|
|
return std::sqrt(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T rsqrt(T a) {
|
|
|
return 1.0 / std::sqrt(float(a));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T abs(T a) {
|
|
|
return std::abs(float(a));
|
|
|
}
|
|
|
#if defined(_MSC_VER) && defined(__CUDACC__)
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T pow(T a, double b) {
|
|
|
return std::pow(float(a), float(b));
|
|
|
}
|
|
|
#else
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T pow(T a, double b) {
|
|
|
return std::pow(float(a), b);
|
|
|
}
|
|
|
#endif
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T pow(T a, T b) {
|
|
|
return std::pow(float(a), float(b));
|
|
|
}
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
inline T fmod(T a, T b) {
|
|
|
return std::fmod(float(a), float(b));
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename T,
|
|
|
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
|
|
|
C10_HOST_DEVICE inline T nextafter(T from, T to) {
|
|
|
|
|
|
|
|
|
using int_repr_t = uint16_t;
|
|
|
constexpr uint8_t bits = 16;
|
|
|
union {
|
|
|
T f;
|
|
|
int_repr_t i;
|
|
|
} ufrom = {from}, uto = {to};
|
|
|
|
|
|
|
|
|
int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
|
|
|
|
|
|
|
|
|
if (from != from || to != to) {
|
|
|
return from + to;
|
|
|
}
|
|
|
|
|
|
|
|
|
if (ufrom.i == uto.i) {
|
|
|
return from;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int_repr_t abs_from = ufrom.i & ~sign_mask;
|
|
|
int_repr_t abs_to = uto.i & ~sign_mask;
|
|
|
if (abs_from == 0) {
|
|
|
|
|
|
|
|
|
if (abs_to == 0) {
|
|
|
return to;
|
|
|
}
|
|
|
|
|
|
ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
|
|
|
return ufrom.f;
|
|
|
}
|
|
|
|
|
|
|
|
|
if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
|
|
|
ufrom.i--;
|
|
|
} else {
|
|
|
ufrom.i++;
|
|
|
}
|
|
|
|
|
|
return ufrom.f;
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
C10_CLANG_DIAGNOSTIC_POP()
|
|
|
|