Spaces:
Runtime error
Runtime error
| /****************************************************************************** | |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. | |
| * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * * Redistributions of source code must retain the above copyright | |
| * notice, this list of conditions and the following disclaimer. | |
| * * Redistributions in binary form must reproduce the above copyright | |
| * notice, this list of conditions and the following disclaimer in the | |
| * documentation and/or other materials provided with the distribution. | |
| * * Neither the name of the NVIDIA CORPORATION nor the | |
| * names of its contributors may be used to endorse or promote products | |
| * derived from this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| ******************************************************************************/ | |
| /** | |
| * \file | |
| * Utilities for interacting with the opaque CUDA __half type | |
| */ | |
| // There's a ton of type-punning going on in this file. | |
| /****************************************************************************** | |
| * half_t | |
| ******************************************************************************/ | |
| /** | |
| * Host-based fp16 data type compatible and convertible with __half | |
| */ | |
| struct half_t | |
| { | |
| uint16_t __x; | |
| /// Constructor from __half | |
| __host__ __device__ __forceinline__ | |
| half_t(const __half &other) | |
| { | |
| __x = reinterpret_cast<const uint16_t&>(other); | |
| } | |
| /// Constructor from integer | |
| __host__ __device__ __forceinline__ | |
| half_t(int a) | |
| { | |
| *this = half_t(float(a)); | |
| } | |
| /// Default constructor | |
| __host__ __device__ __forceinline__ | |
| half_t() : __x(0) | |
| {} | |
| /// Constructor from float | |
| __host__ __device__ __forceinline__ | |
| half_t(float a) | |
| { | |
| // Stolen from Norbert Juffa | |
| uint32_t ia = *reinterpret_cast<uint32_t*>(&a); | |
| uint16_t ir; | |
| ir = (ia >> 16) & 0x8000; | |
| if ((ia & 0x7f800000) == 0x7f800000) | |
| { | |
| if ((ia & 0x7fffffff) == 0x7f800000) | |
| { | |
| ir |= 0x7c00; /* infinity */ | |
| } | |
| else | |
| { | |
| ir = 0x7fff; /* canonical NaN */ | |
| } | |
| } | |
| else if ((ia & 0x7f800000) >= 0x33000000) | |
| { | |
| int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127; | |
| if (shift > 15) | |
| { | |
| ir |= 0x7c00; /* infinity */ | |
| } | |
| else | |
| { | |
| ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */ | |
| if (shift < -14) | |
| { /* denormal */ | |
| ir |= ia >> (-1 - shift); | |
| ia = ia << (32 - (-1 - shift)); | |
| } | |
| else | |
| { /* normal */ | |
| ir |= ia >> (24 - 11); | |
| ia = ia << (32 - (24 - 11)); | |
| ir = ir + ((14 + shift) << 10); | |
| } | |
| /* IEEE-754 round to nearest of even */ | |
| if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1))) | |
| { | |
| ir++; | |
| } | |
| } | |
| } | |
| this->__x = ir; | |
| } | |
| /// Cast to __half | |
| __host__ __device__ __forceinline__ | |
| operator __half() const | |
| { | |
| return reinterpret_cast<const __half&>(__x); | |
| } | |
| /// Cast to float | |
| __host__ __device__ __forceinline__ | |
| operator float() const | |
| { | |
| // Stolen from Andrew Kerr | |
| int sign = ((this->__x >> 15) & 1); | |
| int exp = ((this->__x >> 10) & 0x1f); | |
| int mantissa = (this->__x & 0x3ff); | |
| uint32_t f = 0; | |
| if (exp > 0 && exp < 31) | |
| { | |
| // normal | |
| exp += 112; | |
| f = (sign << 31) | (exp << 23) | (mantissa << 13); | |
| } | |
| else if (exp == 0) | |
| { | |
| if (mantissa) | |
| { | |
| // subnormal | |
| exp += 113; | |
| while ((mantissa & (1 << 10)) == 0) | |
| { | |
| mantissa <<= 1; | |
| exp--; | |
| } | |
| mantissa &= 0x3ff; | |
| f = (sign << 31) | (exp << 23) | (mantissa << 13); | |
| } | |
| else if (sign) | |
| { | |
| f = 0x80000000; // negative zero | |
| } | |
| else | |
| { | |
| f = 0x0; // zero | |
| } | |
| } | |
| else if (exp == 31) | |
| { | |
| if (mantissa) | |
| { | |
| f = 0x7fffffff; // not a number | |
| } | |
| else | |
| { | |
| f = (0xff << 23) | (sign << 31); // inf | |
| } | |
| } | |
| return *reinterpret_cast<float const *>(&f); | |
| } | |
| /// Get raw storage | |
| __host__ __device__ __forceinline__ | |
| uint16_t raw() | |
| { | |
| return this->__x; | |
| } | |
| /// Equality | |
| __host__ __device__ __forceinline__ | |
| bool operator ==(const half_t &other) | |
| { | |
| return (this->__x == other.__x); | |
| } | |
| /// Inequality | |
| __host__ __device__ __forceinline__ | |
| bool operator !=(const half_t &other) | |
| { | |
| return (this->__x != other.__x); | |
| } | |
| /// Assignment by sum | |
| __host__ __device__ __forceinline__ | |
| half_t& operator +=(const half_t &rhs) | |
| { | |
| *this = half_t(float(*this) + float(rhs)); | |
| return *this; | |
| } | |
| /// Multiply | |
| __host__ __device__ __forceinline__ | |
| half_t operator*(const half_t &other) | |
| { | |
| return half_t(float(*this) * float(other)); | |
| } | |
| /// Add | |
| __host__ __device__ __forceinline__ | |
| half_t operator+(const half_t &other) | |
| { | |
| return half_t(float(*this) + float(other)); | |
| } | |
| /// Less-than | |
| __host__ __device__ __forceinline__ | |
| bool operator<(const half_t &other) const | |
| { | |
| return float(*this) < float(other); | |
| } | |
| /// Less-than-equal | |
| __host__ __device__ __forceinline__ | |
| bool operator<=(const half_t &other) const | |
| { | |
| return float(*this) <= float(other); | |
| } | |
| /// Greater-than | |
| __host__ __device__ __forceinline__ | |
| bool operator>(const half_t &other) const | |
| { | |
| return float(*this) > float(other); | |
| } | |
| /// Greater-than-equal | |
| __host__ __device__ __forceinline__ | |
| bool operator>=(const half_t &other) const | |
| { | |
| return float(*this) >= float(other); | |
| } | |
| /// numeric_traits<half_t>::max | |
| __host__ __device__ __forceinline__ | |
| static half_t max() { | |
| uint16_t max_word = 0x7BFF; | |
| return reinterpret_cast<half_t&>(max_word); | |
| } | |
| /// numeric_traits<half_t>::lowest | |
| __host__ __device__ __forceinline__ | |
| static half_t lowest() { | |
| uint16_t lowest_word = 0xFBFF; | |
| return reinterpret_cast<half_t&>(lowest_word); | |
| } | |
| }; | |
| /****************************************************************************** | |
| * I/O stream overloads | |
| ******************************************************************************/ | |
| /// Insert formatted \p half_t into the output stream | |
| std::ostream& operator<<(std::ostream &out, const half_t &x) | |
| { | |
| out << (float)x; | |
| return out; | |
| } | |
| /// Insert formatted \p __half into the output stream | |
| std::ostream& operator<<(std::ostream &out, const __half &x) | |
| { | |
| return out << half_t(x); | |
| } | |
| /****************************************************************************** | |
| * Traits overloads | |
| ******************************************************************************/ | |
| template <> | |
| struct cub::FpLimits<half_t> | |
| { | |
| static __host__ __device__ __forceinline__ half_t Max() { return half_t::max(); } | |
| static __host__ __device__ __forceinline__ half_t Lowest() { return half_t::lowest(); } | |
| }; | |
| template <> struct cub::NumericTraits<half_t> : cub::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t> {}; | |