|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
#include <math.h> |
|
|
|
|
|
#include <cassert> |
|
|
#include <climits> |
|
|
#include <cstdint> |
|
|
#include <cstdlib> |
|
|
#include <cstring> |
|
|
|
|
|
#include "./Types.h" |
|
|
|
|
|
#ifndef __is_identifier |
|
|
#define __is_identifier(x) 1 |
|
|
#endif |
|
|
|
|
|
#define __has_keyword(__x) !(__is_identifier(__x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if __has_keyword(__fp16) && !defined(_WIN32) |
|
|
#define HAS_NATIVE_FP16_TYPE |
|
|
using native_fp16_t = __fp16; |
|
|
#elif __has_keyword(_Float16) && !defined(_WIN32) |
|
|
#define HAS_NATIVE_FP16_TYPE |
|
|
using native_fp16_t = _Float16; |
|
|
#else |
|
|
using native_fp16_t = void; |
|
|
#endif |
|
|
|
|
|
namespace fbgemm { |
|
|
|
|
|
namespace detail { |
|
|
|
|
|
template <typename T, int ExponentBits, bool HasInfinity = true> |
|
|
struct FloatFormat { |
|
|
using value_type = T; |
|
|
static constexpr int bits = sizeof(T) * CHAR_BIT; |
|
|
static constexpr int exponent_bits = ExponentBits; |
|
|
static constexpr int mantissa_bits = bits - exponent_bits - 1; |
|
|
static constexpr int sign_bit_pos = bits - 1; |
|
|
static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1; |
|
|
static constexpr int unbiased_exponent_min = -exponent_bias + 1; |
|
|
static constexpr int unbiased_exponent_max = |
|
|
HasInfinity ? exponent_bias : (exponent_bias + 1); |
|
|
static constexpr T sign_bit = T{1} << sign_bit_pos; |
|
|
static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1) |
|
|
<< mantissa_bits; |
|
|
static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1; |
|
|
|
|
|
static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1); |
|
|
|
|
|
static constexpr T nan = exponent_mask | mantissa_mask; |
|
|
static constexpr T overflow_value = HasInfinity ? exponent_mask : nan; |
|
|
static constexpr bool has_infinity = HasInfinity; |
|
|
static constexpr bool has_nan_payload = HasInfinity; |
|
|
}; |
|
|
|
|
|
using IEEE754Single = FloatFormat<uint32_t, 8>; |
|
|
using IEEE754Half = FloatFormat<uint16_t, 5>; |
|
|
|
|
|
using BFloat16 = FloatFormat<uint16_t, 8>; |
|
|
|
|
|
using FP8_E5M2 = FloatFormat<uint8_t, 5>; |
|
|
|
|
|
using FP8_E4M3FN = FloatFormat< |
|
|
uint8_t, |
|
|
4, |
|
|
false>; |
|
|
|
|
|
enum class RoundingMode { |
|
|
ToNearestTiesToEven, |
|
|
ToZero, |
|
|
}; |
|
|
|
|
|
|
|
|
template <typename Src, typename Tgt, RoundingMode RoundingMode> |
|
|
[[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc( |
|
|
typename Src::value_type value) { |
|
|
static_assert(Src::exponent_bits >= Tgt::exponent_bits); |
|
|
static_assert(Src::mantissa_bits > Tgt::mantissa_bits); |
|
|
using ST = typename Src::value_type; |
|
|
using TT = typename Tgt::value_type; |
|
|
|
|
|
ST src_exponent = value & Src::exponent_mask; |
|
|
ST src_mantissa = value & Src::mantissa_mask; |
|
|
|
|
|
|
|
|
if constexpr ( |
|
|
Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity && |
|
|
Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) { |
|
|
TT result = value >> (Src::bits - Tgt::bits); |
|
|
|
|
|
|
|
|
|
|
|
if (src_exponent == Src::exponent_mask && src_mantissa != 0) { |
|
|
result |= Tgt::quiet_nan_bit; |
|
|
} |
|
|
return result; |
|
|
} |
|
|
|
|
|
ST tgt_sign = |
|
|
(value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos); |
|
|
constexpr bool denormal_becomes_zero = |
|
|
Tgt::unbiased_exponent_min - Src::unbiased_exponent_min > |
|
|
Src::mantissa_bits - Tgt::mantissa_bits; |
|
|
if constexpr (denormal_becomes_zero) { |
|
|
|
|
|
|
|
|
if (src_exponent == 0) { |
|
|
return tgt_sign; |
|
|
} |
|
|
} |
|
|
|
|
|
int unbiased_exponent = |
|
|
(src_exponent >> Src::mantissa_bits) - Src::exponent_bias; |
|
|
if (unbiased_exponent < Tgt::unbiased_exponent_min) { |
|
|
int shift = Tgt::unbiased_exponent_min - unbiased_exponent; |
|
|
if (shift <= Tgt::mantissa_bits + 1) { |
|
|
|
|
|
ST src_mantissa_one = src_mantissa; |
|
|
|
|
|
if (denormal_becomes_zero || src_exponent != 0) { |
|
|
src_mantissa_one |= TT{1} << Src::mantissa_bits; |
|
|
} else { |
|
|
shift--; |
|
|
} |
|
|
TT tgt_mantissa = |
|
|
src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift); |
|
|
|
|
|
if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { |
|
|
int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1; |
|
|
ST half = 1 << half_pos; |
|
|
ST remainder = src_mantissa_one & ((half << 1) - 1); |
|
|
if (remainder > half || |
|
|
(remainder == half && (tgt_mantissa & 1) != 0)) { |
|
|
tgt_mantissa += 1; |
|
|
} |
|
|
} else { |
|
|
static_assert(RoundingMode == RoundingMode::ToZero); |
|
|
} |
|
|
return tgt_sign | tgt_mantissa; |
|
|
} else { |
|
|
|
|
|
return tgt_sign; |
|
|
} |
|
|
} |
|
|
|
|
|
if (unbiased_exponent > Tgt::unbiased_exponent_max) { |
|
|
if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) { |
|
|
TT tgt_mantissa; |
|
|
if constexpr (Tgt::has_nan_payload) { |
|
|
|
|
|
tgt_mantissa = |
|
|
src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); |
|
|
tgt_mantissa |= Tgt::quiet_nan_bit; |
|
|
} else { |
|
|
tgt_mantissa = Tgt::mantissa_mask; |
|
|
} |
|
|
return tgt_sign | Tgt::exponent_mask | tgt_mantissa; |
|
|
} else { |
|
|
if (RoundingMode == RoundingMode::ToZero && |
|
|
(!Src::has_infinity || src_exponent != Src::exponent_mask)) { |
|
|
|
|
|
return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) | |
|
|
Tgt::mantissa_mask; |
|
|
} |
|
|
|
|
|
return tgt_sign | Tgt::overflow_value; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); |
|
|
TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias) |
|
|
<< Tgt::mantissa_bits; |
|
|
if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { |
|
|
ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1); |
|
|
ST remainder = src_mantissa & ((half << 1) - 1); |
|
|
if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) { |
|
|
if (tgt_mantissa < Tgt::mantissa_mask) { |
|
|
tgt_mantissa += 1; |
|
|
} else { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) { |
|
|
tgt_mantissa = 0; |
|
|
tgt_exponent += TT{1} << Tgt::mantissa_bits; |
|
|
} else { |
|
|
|
|
|
tgt_mantissa = Tgt::mantissa_mask; |
|
|
} |
|
|
} |
|
|
} |
|
|
} else { |
|
|
static_assert(RoundingMode == RoundingMode::ToZero); |
|
|
} |
|
|
return tgt_sign | tgt_exponent | tgt_mantissa; |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
inline float16 cpu_float2half_rn(float f) { |
|
|
uint32_t f_u32 = 0; |
|
|
std::memcpy(&f_u32, &f, sizeof(f_u32)); |
|
|
return detail::ieee754_trunc< |
|
|
detail::IEEE754Single, |
|
|
detail::IEEE754Half, |
|
|
detail::RoundingMode::ToNearestTiesToEven>(f_u32); |
|
|
} |
|
|
|
|
|
inline float16 cpu_float2half_rz(float f) { |
|
|
uint32_t f_u32 = 0; |
|
|
std::memcpy(&f_u32, &f, sizeof(f_u32)); |
|
|
return detail::ieee754_trunc< |
|
|
detail::IEEE754Single, |
|
|
detail::IEEE754Half, |
|
|
detail::RoundingMode::ToZero>(f_u32); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
inline float cpu_half2float_ref(const float16 h) { |
|
|
constexpr uint32_t f16_num_exponent_bits = 5; |
|
|
constexpr uint32_t f16_num_mantissa_bits = 10; |
|
|
constexpr uint32_t f16_num_non_sign_bits = |
|
|
f16_num_exponent_bits + f16_num_mantissa_bits; |
|
|
constexpr uint32_t f16_exponent_bias = 15; |
|
|
constexpr uint32_t f16_exponent_mask = 0b1'1111; |
|
|
constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; |
|
|
|
|
|
constexpr uint32_t f32_num_exponent_bits = 8; |
|
|
constexpr uint32_t f32_num_mantissa_bits = 23; |
|
|
constexpr uint32_t f32_num_non_sign_bits = |
|
|
f32_num_exponent_bits + f32_num_mantissa_bits; |
|
|
constexpr uint32_t f32_exponent_bias = 127; |
|
|
constexpr uint32_t f32_exponent_mask = 0b1111'1111; |
|
|
constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; |
|
|
constexpr uint32_t f32_most_significant_bit = 1u << 22; |
|
|
|
|
|
|
|
|
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1; |
|
|
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask; |
|
|
|
|
|
uint32_t mantissa = (h & f16_mantissa_mask) |
|
|
<< (f32_num_mantissa_bits - f16_num_mantissa_bits); |
|
|
|
|
|
if (exponent == f16_exponent_mask) { |
|
|
if (mantissa) { |
|
|
mantissa = f32_mantissa_mask; |
|
|
sign_bit = 0; |
|
|
} |
|
|
exponent = f32_exponent_mask; |
|
|
} else if (!exponent) { |
|
|
if (mantissa) { |
|
|
uint32_t msb = 0; |
|
|
exponent = f32_exponent_bias - f16_exponent_bias + 1; |
|
|
do { |
|
|
msb = mantissa & f32_most_significant_bit; |
|
|
mantissa <<= 1; |
|
|
--exponent; |
|
|
} while (!msb); |
|
|
mantissa &= f32_mantissa_mask; |
|
|
} |
|
|
} else { |
|
|
exponent += f32_exponent_bias - f16_exponent_bias; |
|
|
} |
|
|
|
|
|
const uint32_t i = (sign_bit << f32_num_non_sign_bits) | |
|
|
(exponent << f32_num_mantissa_bits) | mantissa; |
|
|
|
|
|
float ret = NAN; |
|
|
std::memcpy(&ret, &i, sizeof(float)); |
|
|
return ret; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
inline float cpu_half2float(const float16 h) { |
|
|
#if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) |
|
|
__fp16 h_fp16 = NAN; |
|
|
std::memcpy(&h_fp16, &h, sizeof(__fp16)); |
|
|
return h_fp16; |
|
|
#else |
|
|
return cpu_half2float_ref(h); |
|
|
#endif |
|
|
} |
|
|
|
|
|
inline float16 cpu_float2half(const float f) { |
|
|
#if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) |
|
|
__fp16 h = f; |
|
|
float16 res = 0; |
|
|
std::memcpy(&res, &h, sizeof(__fp16)); |
|
|
return res; |
|
|
#else |
|
|
return cpu_float2half_rn(f); |
|
|
#endif |
|
|
} |
|
|
|
|
|
inline float cpu_bf162float(bfloat16 src) { |
|
|
float ret = NAN; |
|
|
uint32_t val_fp32 = |
|
|
static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16; |
|
|
std::memcpy(&ret, &val_fp32, sizeof(float)); |
|
|
return ret; |
|
|
} |
|
|
|
|
|
inline bfloat16 cpu_float2bfloat16(float src) { |
|
|
uint32_t temp = 0; |
|
|
std::memcpy(&temp, &src, sizeof(uint32_t)); |
|
|
return (temp + (1u << 15)) >> 16; |
|
|
} |
|
|
|
|
|
} |
|
|
|