| import functools |
| import struct |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Optional, Union |
|
|
|
|
| |
| class NanRepr(Enum): |
| NONE = 0 |
| IEEE_754 = 1 |
| EXTD_RANGE_MAX_MIN = 2 |
|
|
|
|
| |
| |
| |
| @dataclass(frozen=True) |
| class ScalarType: |
| """ |
| ScalarType can represent a wide range of floating point and integer |
| types, in particular it can be used to represent sub-byte data types |
| (something that torch.dtype currently does not support). It is also |
| capable of representing types with a bias, i.e.: |
| `stored_value = value + bias`, |
| this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias |
| of 8). The implementation for this class can be found in |
| csrc/core/scalar_type.hpp, these type signatures should be kept in sync |
| with that file. |
| """ |
|
|
| exponent: int |
| """ |
| Number of bits in the exponent if this is a floating point type |
| (zero if this an integer type) |
| """ |
|
|
| mantissa: int |
| """ |
| Number of bits in the mantissa if this is a floating point type, |
| or the number bits representing an integer excluding the sign bit if |
| this an integer type. |
| """ |
|
|
| signed: bool |
| "If the type is signed (i.e. has a sign bit)" |
|
|
| bias: int |
| """ |
| bias used to encode the values in this scalar type |
| (value = stored_value - bias, default 0) for example if we store the |
| type as an unsigned integer with a bias of 128 then the value 0 will be |
| stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. |
| """ |
|
|
| _finite_values_only: bool = False |
| """ |
| Private: if infs are supported, used `has_infs()` instead. |
| """ |
|
|
| nan_repr: NanRepr = NanRepr.IEEE_754 |
| """ |
| How NaNs are represent in this scalar type, returns NanRepr value. |
| (not applicable for integer types) |
| """ |
|
|
| def _floating_point_max_int(self) -> int: |
| assert ( |
| self.mantissa <= 52 and self.exponent <= 11 |
| ), f"Cannot represent max/min as a double for type {self.__str__()}" |
|
|
| max_mantissa = (1 << self.mantissa) - 1 |
| if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: |
| max_mantissa = max_mantissa - 1 |
|
|
| max_exponent = (1 << self.exponent) - 2 |
| if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN |
| or self.nan_repr == NanRepr.NONE): |
| assert ( |
| self.exponent < 11 |
| ), f"Cannot represent max/min as a double for type {self.__str__()}" |
| max_exponent = max_exponent + 1 |
|
|
| |
| |
| |
| |
| |
| |
| |
| exponent_bias = (1 << (self.exponent - 1)) - 1 |
| exponent_bias_double = (1 << 10) - 1 |
|
|
| max_exponent_double = (max_exponent - exponent_bias + |
| exponent_bias_double) |
|
|
| |
| |
| return (max_mantissa << |
| (52 - self.mantissa)) | (max_exponent_double << 52) |
|
|
| def _floating_point_max(self) -> float: |
| double_raw = self._floating_point_max_int() |
| return struct.unpack('!d', struct.pack('!Q', double_raw))[0] |
|
|
| def _raw_max(self) -> Union[int, float]: |
| if self.is_floating_point(): |
| return self._floating_point_max() |
| else: |
| assert (self.size_bits < 64 or self.size_bits == 64 |
| and self.is_signed()), "Cannot represent max as an int" |
| return (1 << self.mantissa) - 1 |
|
|
| def _raw_min(self) -> Union[int, float]: |
| if self.is_floating_point(): |
| assert self.is_signed( |
| ), "We currently assume all floating point types are signed" |
| sign_bit_double = 1 << 63 |
|
|
| max_raw = self._floating_point_max_int() |
| min_raw = max_raw | sign_bit_double |
| return struct.unpack('!d', struct.pack('!Q', min_raw))[0] |
| else: |
| assert (not self.is_signed() or |
| self.size_bits <= 64), "Cannot represent min as a int64_t" |
|
|
| if self.is_signed(): |
| return -(1 << (self.size_bits - 1)) |
| else: |
| return 0 |
|
|
| @functools.cached_property |
| def id(self) -> int: |
| """ |
| Convert the ScalarType to an int which can be passed to pytorch custom |
| ops. This layout of the int must be kept in sync with the C++ |
| ScalarType's from_id method. |
| """ |
| val = 0 |
| offset = 0 |
|
|
| def or_and_advance(member, bit_width): |
| nonlocal val |
| nonlocal offset |
| bit_mask = (1 << bit_width) - 1 |
| val = val | (int(member) & bit_mask) << offset |
| offset = offset + bit_width |
|
|
| or_and_advance(self.exponent, 8) |
| or_and_advance(self.mantissa, 8) |
| or_and_advance(self.signed, 1) |
| or_and_advance(self.bias, 32) |
| or_and_advance(self._finite_values_only, 1) |
| or_and_advance(self.nan_repr.value, 8) |
|
|
| assert offset <= 64, \ |
| f"ScalarType fields too big {offset} to fit into an int64" |
|
|
| return val |
|
|
| @property |
| def size_bits(self) -> int: |
| return self.exponent + self.mantissa + int(self.signed) |
|
|
| def min(self) -> Union[int, float]: |
| """ |
| Min representable value for this scalar type. |
| (accounting for bias if there is one) |
| """ |
| return self._raw_min() - self.bias |
|
|
| def max(self) -> Union[int, float]: |
| """ |
| Max representable value for this scalar type. |
| (accounting for bias if there is one) |
| """ |
| return self._raw_max() - self.bias |
|
|
| def is_signed(self) -> bool: |
| """ |
| If the type is signed (i.e. has a sign bit), same as `signed` |
| added for consistency with: |
| https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html |
| """ |
| return self.signed |
|
|
| def is_floating_point(self) -> bool: |
| "If the type is a floating point type" |
| return self.exponent != 0 |
|
|
| def is_integer(self) -> bool: |
| "If the type is an integer type" |
| return self.exponent == 0 |
|
|
| def has_bias(self) -> bool: |
| "If the type has a non-zero bias" |
| return self.bias != 0 |
|
|
| def has_infs(self) -> bool: |
| "If the type is floating point and supports infinity" |
| return not self._finite_values_only |
|
|
| def has_nans(self) -> bool: |
| return self.nan_repr != NanRepr.NONE.value |
|
|
| def is_ieee_754(self) -> bool: |
| """ |
| If the type is a floating point type that follows IEEE 754 |
| conventions |
| """ |
| return self.nan_repr == NanRepr.IEEE_754.value and \ |
| not self._finite_values_only |
|
|
| def __str__(self) -> str: |
| """ |
| naming generally follows: https://github.com/jax-ml/ml_dtypes |
| for floating point types (leading f) the scheme is: |
| `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]` |
| flags: |
| - no-flags: means it follows IEEE 754 conventions |
| - f: means finite values only (no infinities) |
| - n: means nans are supported (non-standard encoding) |
| for integer types the scheme is: |
| `[u]int<size_bits>[b<bias>]` |
| - if bias is not present it means its zero |
| """ |
| if self.is_floating_point(): |
| ret = "float" + str(self.size_bits) + "_e" + str( |
| self.exponent) + "m" + str(self.mantissa) |
|
|
| if not self.is_ieee_754(): |
| if self._finite_values_only: |
| ret = ret + "f" |
| if self.nan_repr != NanRepr.NONE: |
| ret = ret + "n" |
|
|
| return ret |
| else: |
| ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) |
| if self.has_bias(): |
| ret = ret + "b" + str(self.bias) |
| return ret |
|
|
| def __repr__(self) -> str: |
| return "ScalarType." + self.__str__() |
|
|
| |
| |
| def __len__(self) -> int: |
| raise TypeError |
|
|
| |
| |
| |
|
|
| @classmethod |
| def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': |
| "Create a signed integer scalar type (size_bits includes sign-bit)." |
| ret = cls(0, size_bits - 1, True, bias if bias else 0) |
| ret.id |
| return ret |
|
|
| @classmethod |
| def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': |
| """Create a unsigned integer scalar type.""" |
| ret = cls(0, size_bits, False, bias if bias else 0) |
| ret.id |
| return ret |
|
|
| @classmethod |
| def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': |
| """ |
| Create a standard floating point type |
| (i.e. follows IEEE 754 conventions). |
| """ |
| assert (mantissa > 0 and exponent > 0) |
| ret = cls(exponent, mantissa, True, 0) |
| ret.id |
| return ret |
|
|
| @classmethod |
| def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, |
| nan_repr: NanRepr) -> 'ScalarType': |
| """ |
| Create a non-standard floating point type |
| (i.e. does not follow IEEE 754 conventions). |
| """ |
| assert (mantissa > 0 and exponent > 0) |
| assert (nan_repr != NanRepr.IEEE_754), ( |
| "use `float_IEEE754` constructor for floating point types that " |
| "follow IEEE 754 conventions") |
| ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) |
| ret.id |
| return ret |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class scalar_types: |
| int4 = ScalarType.int_(4, None) |
| uint4 = ScalarType.uint(4, None) |
| int8 = ScalarType.int_(8, None) |
| uint8 = ScalarType.uint(8, None) |
| float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) |
| float8_e5m2 = ScalarType.float_IEEE754(5, 2) |
| float16_e8m7 = ScalarType.float_IEEE754(8, 7) |
| float16_e5m10 = ScalarType.float_IEEE754(5, 10) |
|
|
| |
| float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) |
|
|
| |
| uint2b2 = ScalarType.uint(2, 2) |
| uint3b4 = ScalarType.uint(3, 4) |
| uint4b8 = ScalarType.uint(4, 8) |
| uint8b128 = ScalarType.uint(8, 128) |
|
|
| |
| bfloat16 = float16_e8m7 |
| float16 = float16_e5m10 |
|
|