File size: 3,771 Bytes
e05eed1 98a67a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <torch/torch.h>
#include "cuda_intellisense.cuh"
#ifndef __CUDACC__
#pragma message("__CUDACC__ not defined!")
#else
#pragma message("__CUDACC__ defined!")
#endif
#ifdef __NVCC__
#define __qr_device__ __device__
#define __qr_host__ __host__
#define __qr_inline__ __forceinline__
#else
#define __qr_device__
#define __qr_host__
#define __qr_inline__ inline
#endif
#ifdef __CUDACC__
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__qr_inline__ __device__ __half operator-(__half v) {
return __hneg(v);
}
__qr_inline__ __device__ __half operator+(__half a, __half b) {
return __hadd(a, b);
}
__qr_inline__ __device__ __half operator-(__half a, __half b) {
return __hsub(a, b);
}
__qr_inline__ __device__ __half operator*(__half a, __half b) {
return __hmul(a, b);
}
__qr_inline__ __device__ __half operator/(__half a, __half b) {
return __hdiv(a, b);
}
__qr_inline__ __device__ bool operator==(__half a, __half b) {
return __heq(a, b);
}
__qr_inline__ __device__ bool operator<(__half a, __half b) {
return __hlt(a, b);
}
__qr_inline__ __device__ bool operator>(__half a, __half b) {
return __hgt(a, b);
}
__qr_inline__ __device__ __half sqrt(__half v) {
return hsqrt(v);
}
__qr_inline__ __device__ __half floor(__half v) {
return hfloor(v);
}
__qr_inline__ __device__ __half ceil(__half v) {
return hceil(v);
}
__qr_inline__ __device__ __half max(__half a, __half b) {
return a > b ? a : b;
}
#endif //__CUDACC__
template<typename Src, typename Dest>
struct Convert {
__qr_inline__ static __qr_host__ __qr_device__ constexpr Dest From(Src value) { return static_cast<Dest>(value); }
__qr_inline__ static __qr_host__ __qr_device__ constexpr Src To(Dest value) { return static_cast<Src>(value); }
__qr_inline__ static __qr_host__ __qr_device__ constexpr Dest LeftToRight(Src value) { return static_cast<Dest>(value); }
__qr_inline__ static __qr_host__ __qr_device__ constexpr Src RightToLeft(Dest value) { return static_cast<Src>(value); }
};
#ifdef __CUDACC__
template<>
struct Convert<__half, float> {
__qr_inline__ static __host__ __device__ float From(__half value) { return __half2float(value); }
__qr_inline__ static __host__ __device__ __half To(float value) { return __float2half(value); }
__qr_inline__ static __host__ __device__ float LeftToRight(__half value) { return __half2float(value); }
__qr_inline__ static __host__ __device__ __half RightToLeft(float value) { return __float2half(value); }
};
template<typename Dest>
struct Convert<__half, Dest> : Convert<__half, float> {
};
namespace at {
template<>
inline __half* TensorBase::mutable_data_ptr() const {
TORCH_CHECK(scalar_type() == ScalarType::Half,
"expected scalar type Half but found ",
c10::toString(scalar_type()));
return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
}
template<>
inline __half* TensorBase::data_ptr() const {
TORCH_CHECK(scalar_type() == ScalarType::Half,
"expected scalar type Half but found ",
c10::toString(scalar_type()));
return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
}
}
template<typename T>
struct remap_half {
typedef T type;
};
template<>
struct remap_half<at::Half> {
typedef __half type;
};
template<typename T>
__half to_half(T val) {
return Convert<__half, T>::RightToLeft(val);
}
template<typename T>
struct fp_promote {
typedef T type;
};
template<>
struct fp_promote<__half> {
typedef float type;
};
#endif //__CUDACC__
|