File size: 11,625 Bytes
d1d4335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
/*

 * Copyright (c) Meta Platforms, Inc. and affiliates.

 * All rights reserved.

 *

 * This source code is licensed under the BSD-style license found in the

 * LICENSE file in the root directory of this source tree.

 */

#pragma once

#include <cassert>
#include <climits>
#include <cstdint>
#include <cstdlib>
#include <cstring>

#include "./Types.h"

#ifndef __is_identifier
#define __is_identifier(x) 1
#endif

#define __has_keyword(__x) !(__is_identifier(__x))

// TODO: we're disabling native fp16 on Windows to workaround test failures
// due to "undefined symbol __gnu_h2f_ieee" error. We should follup on this
// later.
#if __has_keyword(__fp16) && !defined(_WIN32)
#define HAS_NATIVE_FP16_TYPE
using native_fp16_t = __fp16;
#elif __has_keyword(_Float16) && !defined(_WIN32)
#define HAS_NATIVE_FP16_TYPE
using native_fp16_t = _Float16;
#else
using native_fp16_t = void;
#endif

namespace fbgemm {

namespace detail {

template <typename T, int ExponentBits, bool HasInfinity = true>
struct FloatFormat {
  using value_type = T;
  static constexpr int bits = sizeof(T) * CHAR_BIT;
  static constexpr int exponent_bits = ExponentBits;
  static constexpr int mantissa_bits = bits - exponent_bits - 1;
  static constexpr int sign_bit_pos = bits - 1;
  static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1;
  static constexpr int unbiased_exponent_min = -exponent_bias + 1;
  static constexpr int unbiased_exponent_max =
      HasInfinity ? exponent_bias : (exponent_bias + 1);
  static constexpr T sign_bit = T{1} << sign_bit_pos;
  static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1)
      << mantissa_bits;
  static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1;
  // signaling/quiet encoding is unspecified by IEEE754. This mirrors x86/ARM.
  static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1);

  static constexpr T nan = exponent_mask | mantissa_mask;
  static constexpr T overflow_value = HasInfinity ? exponent_mask : nan;
  static constexpr bool has_infinity = HasInfinity;
  static constexpr bool has_nan_payload = HasInfinity;
};

using IEEE754Single = FloatFormat</*T=*/uint32_t, /*ExponentBits=*/8>;
using IEEE754Half = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/5>;
// See https://arxiv.org/abs/1905.12322v3
using BFloat16 = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/8>;
// See https://doi.org/10.48550/arXiv.2209.05433
using FP8_E5M2 = FloatFormat</*T=*/uint8_t, /*ExponentBits=*/5>;
// See https://doi.org/10.48550/arXiv.2209.05433
using FP8_E4M3FN = FloatFormat<
    /*T=*/uint8_t,
    /*ExponentBits=*/4,
    /*HasInfinity=*/false>;

enum class RoundingMode {
  ToNearestTiesToEven,
  ToZero,
};

// Generic IEEE754 truncation algorithm.
template <typename Src, typename Tgt, RoundingMode RoundingMode>
[[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc(

    typename Src::value_type value) {
  static_assert(Src::exponent_bits >= Tgt::exponent_bits);
  static_assert(Src::mantissa_bits > Tgt::mantissa_bits);
  using ST = typename Src::value_type;
  using TT = typename Tgt::value_type;

  ST src_exponent = value & Src::exponent_mask;
  ST src_mantissa = value & Src::mantissa_mask;
  // Fast-path: If there is no difference in exponent sizes (e.g. fp32 -> bf16)
  // and we round toward zero, then we can just drop the least significant bits.
  if constexpr (

      Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity &&

      Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) {
    TT result = value >> (Src::bits - Tgt::bits);
    // Turn signaling NaN into quiet NaN. This also avoids that the mantissa
    // is completely zero after truncation (which would be misinterpreted as
    // INF).
    if (src_exponent == Src::exponent_mask && src_mantissa != 0) {
      result |= Tgt::quiet_nan_bit;
    }
    return result;
  }

  ST tgt_sign =
      (value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos);
  constexpr bool denormal_becomes_zero =
      Tgt::unbiased_exponent_min - Src::unbiased_exponent_min >
      Src::mantissa_bits - Tgt::mantissa_bits;
  if constexpr (denormal_becomes_zero) {
    // Fast-path for zero exponentbits: This means the number was zero or a
    // denormal number that will turn into zero in the Tgt format.
    if (src_exponent == 0) {
      return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0
    }
  }

  int unbiased_exponent =
      (src_exponent >> Src::mantissa_bits) - Src::exponent_bias;
  if (unbiased_exponent < Tgt::unbiased_exponent_min) {
    int shift = Tgt::unbiased_exponent_min - unbiased_exponent;
    if (shift <= Tgt::mantissa_bits + 1) {
      // Result is denormal.
      ST src_mantissa_one = src_mantissa;
      // Add explicit one if the source was not denormal.
      if (denormal_becomes_zero || src_exponent != 0) {
        src_mantissa_one |= TT{1} << Src::mantissa_bits;
      } else {
        shift--;
      }
      TT tgt_mantissa =
          src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift);

      if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) {
        int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1;
        ST half = 1 << half_pos;
        ST remainder = src_mantissa_one & ((half << 1) - 1);
        if (remainder > half ||
            (remainder == half && (tgt_mantissa & 1) != 0)) {
          tgt_mantissa += 1;
        }
      } else {
        assert(RoundingMode == RoundingMode::ToZero);
      }
      return tgt_sign | tgt_mantissa; // tgt_exponent == 0
    } else {
      // Result is +/- zero
      return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0
    }
  }

  if (unbiased_exponent > Tgt::unbiased_exponent_max) {
    if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) {
      TT tgt_mantissa;
      if constexpr (Tgt::has_nan_payload) {
        // NaN; not a number
        tgt_mantissa =
            src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits);
        tgt_mantissa |= Tgt::quiet_nan_bit;
      } else {
        tgt_mantissa = Tgt::mantissa_mask;
      }
      return tgt_sign | Tgt::exponent_mask | tgt_mantissa;
    } else {
      if (RoundingMode == RoundingMode::ToZero &&
          (!Src::has_infinity || src_exponent != Src::exponent_mask)) {
        // Return largest finite number.
        return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) |
            Tgt::mantissa_mask;
      }
      // Infinity or NaN for formats without infinity.
      return tgt_sign | Tgt::overflow_value;
    }
  }

  // Normal number.
  TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits);
  TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias)
      << Tgt::mantissa_bits;
  if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) {
    ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1);
    ST remainder = src_mantissa & ((half << 1) - 1);
    if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) {
      if (tgt_mantissa < Tgt::mantissa_mask) {
        tgt_mantissa += 1;
      } else {
        // Mantissa overflowed, increment exponent.

        // Normally we can just add to the exponent and will naturally end up
        // on infinity on overflow. But we need special treatments for formats
        // without infinity.
        if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) {
          tgt_mantissa = 0;
          tgt_exponent += TT{1} << Tgt::mantissa_bits;
        } else {
          // Return NaN.
          tgt_mantissa = Tgt::mantissa_mask;
        }
      }
    }
  } else {
    assert(RoundingMode == RoundingMode::ToZero);
  }
  return tgt_sign | tgt_exponent | tgt_mantissa;
}

} // namespace detail

