File size: 11,625 Bytes
d1d4335 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cassert>
#include <climits>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "./Types.h"
#ifndef __is_identifier
#define __is_identifier(x) 1
#endif
#define __has_keyword(__x) !(__is_identifier(__x))
// TODO: we're disabling native fp16 on Windows to workaround test failures
// due to "undefined symbol __gnu_h2f_ieee" error. We should follup on this
// later.
#if __has_keyword(__fp16) && !defined(_WIN32)
#define HAS_NATIVE_FP16_TYPE
using native_fp16_t = __fp16;
#elif __has_keyword(_Float16) && !defined(_WIN32)
#define HAS_NATIVE_FP16_TYPE
using native_fp16_t = _Float16;
#else
using native_fp16_t = void;
#endif
namespace fbgemm {
namespace detail {
template <typename T, int ExponentBits, bool HasInfinity = true>
struct FloatFormat {
using value_type = T;
static constexpr int bits = sizeof(T) * CHAR_BIT;
static constexpr int exponent_bits = ExponentBits;
static constexpr int mantissa_bits = bits - exponent_bits - 1;
static constexpr int sign_bit_pos = bits - 1;
static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1;
static constexpr int unbiased_exponent_min = -exponent_bias + 1;
static constexpr int unbiased_exponent_max =
HasInfinity ? exponent_bias : (exponent_bias + 1);
static constexpr T sign_bit = T{1} << sign_bit_pos;
static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1)
<< mantissa_bits;
static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1;
// signaling/quiet encoding is unspecified by IEEE754. This mirrors x86/ARM.
static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1);
static constexpr T nan = exponent_mask | mantissa_mask;
static constexpr T overflow_value = HasInfinity ? exponent_mask : nan;
static constexpr bool has_infinity = HasInfinity;
static constexpr bool has_nan_payload = HasInfinity;
};
using IEEE754Single = FloatFormat</*T=*/uint32_t, /*ExponentBits=*/8>;
using IEEE754Half = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/5>;
// See https://arxiv.org/abs/1905.12322v3
using BFloat16 = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/8>;
// See https://doi.org/10.48550/arXiv.2209.05433
using FP8_E5M2 = FloatFormat</*T=*/uint8_t, /*ExponentBits=*/5>;
// See https://doi.org/10.48550/arXiv.2209.05433
using FP8_E4M3FN = FloatFormat<
/*T=*/uint8_t,
/*ExponentBits=*/4,
/*HasInfinity=*/false>;
enum class RoundingMode {
ToNearestTiesToEven,
ToZero,
};
// Generic IEEE754 truncation algorithm.
template <typename Src, typename Tgt, RoundingMode RoundingMode>
[[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc(
typename Src::value_type value) {
static_assert(Src::exponent_bits >= Tgt::exponent_bits);
static_assert(Src::mantissa_bits > Tgt::mantissa_bits);
using ST = typename Src::value_type;
using TT = typename Tgt::value_type;
ST src_exponent = value & Src::exponent_mask;
ST src_mantissa = value & Src::mantissa_mask;
// Fast-path: If there is no difference in exponent sizes (e.g. fp32 -> bf16)
// and we round toward zero, then we can just drop the least significant bits.
if constexpr (
Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity &&
Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) {
TT result = value >> (Src::bits - Tgt::bits);
// Turn signaling NaN into quiet NaN. This also avoids that the mantissa
// is completely zero after truncation (which would be misinterpreted as
// INF).
if (src_exponent == Src::exponent_mask && src_mantissa != 0) {
result |= Tgt::quiet_nan_bit;
}
return result;
}
ST tgt_sign =
(value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos);
constexpr bool denormal_becomes_zero =
Tgt::unbiased_exponent_min - Src::unbiased_exponent_min >
Src::mantissa_bits - Tgt::mantissa_bits;
if constexpr (denormal_becomes_zero) {
// Fast-path for zero exponentbits: This means the number was zero or a
// denormal number that will turn into zero in the Tgt format.
if (src_exponent == 0) {
return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0
}
}
int unbiased_exponent =
(src_exponent >> Src::mantissa_bits) - Src::exponent_bias;
if (unbiased_exponent < Tgt::unbiased_exponent_min) {
int shift = Tgt::unbiased_exponent_min - unbiased_exponent;
if (shift <= Tgt::mantissa_bits + 1) {
// Result is denormal.
ST src_mantissa_one = src_mantissa;
// Add explicit one if the source was not denormal.
if (denormal_becomes_zero || src_exponent != 0) {
src_mantissa_one |= TT{1} << Src::mantissa_bits;
} else {
shift--;
}
TT tgt_mantissa =
src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift);
if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) {
int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1;
ST half = 1 << half_pos;
ST remainder = src_mantissa_one & ((half << 1) - 1);
if (remainder > half ||
(remainder == half && (tgt_mantissa & 1) != 0)) {
tgt_mantissa += 1;
}
} else {
assert(RoundingMode == RoundingMode::ToZero);
}
return tgt_sign | tgt_mantissa; // tgt_exponent == 0
} else {
// Result is +/- zero
return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0
}
}
if (unbiased_exponent > Tgt::unbiased_exponent_max) {
if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) {
TT tgt_mantissa;
if constexpr (Tgt::has_nan_payload) {
// NaN; not a number
tgt_mantissa =
src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits);
tgt_mantissa |= Tgt::quiet_nan_bit;
} else {
tgt_mantissa = Tgt::mantissa_mask;
}
return tgt_sign | Tgt::exponent_mask | tgt_mantissa;
} else {
if (RoundingMode == RoundingMode::ToZero &&
(!Src::has_infinity || src_exponent != Src::exponent_mask)) {
// Return largest finite number.
return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) |
Tgt::mantissa_mask;
}
// Infinity or NaN for formats without infinity.
return tgt_sign | Tgt::overflow_value;
}
}
// Normal number.
TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits);
TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias)
<< Tgt::mantissa_bits;
if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) {
ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1);
ST remainder = src_mantissa & ((half << 1) - 1);
if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) {
if (tgt_mantissa < Tgt::mantissa_mask) {
tgt_mantissa += 1;
} else {
// Mantissa overflowed, increment exponent.
// Normally we can just add to the exponent and will naturally end up
// on infinity on overflow. But we need special treatments for formats
// without infinity.
if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) {
tgt_mantissa = 0;
tgt_exponent += TT{1} << Tgt::mantissa_bits;
} else {
// Return NaN.
tgt_mantissa = Tgt::mantissa_mask;
}
}
}
} else {
assert(RoundingMode == RoundingMode::ToZero);
}
return tgt_sign | tgt_exponent | tgt_mantissa;
}
} // namespace detail
inline float16 cpu_float2half_rn(float f) {
uint32_t f_u32;
std::memcpy(&f_u32, &f, sizeof(f_u32));
return detail::ieee754_trunc<
/*Src=*/detail::IEEE754Single,
/*Tgt=*/detail::IEEE754Half,
detail::RoundingMode::ToNearestTiesToEven>(f_u32);
}
inline float16 cpu_float2half_rz(float f) {
uint32_t f_u32;
std::memcpy(&f_u32, &f, sizeof(f_u32));
return detail::ieee754_trunc<
/*Src=*/detail::IEEE754Single,
/*Tgt=*/detail::IEEE754Half,
detail::RoundingMode::ToZero>(f_u32);
};
// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
// float into an IEEE754 32-bit single-precision float
inline float cpu_half2float_ref(const float16 h) {
constexpr uint32_t f16_num_exponent_bits = 5;
constexpr uint32_t f16_num_mantissa_bits = 10;
constexpr uint32_t f16_num_non_sign_bits =
f16_num_exponent_bits + f16_num_mantissa_bits;
constexpr uint32_t f16_exponent_bias = 15;
constexpr uint32_t f16_exponent_mask = 0b1'1111;
constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111;
constexpr uint32_t f32_num_exponent_bits = 8;
constexpr uint32_t f32_num_mantissa_bits = 23;
constexpr uint32_t f32_num_non_sign_bits =
f32_num_exponent_bits + f32_num_mantissa_bits;
constexpr uint32_t f32_exponent_bias = 127;
constexpr uint32_t f32_exponent_mask = 0b1111'1111;
constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF;
constexpr uint32_t f32_most_significant_bit = 1u << 22;
// Get sign and exponent alone by themselves
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
// Shift mantissa so that it fills the most significant bits of a float32
uint32_t mantissa = (h & f16_mantissa_mask)
<< (f32_num_mantissa_bits - f16_num_mantissa_bits);
if (exponent == f16_exponent_mask) { // NaN or Inf
if (mantissa) {
mantissa = f32_mantissa_mask;
sign_bit = 0;
}
exponent = f32_exponent_mask;
} else if (!exponent) { // Denorm or Zero
if (mantissa) {
uint32_t msb;
exponent = f32_exponent_bias - f16_exponent_bias + 1;
do {
msb = mantissa & f32_most_significant_bit;
mantissa <<= 1; // normalize
--exponent;
} while (!msb);
mantissa &= f32_mantissa_mask; // 1.mantissa is implicit
}
} else {
exponent += f32_exponent_bias - f16_exponent_bias;
}
const uint32_t i = (sign_bit << f32_num_non_sign_bits) |
(exponent << f32_num_mantissa_bits) | mantissa;
float ret;
std::memcpy(&ret, &i, sizeof(float));
return ret;
}
// Same as the previous function, but use the built-in fp16 to fp32
// conversion provided by the compiler
inline float cpu_half2float(const float16 h) {
#if defined(HAS_NATIVE_FP16_TYPE) && defined(HAVE_GNU_F2H_IEEE)
__fp16 h_fp16;
std::memcpy(&h_fp16, &h, sizeof(__fp16));
return h_fp16;
#else
return cpu_half2float_ref(h);
#endif
}
inline float16 cpu_float2half(const float f) {
#if defined(HAS_NATIVE_FP16_TYPE) && defined(HAVE_GNU_F2H_IEEE)
__fp16 h = f;
float16 res;
std::memcpy(&res, &h, sizeof(__fp16));
return res;
#else
return cpu_float2half_rn(f);
#endif
}
inline float cpu_bf162float(bfloat16 src) {
float ret;
uint32_t val_fp32 =
static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16;
std::memcpy(&ret, &val_fp32, sizeof(float));
return ret;
}
inline bfloat16 cpu_float2bfloat16(float src) {
uint32_t temp;
std::memcpy(&temp, &src, sizeof(uint32_t));
return (temp + (1u << 15)) >> 16;
}
} // namespace fbgemm
|