|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <c10/macros/Export.h>
|
|
|
#include <c10/macros/Macros.h>
|
|
|
#include <c10/util/floating_point_utils.h>
|
|
|
#include <type_traits>
|
|
|
|
|
|
|
|
|
#if defined(__cplusplus)
|
|
|
#include <cstdint>
|
|
|
#elif !defined(__OPENCL_VERSION__)
|
|
|
#include <math.h>
|
|
|
#include <stdint.h>
|
|
|
#endif
|
|
|
|
|
|
#include <iosfwd>
|
|
|
#include <ostream>
|
|
|
|
|
|
namespace c10 {
|
|
|
|
|
|
namespace detail {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
|
|
|
|
|
|
|
|
|
uint32_t f_bits = c10::detail::fp32_to_bits(f);
|
|
|
|
|
|
|
|
|
uint32_t exponent = (f_bits >> 23) & 0b11111111;
|
|
|
|
|
|
|
|
|
if (exponent == 0b11111111) {
|
|
|
return exponent;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint8_t g = (f_bits & 0x400000) > 0;
|
|
|
|
|
|
uint8_t r = (f_bits & 0x200000) > 0;
|
|
|
|
|
|
uint8_t s = (f_bits & 0x1FFFFF) > 0;
|
|
|
|
|
|
|
|
|
uint8_t lsb = exponent > 0;
|
|
|
|
|
|
|
|
|
bool round_up = false;
|
|
|
|
|
|
|
|
|
if (g == 1) {
|
|
|
if ((r == 1) || (s == 1)) {
|
|
|
|
|
|
round_up = true;
|
|
|
} else {
|
|
|
if (lsb == 1) {
|
|
|
|
|
|
round_up = true;
|
|
|
}
|
|
|
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if (round_up) {
|
|
|
|
|
|
|
|
|
|
|
|
exponent++;
|
|
|
}
|
|
|
|
|
|
return exponent;
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
struct alignas(1) Float8_e8m0fnu {
|
|
|
uint8_t x;
|
|
|
|
|
|
struct from_bits_t {};
|
|
|
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
|
|
return from_bits_t();
|
|
|
}
|
|
|
|
|
|
Float8_e8m0fnu() = default;
|
|
|
|
|
|
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
|
|
|
: x(bits) {}
|
|
|
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
|
|
|
inline C10_HOST_DEVICE operator float() const;
|
|
|
inline C10_HOST_DEVICE bool isnan() const;
|
|
|
};
|
|
|
|
|
|
C10_API inline std::ostream& operator<<(
|
|
|
std::ostream& out,
|
|
|
const Float8_e8m0fnu& value) {
|
|
|
out << (float)value;
|
|
|
return out;
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
#include <c10/util/Float8_e8m0fnu-inl.h>
|
|
|
|