| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #pragma once |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | #include <stdint.h> |
| | #include <cuda_bf16.h> |
| | #include <iosfwd> |
| |
|
| | #include <cub/util_type.cuh> |
| |
|
| | #ifdef __GNUC__ |
| | |
| | #pragma GCC diagnostic push |
| | #pragma GCC diagnostic ignored "-Wstrict-aliasing" |
| | #endif |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | struct bfloat16_t |
| | { |
| | uint16_t __x; |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t(const __nv_bfloat16 &other) |
| | { |
| | __x = reinterpret_cast<const uint16_t&>(other); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t(int a) |
| | { |
| | *this = bfloat16_t(float(a)); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t(std::size_t a) |
| | { |
| | *this = bfloat16_t(float(a)); |
| | } |
| |
|
| | |
| | bfloat16_t() = default; |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t(float a) |
| | { |
| | |
| | |
| | uint16_t ir; |
| | if (a != a) { |
| | ir = UINT16_C(0x7FFF); |
| | } else { |
| | union { |
| | uint32_t U32; |
| | float F32; |
| | }; |
| |
|
| | F32 = a; |
| | uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); |
| | ir = static_cast<uint16_t>((U32 + rounding_bias) >> 16); |
| | } |
| | this->__x = ir; |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | operator __nv_bfloat16() const |
| | { |
| | return reinterpret_cast<const __nv_bfloat16&>(__x); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | operator float() const |
| | { |
| | float f = 0; |
| | uint32_t *p = reinterpret_cast<uint32_t *>(&f); |
| | *p = uint32_t(__x) << 16; |
| | return f; |
| | } |
| |
|
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | uint16_t raw() const |
| | { |
| | return this->__x; |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bool operator ==(const bfloat16_t &other) const |
| | { |
| | return (this->__x == other.__x); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bool operator !=(const bfloat16_t &other) const |
| | { |
| | return (this->__x != other.__x); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t& operator +=(const bfloat16_t &rhs) |
| | { |
| | *this = bfloat16_t(float(*this) + float(rhs)); |
| | return *this; |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t operator*(const bfloat16_t &other) |
| | { |
| | return bfloat16_t(float(*this) * float(other)); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bfloat16_t operator+(const bfloat16_t &other) |
| | { |
| | return bfloat16_t(float(*this) + float(other)); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bool operator<(const bfloat16_t &other) const |
| | { |
| | return float(*this) < float(other); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bool operator<=(const bfloat16_t &other) const |
| | { |
| | return float(*this) <= float(other); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bool operator>(const bfloat16_t &other) const |
| | { |
| | return float(*this) > float(other); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | bool operator>=(const bfloat16_t &other) const |
| | { |
| | return float(*this) >= float(other); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | static bfloat16_t (max)() { |
| | uint16_t max_word = 0x7F7F; |
| | return reinterpret_cast<bfloat16_t&>(max_word); |
| | } |
| |
|
| | |
| | __host__ __device__ __forceinline__ |
| | static bfloat16_t lowest() { |
| | uint16_t lowest_word = 0xFF7F; |
| | return reinterpret_cast<bfloat16_t&>(lowest_word); |
| | } |
| | }; |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | std::ostream& operator<<(std::ostream &out, const bfloat16_t &x) |
| | { |
| | out << (float)x; |
| | return out; |
| | } |
| |
|
| |
|
| | |
| | std::ostream& operator<<(std::ostream &out, const __nv_bfloat16 &x) |
| | { |
| | return out << bfloat16_t(x); |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | template <> |
| | struct CUB_NS_QUALIFIER::FpLimits<bfloat16_t> |
| | { |
| | static __host__ __device__ __forceinline__ bfloat16_t Max() { return bfloat16_t::max(); } |
| |
|
| | static __host__ __device__ __forceinline__ bfloat16_t Lowest() { return bfloat16_t::lowest(); } |
| | }; |
| |
|
| | template <> |
| | struct CUB_NS_QUALIFIER::NumericTraits<bfloat16_t> |
| | : CUB_NS_QUALIFIER:: |
| | BaseTraits<FLOATING_POINT, true, false, unsigned short, bfloat16_t> |
| | {}; |
| |
|
| | #ifdef __GNUC__ |
| | #pragma GCC diagnostic pop |
| | #endif |
| |
|