/****************************************************************************** * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once /** * \file * Utilities for interacting with the opaque CUDA __nv_bfloat16 type */ #include #include #include #include #ifdef __GNUC__ // There's a ton of type-punning going on in this file. #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif /****************************************************************************** * bfloat16_t ******************************************************************************/ /** * Host-based fp16 data type compatible and convertible with __nv_bfloat16 */ struct bfloat16_t { uint16_t __x; /// Constructor from __nv_bfloat16 __host__ __device__ __forceinline__ bfloat16_t(const __nv_bfloat16 &other) { __x = reinterpret_cast(other); } /// Constructor from integer __host__ __device__ __forceinline__ bfloat16_t(int a) { *this = bfloat16_t(float(a)); } /// Constructor from std::size_t __host__ __device__ __forceinline__ bfloat16_t(std::size_t a) { *this = bfloat16_t(float(a)); } /// Default constructor bfloat16_t() = default; /// Constructor from float __host__ __device__ __forceinline__ bfloat16_t(float a) { // Refrence: // https://github.com/pytorch/pytorch/blob/44cc873fba5e5ffc4d4d4eef3bd370b653ce1ce1/c10/util/BFloat16.h#L51 uint16_t ir; if (a != a) { ir = UINT16_C(0x7FFF); } else { union { uint32_t U32; float F32; }; F32 = a; uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); ir = static_cast((U32 + rounding_bias) >> 16); } this->__x = ir; } /// Cast to __nv_bfloat16 __host__ __device__ __forceinline__ operator __nv_bfloat16() const { return reinterpret_cast(__x); } /// Cast to float __host__ __device__ __forceinline__ operator float() const { float f = 0; uint32_t *p = reinterpret_cast(&f); *p = uint32_t(__x) << 16; return f; } /// Get raw storage __host__ __device__ __forceinline__ uint16_t raw() const { return this->__x; } /// Equality __host__ __device__ __forceinline__ bool operator ==(const bfloat16_t &other) const { return (this->__x == other.__x); } /// Inequality __host__ __device__ __forceinline__ bool operator !=(const bfloat16_t &other) const { return (this->__x != other.__x); } /// Assignment by sum __host__ __device__ __forceinline__ bfloat16_t& operator +=(const bfloat16_t &rhs) { *this = bfloat16_t(float(*this) + float(rhs)); return *this; } /// Multiply __host__ __device__ __forceinline__ bfloat16_t operator*(const bfloat16_t &other) { return bfloat16_t(float(*this) * float(other)); } /// Add __host__ __device__ __forceinline__ bfloat16_t operator+(const bfloat16_t &other) { return bfloat16_t(float(*this) + float(other)); } /// Less-than __host__ __device__ __forceinline__ bool operator<(const bfloat16_t &other) const { return float(*this) < float(other); } /// Less-than-equal __host__ __device__ __forceinline__ bool operator<=(const bfloat16_t &other) const { return float(*this) <= float(other); } /// Greater-than __host__ __device__ __forceinline__ bool operator>(const bfloat16_t &other) const { return float(*this) > float(other); } /// Greater-than-equal __host__ __device__ __forceinline__ bool operator>=(const bfloat16_t &other) const { return float(*this) >= float(other); } /// numeric_traits::max __host__ __device__ __forceinline__ static bfloat16_t (max)() { uint16_t max_word = 0x7F7F; return reinterpret_cast(max_word); } /// numeric_traits::lowest __host__ __device__ __forceinline__ static bfloat16_t lowest() { uint16_t lowest_word = 0xFF7F; return reinterpret_cast(lowest_word); } }; /****************************************************************************** * I/O stream overloads ******************************************************************************/ /// Insert formatted \p bfloat16_t into the output stream std::ostream& operator<<(std::ostream &out, const bfloat16_t &x) { out << (float)x; return out; } /// Insert formatted \p __nv_bfloat16 into the output stream std::ostream& operator<<(std::ostream &out, const __nv_bfloat16 &x) { return out << bfloat16_t(x); } /****************************************************************************** * Traits overloads ******************************************************************************/ template <> struct CUB_NS_QUALIFIER::FpLimits { static __host__ __device__ __forceinline__ bfloat16_t Max() { return bfloat16_t::max(); } static __host__ __device__ __forceinline__ bfloat16_t Lowest() { return bfloat16_t::lowest(); } }; template <> struct CUB_NS_QUALIFIER::NumericTraits : CUB_NS_QUALIFIER:: BaseTraits {}; #ifdef __GNUC__ #pragma GCC diagnostic pop #endif