Add missing `scalar_type.hpp`
Browse files- core/scalar_type.hpp +347 -0
core/scalar_type.hpp
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// For TORCH_CHECK
|
| 4 |
+
#include <torch/library.h>
|
| 5 |
+
|
| 6 |
+
namespace vllm {
|
| 7 |
+
|
| 8 |
+
//
|
| 9 |
+
// ScalarType can represent a wide range of floating point and integer types,
|
| 10 |
+
// in particular it can be used to represent sub-byte data types (something
|
| 11 |
+
// that torch.dtype currently does not support).
|
| 12 |
+
//
|
| 13 |
+
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
| 14 |
+
// these type definitions should be kept up to date with any Python API changes
|
| 15 |
+
// here.
|
| 16 |
+
//
|
| 17 |
+
class ScalarType {
|
| 18 |
+
public:
|
| 19 |
+
enum NanRepr : uint8_t {
|
| 20 |
+
NAN_NONE = 0, // nans are not supported
|
| 21 |
+
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
| 22 |
+
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
| 23 |
+
|
| 24 |
+
NAN_REPR_ID_MAX
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
| 28 |
+
int32_t bias, bool finite_values_only = false,
|
| 29 |
+
NanRepr nan_repr = NAN_IEEE_754)
|
| 30 |
+
: exponent(exponent),
|
| 31 |
+
mantissa(mantissa),
|
| 32 |
+
signed_(signed_),
|
| 33 |
+
bias(bias),
|
| 34 |
+
finite_values_only(finite_values_only),
|
| 35 |
+
nan_repr(nan_repr){};
|
| 36 |
+
|
| 37 |
+
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
| 38 |
+
return ScalarType(0, size_bits - 1, true, bias);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
| 42 |
+
return ScalarType(0, size_bits, false, bias);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// IEEE 754 compliant floating point type
|
| 46 |
+
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
| 47 |
+
uint8_t mantissa) {
|
| 48 |
+
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
| 49 |
+
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// IEEE 754 non-compliant floating point type
|
| 53 |
+
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
| 54 |
+
bool finite_values_only,
|
| 55 |
+
NanRepr nan_repr) {
|
| 56 |
+
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
| 57 |
+
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
| 58 |
+
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
| 59 |
+
"use `float_IEEE754` constructor for floating point types that "
|
| 60 |
+
"follow IEEE 754 conventions");
|
| 61 |
+
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
| 62 |
+
nan_repr);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
| 66 |
+
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
| 67 |
+
// excluding the sign bit for integer types)
|
| 68 |
+
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
| 69 |
+
// sign bit)
|
| 70 |
+
int32_t const bias; // stored values equal value + bias,
|
| 71 |
+
// used for quantized type
|
| 72 |
+
|
| 73 |
+
// Extra Floating point info
|
| 74 |
+
bool const finite_values_only; // i.e. no +/-inf if true
|
| 75 |
+
NanRepr const nan_repr; // how NaNs are represented
|
| 76 |
+
// (not applicable for integer types)
|
| 77 |
+
|
| 78 |
+
using Id = int64_t;
|
| 79 |
+
|
| 80 |
+
private:
|
| 81 |
+
// Field size in id
|
| 82 |
+
template <typename T_>
|
| 83 |
+
static constexpr size_t member_id_field_width() {
|
| 84 |
+
using T = std::decay_t<T_>;
|
| 85 |
+
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
template <typename Fn, typename Init, typename Member, typename... Rest>
|
| 89 |
+
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
| 90 |
+
Rest... rest) {
|
| 91 |
+
auto new_val = f(val, member);
|
| 92 |
+
if constexpr (sizeof...(rest) > 0) {
|
| 93 |
+
return reduce_members_helper(f, new_val, rest...);
|
| 94 |
+
} else {
|
| 95 |
+
return new_val;
|
| 96 |
+
};
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <typename Fn, typename Init>
|
| 100 |
+
constexpr auto reduce_members(Fn f, Init init) const {
|
| 101 |
+
// Should be in constructor order for `from_id`
|
| 102 |
+
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
| 103 |
+
finite_values_only, nan_repr);
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
template <typename Fn, typename Init>
|
| 107 |
+
static constexpr auto reduce_member_types(Fn f, Init init) {
|
| 108 |
+
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
| 109 |
+
return dummy_type.reduce_members(f, init);
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
static constexpr auto id_size_bits() {
|
| 113 |
+
return reduce_member_types(
|
| 114 |
+
[](int acc, auto member) -> int {
|
| 115 |
+
return acc + member_id_field_width<decltype(member)>();
|
| 116 |
+
},
|
| 117 |
+
0);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
public:
|
| 121 |
+
// unique id for this scalar type that can be computed at compile time for
|
| 122 |
+
// c++17 template specialization this is not needed once we migrate to
|
| 123 |
+
// c++20 and can pass literal classes as template parameters
|
| 124 |
+
constexpr Id id() const {
|
| 125 |
+
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
| 126 |
+
"ScalarType id is too large to be stored");
|
| 127 |
+
|
| 128 |
+
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
| 129 |
+
auto member) -> std::pair<Id, uint32_t> {
|
| 130 |
+
auto [id, bit_offset] = result;
|
| 131 |
+
auto constexpr bits = member_id_field_width<decltype(member)>();
|
| 132 |
+
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
| 133 |
+
<< bit_offset,
|
| 134 |
+
bit_offset + bits};
|
| 135 |
+
};
|
| 136 |
+
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// create a ScalarType from an id, for c++17 template specialization,
|
| 140 |
+
// this is not needed once we migrate to c++20 and can pass literal
|
| 141 |
+
// classes as template parameters
|
| 142 |
+
static constexpr ScalarType from_id(Id id) {
|
| 143 |
+
auto extract_and_advance = [id](auto result, auto member) {
|
| 144 |
+
using T = decltype(member);
|
| 145 |
+
auto [tuple, bit_offset] = result;
|
| 146 |
+
auto constexpr bits = member_id_field_width<T>();
|
| 147 |
+
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
| 148 |
+
((uint64_t(1) << bits) - 1));
|
| 149 |
+
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
| 150 |
+
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
| 154 |
+
std::pair<std::tuple<>, int>{});
|
| 155 |
+
return std::apply([](auto... args) { return ScalarType(args...); },
|
| 156 |
+
tuple_args);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
constexpr int64_t size_bits() const {
|
| 160 |
+
return mantissa + exponent + is_signed();
|
| 161 |
+
}
|
| 162 |
+
constexpr bool is_signed() const { return signed_; }
|
| 163 |
+
constexpr bool is_integer() const { return exponent == 0; }
|
| 164 |
+
constexpr bool is_floating_point() const { return exponent > 0; }
|
| 165 |
+
constexpr bool is_ieee_754() const {
|
| 166 |
+
return is_floating_point() && finite_values_only == false &&
|
| 167 |
+
nan_repr == NAN_IEEE_754;
|
| 168 |
+
}
|
| 169 |
+
constexpr bool has_nans() const {
|
| 170 |
+
return is_floating_point() && nan_repr != NAN_NONE;
|
| 171 |
+
}
|
| 172 |
+
constexpr bool has_infs() const {
|
| 173 |
+
return is_floating_point() && finite_values_only == false;
|
| 174 |
+
}
|
| 175 |
+
constexpr bool has_bias() const { return bias != 0; }
|
| 176 |
+
|
| 177 |
+
private:
|
| 178 |
+
double _floating_point_max() const {
|
| 179 |
+
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
|
| 180 |
+
"Cannot represent max/min as a double for type ", str());
|
| 181 |
+
|
| 182 |
+
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
| 183 |
+
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
| 184 |
+
max_mantissa -= 1;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
| 188 |
+
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
| 189 |
+
TORCH_CHECK(exponent < 11,
|
| 190 |
+
"Cannot represent max/min as a double for type ", str());
|
| 191 |
+
max_exponent += 1;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// adjust the exponent to match that of a double
|
| 195 |
+
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
| 196 |
+
// is the exponent bits), there is some precedent for non-standard biases,
|
| 197 |
+
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
| 198 |
+
// but to avoid premature over complication we are just assuming the
|
| 199 |
+
// standard exponent bias until there is a need to support non-standard
|
| 200 |
+
// biases
|
| 201 |
+
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
| 202 |
+
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
| 203 |
+
|
| 204 |
+
uint64_t max_exponent_double =
|
| 205 |
+
max_exponent - exponent_bias + exponent_bias_double;
|
| 206 |
+
|
| 207 |
+
// shift the mantissa into the position for a double and
|
| 208 |
+
// the exponent
|
| 209 |
+
uint64_t double_raw =
|
| 210 |
+
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
| 211 |
+
|
| 212 |
+
return *reinterpret_cast<double*>(&double_raw);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
constexpr std::variant<int64_t, double> _raw_max() const {
|
| 216 |
+
if (is_floating_point()) {
|
| 217 |
+
return {_floating_point_max()};
|
| 218 |
+
} else {
|
| 219 |
+
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
| 220 |
+
"Cannot represent max as a int64_t");
|
| 221 |
+
return {(int64_t(1) << mantissa) - 1};
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
constexpr std::variant<int64_t, double> _raw_min() const {
|
| 226 |
+
if (is_floating_point()) {
|
| 227 |
+
TORCH_CHECK(is_signed(),
|
| 228 |
+
"We currently assume all floating point types are signed");
|
| 229 |
+
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
| 230 |
+
|
| 231 |
+
double max = _floating_point_max();
|
| 232 |
+
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
| 233 |
+
uint64_t min_raw = max_raw | sign_bit_double;
|
| 234 |
+
return {*reinterpret_cast<double*>(&min_raw)};
|
| 235 |
+
} else {
|
| 236 |
+
TORCH_CHECK(!is_signed() || size_bits() <= 64,
|
| 237 |
+
"Cannot represent min as a int64_t");
|
| 238 |
+
if (is_signed()) {
|
| 239 |
+
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
| 240 |
+
// then perform an arithmetic shift right to set all the bits above
|
| 241 |
+
// (size_bits() - 1) to 1
|
| 242 |
+
return {INT64_MIN >> (64 - size_bits())};
|
| 243 |
+
} else {
|
| 244 |
+
return {int64_t(0)};
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
public:
|
| 250 |
+
// Max representable value for this scalar type.
|
| 251 |
+
// (accounting for bias if there is one)
|
| 252 |
+
constexpr std::variant<int64_t, double> max() const {
|
| 253 |
+
return std::visit(
|
| 254 |
+
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
| 255 |
+
_raw_max());
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// Min representable value for this scalar type.
|
| 259 |
+
// (accounting for bias if there is one)
|
| 260 |
+
constexpr std::variant<int64_t, double> min() const {
|
| 261 |
+
return std::visit(
|
| 262 |
+
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
| 263 |
+
_raw_min());
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
std::string str() const {
|
| 267 |
+
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
| 268 |
+
* for floating point types (leading f) the scheme is:
|
| 269 |
+
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
| 270 |
+
* flags:
|
| 271 |
+
* - no-flags: means it follows IEEE 754 conventions
|
| 272 |
+
* - f: means finite values only (no infinities)
|
| 273 |
+
* - n: means nans are supported (non-standard encoding)
|
| 274 |
+
* for integer types the scheme is:
|
| 275 |
+
* `[u]int<size_bits>[b<bias>]`
|
| 276 |
+
* - if bias is not present it means its zero
|
| 277 |
+
*/
|
| 278 |
+
if (is_floating_point()) {
|
| 279 |
+
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
| 280 |
+
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
| 281 |
+
if (!is_ieee_754()) {
|
| 282 |
+
if (finite_values_only) {
|
| 283 |
+
ret += "f";
|
| 284 |
+
}
|
| 285 |
+
if (nan_repr != NAN_NONE) {
|
| 286 |
+
ret += "n";
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
return ret;
|
| 290 |
+
} else {
|
| 291 |
+
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
| 292 |
+
if (has_bias()) {
|
| 293 |
+
ret += "b" + std::to_string(bias);
|
| 294 |
+
}
|
| 295 |
+
return ret;
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
constexpr bool operator==(ScalarType const& other) const {
|
| 300 |
+
return mantissa == other.mantissa && exponent == other.exponent &&
|
| 301 |
+
bias == other.bias && signed_ == other.signed_ &&
|
| 302 |
+
finite_values_only == other.finite_values_only &&
|
| 303 |
+
nan_repr == other.nan_repr;
|
| 304 |
+
}
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
using ScalarTypeId = ScalarType::Id;
|
| 308 |
+
|
| 309 |
+
// "rust style" names generally following:
|
| 310 |
+
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
| 311 |
+
static inline constexpr auto kS4 = ScalarType::int_(4);
|
| 312 |
+
static inline constexpr auto kU4 = ScalarType::uint(4);
|
| 313 |
+
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
| 314 |
+
static inline constexpr auto kS8 = ScalarType::int_(8);
|
| 315 |
+
static inline constexpr auto kU8 = ScalarType::uint(8);
|
| 316 |
+
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
| 317 |
+
|
| 318 |
+
static inline constexpr auto kFE3M2f =
|
| 319 |
+
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
| 320 |
+
static inline constexpr auto kFE4M3fn =
|
| 321 |
+
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
| 322 |
+
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
| 323 |
+
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
| 324 |
+
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
| 325 |
+
|
| 326 |
+
// Fixed width style names, generally following:
|
| 327 |
+
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
| 328 |
+
static inline constexpr auto kInt4 = kS4;
|
| 329 |
+
static inline constexpr auto kUint4 = kU4;
|
| 330 |
+
static inline constexpr auto kUint4b8 = kU4B8;
|
| 331 |
+
static inline constexpr auto kInt8 = kS8;
|
| 332 |
+
static inline constexpr auto kUint8 = kU8;
|
| 333 |
+
static inline constexpr auto kUint8b128 = kU8B128;
|
| 334 |
+
|
| 335 |
+
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
| 336 |
+
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
| 337 |
+
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
| 338 |
+
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
| 339 |
+
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
| 340 |
+
|
| 341 |
+
// colloquial names
|
| 342 |
+
static inline constexpr auto kHalf = kFE5M10;
|
| 343 |
+
static inline constexpr auto kFloat16 = kHalf;
|
| 344 |
+
static inline constexpr auto kBFloat16 = kFE8M7;
|
| 345 |
+
|
| 346 |
+
static inline constexpr auto kFloat16Id = kFloat16.id();
|
| 347 |
+
}; // namespace vllm
|