inline float16 cpu_float2half_rn(float f) {
  uint32_t f_u32;
  std::memcpy(&f_u32, &f, sizeof(f_u32));
  return detail::ieee754_trunc<
      /*Src=*/detail::IEEE754Single,
      /*Tgt=*/detail::IEEE754Half,
      detail::RoundingMode::ToNearestTiesToEven>(f_u32);
}

inline float16 cpu_float2half_rz(float f) {
  uint32_t f_u32;
  std::memcpy(&f_u32, &f, sizeof(f_u32));
  return detail::ieee754_trunc<
      /*Src=*/detail::IEEE754Single,
      /*Tgt=*/detail::IEEE754Half,
      detail::RoundingMode::ToZero>(f_u32);
};

// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
// float into an IEEE754 32-bit single-precision float
inline float cpu_half2float_ref(const float16 h) {
  constexpr uint32_t f16_num_exponent_bits = 5;
  constexpr uint32_t f16_num_mantissa_bits = 10;
  constexpr uint32_t f16_num_non_sign_bits =
      f16_num_exponent_bits + f16_num_mantissa_bits;
  constexpr uint32_t f16_exponent_bias = 15;
  constexpr uint32_t f16_exponent_mask = 0b1'1111;
  constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111;

  constexpr uint32_t f32_num_exponent_bits = 8;
  constexpr uint32_t f32_num_mantissa_bits = 23;
  constexpr uint32_t f32_num_non_sign_bits =
      f32_num_exponent_bits + f32_num_mantissa_bits;
  constexpr uint32_t f32_exponent_bias = 127;
  constexpr uint32_t f32_exponent_mask = 0b1111'1111;
  constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF;
  constexpr uint32_t f32_most_significant_bit = 1u << 22;

  // Get sign and exponent alone by themselves
  uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
  uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
  // Shift mantissa so that it fills the most significant bits of a float32
  uint32_t mantissa = (h & f16_mantissa_mask)
      << (f32_num_mantissa_bits - f16_num_mantissa_bits);

  if (exponent == f16_exponent_mask) { // NaN or Inf
    if (mantissa) {
      mantissa = f32_mantissa_mask;
      sign_bit = 0;
    }
    exponent = f32_exponent_mask;
  } else if (!exponent) { // Denorm or Zero
    if (mantissa) {
      uint32_t msb;
      exponent = f32_exponent_bias - f16_exponent_bias + 1;
      do {
        msb = mantissa & f32_most_significant_bit;
        mantissa <<= 1; // normalize
        --exponent;
      } while (!msb);
      mantissa &= f32_mantissa_mask; // 1.mantissa is implicit
    }
  } else {
    exponent += f32_exponent_bias - f16_exponent_bias;
  }

  const uint32_t i = (sign_bit << f32_num_non_sign_bits) |
      (exponent << f32_num_mantissa_bits) | mantissa;

  float ret;
  std::memcpy(&ret, &i, sizeof(float));
  return ret;
}

// Same as the previous function, but use the built-in fp16 to fp32
// conversion provided by the compiler
inline float cpu_half2float(const float16 h) {
#if defined(HAS_NATIVE_FP16_TYPE) && defined(HAVE_GNU_F2H_IEEE)
  __fp16 h_fp16;
  std::memcpy(&h_fp16, &h, sizeof(__fp16));
  return h_fp16;
#else
  return cpu_half2float_ref(h);
#endif
}

inline float16 cpu_float2half(const float f) {
#if defined(HAS_NATIVE_FP16_TYPE) && defined(HAVE_GNU_F2H_IEEE)
  __fp16 h = f;
  float16 res;
  std::memcpy(&res, &h, sizeof(__fp16));
  return res;
#else
  return cpu_float2half_rn(f);
#endif
}

inline float cpu_bf162float(bfloat16 src) {
  float ret;
  uint32_t val_fp32 =
      static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16;
  std::memcpy(&ret, &val_fp32, sizeof(float));
  return ret;
}

inline bfloat16 cpu_float2bfloat16(float src) {
  uint32_t temp;
  std::memcpy(&temp, &src, sizeof(uint32_t));
  return (temp + (1u << 15)) >> 16;
}

} // namespace fbgemm