| |
| |
|
|
| #pragma once |
|
|
| #include "ck/utility/data_type.hpp" |
|
|
| namespace ck { |
|
|
| |
| |
| |
| enum class f8_rounding_mode |
| { |
| standard, |
| stochastic |
| }; |
|
|
| __host__ inline int clz(uint32_t x) { return __builtin_clz(x); } |
| __device__ inline int clz(uint32_t x) { return __clz(x); } |
|
|
| } |
|
|
| namespace ck::utils { |
|
|
| namespace { |
|
|
| template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch> |
| __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) |
| { |
| |
| constexpr int out_exp = NumericUtils<Y>::exp; |
| constexpr int out_mant = NumericUtils<Y>::mant; |
|
|
| |
| constexpr int in_exp = NumericUtils<X>::exp; |
| constexpr int in_mant = NumericUtils<X>::mant; |
|
|
| int exponent, bias; |
| uint32_t head, mantissa, sign; |
| |
| constexpr Y nan_code = 0x80; |
| constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask; |
|
|
| |
| using T_bitwise = typename NumericUtils<X>::bitwise_type; |
| T_bitwise x_bitwise = bit_cast<T_bitwise>(x); |
|
|
| |
| head = x_bitwise & NumericUtils<X>::head_mask; |
| mantissa = x_bitwise & NumericUtils<X>::mant_mask; |
| exponent = (head >> in_mant) & NumericUtils<X>::exp_mask; |
| sign = head >> (in_exp + in_mant); |
| bias = NumericUtils<X>::bias; |
|
|
| uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); |
| uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; |
| constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); |
|
|
| if constexpr(negative_zero_nan) |
| { |
| if((x_bitwise & nan_mask) == nan_mask) |
| return nan_code; |
| } |
| else |
| { |
| if((x_bitwise & nan_mask) == nan_mask) |
| return signed_inf + (mantissa != 0 ? 1 : 0); |
| } |
|
|
| |
| if(x_bitwise == 0) |
| return 0; |
|
|
| |
| |
| |
| |
| |
|
|
| |
| const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); |
| const int out_denormal_act_exponent = 1 - out_bias; |
| |
| |
| |
| |
| int act_exponent, out_exponent, exponent_diff; |
|
|
| if(exponent == 0) |
| { |
| |
| |
| |
| |
| |
| |
| act_exponent = exponent - bias + 1; |
| exponent_diff = out_denormal_act_exponent - |
| act_exponent; |
| } |
| else |
| { |
| act_exponent = exponent - bias; |
| if(act_exponent <= out_denormal_act_exponent) |
| { |
| |
| |
| |
| |
| |
| exponent_diff = out_denormal_act_exponent - act_exponent; |
| } |
| else |
| { |
| exponent_diff = |
| 0; |
| |
| } |
| mantissa += (1 << in_mant); |
| } |
|
|
| bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == |
| (1 << (in_mant - out_mant + exponent_diff - 1)); |
| |
| |
| |
| |
|
|
| if(exponent_diff > 0) |
| mantissa >>= exponent_diff; |
| else if(exponent_diff == -1) |
| mantissa <<= -exponent_diff; |
| bool implicit_one = mantissa & (1 << in_mant); |
| |
| out_exponent = |
| (act_exponent + exponent_diff) + out_bias - (implicit_one ? 0 : 1); |
|
|
| |
| bool odd = |
| mantissa & |
| (1 << (in_mant - out_mant)); |
| mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; |
|
|
| |
| if(out_exponent == 0) |
| { |
| if((1 << in_mant) & mantissa) |
| { |
| out_exponent = 1; |
| |
| } |
| } |
| else |
| { |
| if((1 << (in_mant + 1)) & mantissa) |
| { |
| mantissa >>= 1; |
| out_exponent++; |
| |
| } |
| } |
|
|
| mantissa >>= (in_mant - out_mant); |
|
|
| if(out_exponent > max_exp) |
| { |
| if constexpr(clip) |
| { |
| mantissa = (1 << out_mant) - 1; |
| out_exponent = max_exp; |
| } |
| else |
| { |
| return signed_inf; |
| } |
| } |
|
|
| |
| if(out_exponent == 0 && mantissa == 0) |
| return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); |
| mantissa &= (1 << out_mant) - 1; |
| return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; |
| } |
|
|
| template <typename X, typename Y, bool negative_zero_nan> |
| __host__ __device__ Y run_cast_from_f8(X x) |
| { |
| |
| constexpr int in_exp = NumericUtils<X>::exp; |
| constexpr int in_mant = NumericUtils<X>::mant; |
|
|
| |
| constexpr int out_exp = NumericUtils<Y>::exp; |
| constexpr int out_mant = NumericUtils<Y>::mant; |
|
|
| |
| constexpr X nan_code = 0x80; |
| using T_bitwise = typename NumericUtils<Y>::bitwise_type; |
|
|
| constexpr T_bitwise Inf_bitwise = NumericUtils<Y>::Inf; |
| constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf; |
| constexpr T_bitwise NaN_bitwise = NumericUtils<Y>::NaN; |
| constexpr T_bitwise Neg0_bitwise = NumericUtils<Y>::Neg0; |
|
|
| constexpr Y Inf = bit_cast<Y>(Inf_bitwise); |
| constexpr Y NegInf = bit_cast<Y>(NegInf_bitwise); |
| constexpr Y NaN = bit_cast<Y>(NaN_bitwise); |
| constexpr Y Neg0 = bit_cast<Y>(Neg0_bitwise); |
|
|
| |
| if(x == 0) |
| return static_cast<Y>(0); |
|
|
| |
| uint32_t sign = x >> (in_exp + in_mant); |
| uint32_t mantissa = x & ((1 << in_mant) - 1); |
| int exponent = (x & 0x7F) >> in_mant; |
|
|
| constexpr int exp_low_cutoff = |
| (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); |
| T_bitwise retval; |
|
|
| if constexpr(negative_zero_nan) |
| { |
| if(x == nan_code) |
| return NaN; |
| } |
| else |
| { |
| if(x == nan_code) |
| return Neg0; |
| if(exponent == ((1 << in_exp) - 1)) |
| return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; |
| } |
|
|
| if constexpr((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && |
| !negative_zero_nan) |
| { |
| retval = x; |
| retval <<= 8; |
| return bit_cast<Y>(retval); |
| } |
|
|
| |
| if(exponent == 0) |
| { |
| |
| int sh = 1 + clz(mantissa) - (32 - in_mant); |
| mantissa <<= sh; |
| exponent += 1 - sh; |
| mantissa &= ((1 << in_mant) - 1); |
| } |
| exponent += exp_low_cutoff - 1; |
| mantissa <<= out_mant - in_mant; |
|
|
| |
| if(exponent <= 0) |
| { |
| mantissa |= 1 << out_mant; |
| mantissa >>= 1 - exponent; |
| exponent = 0; |
| } |
|
|
| retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; |
| return bit_cast<Y>(retval); |
| } |
|
|
| } |
|
|
| template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch> |
| __host__ __device__ Y cast_to_f8(X x, uint32_t rng) |
| { |
| |
| constexpr bool is_half = std::is_same<X, half_t>::value; |
| constexpr bool is_float = std::is_same<X, float>::value; |
| static_assert(is_half || is_float, "Only half and float can be casted."); |
|
|
| return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng); |
| } |
|
|
| template <typename X, typename Y, bool negative_zero_nan> |
| __host__ __device__ Y cast_from_f8(X x) |
| { |
| |
| constexpr bool is_half = std::is_same<Y, half_t>::value; |
| constexpr bool is_float = std::is_same<Y, float>::value; |
| static_assert(is_half || is_float, "only half and float are supported."); |
|
|
| return run_cast_from_f8<X, Y, negative_zero_nan>(x); |
| } |
|
|
| } |
|
|