File size: 3,897 Bytes
d1d4335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#pragma once

#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <cstring>
#include <limits>

// TODO(#146647): Can we remove the below warning?
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif

namespace c10 {

/// Constructors

inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
    : x(detail::fp8e8m0fnu_from_fp32_value(value)) {}

/// Implicit conversions

inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
  // TODO(#146647): maybe rewrite without control flow

  // if exponent is zero, need to special case to return 2^-127 instead of zero
  if (x == 0) {
    return c10::detail::fp32_from_bits(0x00400000);
  }

  // if exponent is NaN, need to special case to return properly encoded NaN
  if (isnan()) {
    return c10::detail::fp32_from_bits(0x7f800001);
  }

  // leave sign at 0, set the exponent bits, leave stored mantissa at 0
  uint32_t res = x << 23;

  return c10::detail::fp32_from_bits(res);
}

/// Special values helper

inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
  return x == 0b11111111;
}

/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e8m0fnu to float.

} // namespace c10

namespace std {

template <>
class numeric_limits<c10::Float8_e8m0fnu> {
 public:
  static constexpr bool is_specialized = true;
  static constexpr bool is_signed = false;
  static constexpr bool is_integer = false;
  static constexpr bool is_exact = false;
  static constexpr bool has_infinity = false;
  static constexpr bool has_quiet_NaN = true;
  static constexpr bool has_signaling_NaN = false;
  static constexpr auto has_denorm = false;
  static constexpr auto has_denorm_loss = false;
  static constexpr auto round_style = numeric_limits<float>::round_style;
  static constexpr bool is_iec559 = false;
  static constexpr bool is_bounded = true;
  static constexpr bool is_modulo = false;
  static constexpr int digits = 1;
  static constexpr int digits10 = 0;
  static constexpr int max_digits10 = 1; // just a 2!
  static constexpr int radix = 2;
  static constexpr int min_exponent = -126;
  static constexpr int min_exponent10 = -38;
  static constexpr int max_exponent = 128;
  static constexpr int max_exponent10 = 38;
  static constexpr auto traps = numeric_limits<float>::traps;
  static constexpr auto tinyness_before = false;

  static constexpr c10::Float8_e8m0fnu min() {
    // 2^-127
    return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
  }
  static constexpr c10::Float8_e8m0fnu lowest() {
    // 2^-127
    return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
  }
  static constexpr c10::Float8_e8m0fnu max() {
    // 254 biased, which is 127 unbiased, so 2^127
    return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
  }
  static constexpr c10::Float8_e8m0fnu epsilon() {
    // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
    // is "the difference between 1.0 and the next representable value of the
    // given floating-point type". The next representable value is 2.0, so the
    // difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
    return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
  }
  static constexpr c10::Float8_e8m0fnu round_error() {
    // 0.5 in float, which is 2^-1, and -1 + 127 = 126
    return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
  }
  static constexpr c10::Float8_e8m0fnu quiet_NaN() {
    return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
  }
};

} // namespace std

C10_CLANG_DIAGNOSTIC_POP()