File size: 3,771 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <torch/torch.h>

#include "cuda_intellisense.cuh"

#ifndef __CUDACC__
#pragma message("__CUDACC__ not defined!")
#else
#pragma message("__CUDACC__ defined!")
#endif

#ifdef __NVCC__
#define __qr_device__ __device__
#define __qr_host__ __host__
#define __qr_inline__ __forceinline__
#else
#define __qr_device__
#define __qr_host__
#define __qr_inline__ inline
#endif

#ifdef __CUDACC__
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>


__qr_inline__ __device__ __half operator-(__half v) {
    return __hneg(v);
}

__qr_inline__ __device__ __half operator+(__half a, __half b) {
    return __hadd(a, b);
}

__qr_inline__ __device__ __half operator-(__half a, __half b) {
    return __hsub(a, b);
}

__qr_inline__ __device__ __half operator*(__half a, __half b) {
    return __hmul(a, b);
}

__qr_inline__ __device__ __half operator/(__half a, __half b) {
    return __hdiv(a, b);
}

__qr_inline__ __device__ bool operator==(__half a, __half b) {
    return __heq(a, b);
}

__qr_inline__ __device__ bool operator<(__half a, __half b) {
    return __hlt(a, b);
}

__qr_inline__ __device__ bool operator>(__half a, __half b) {
    return __hgt(a, b);
}

__qr_inline__ __device__ __half sqrt(__half v) {
    return hsqrt(v);
}

__qr_inline__ __device__ __half floor(__half v) {
    return hfloor(v);
}

__qr_inline__ __device__ __half ceil(__half v) {
    return hceil(v);
}

__qr_inline__ __device__ __half max(__half a, __half b) {
    return a > b ? a : b;
}
#endif //__CUDACC__

template<typename Src, typename Dest>
struct Convert {
    __qr_inline__ static __qr_host__ __qr_device__ constexpr Dest From(Src value) { return static_cast<Dest>(value); }
    __qr_inline__ static __qr_host__ __qr_device__ constexpr Src To(Dest value) { return static_cast<Src>(value); }
    __qr_inline__ static __qr_host__ __qr_device__ constexpr Dest LeftToRight(Src value) { return static_cast<Dest>(value); }
    __qr_inline__ static __qr_host__ __qr_device__ constexpr Src RightToLeft(Dest value) { return static_cast<Src>(value); }
};

#ifdef __CUDACC__
template<>
struct Convert<__half, float> {
    __qr_inline__ static __host__ __device__ float From(__half value) { return __half2float(value); }
    __qr_inline__ static __host__ __device__ __half To(float value) { return __float2half(value); }
    __qr_inline__ static __host__ __device__ float LeftToRight(__half value) { return __half2float(value); }
    __qr_inline__ static __host__ __device__ __half RightToLeft(float value) { return __float2half(value); }
};

template<typename Dest>
struct Convert<__half, Dest> : Convert<__half, float> {

};

namespace at {

template<>
inline __half* TensorBase::mutable_data_ptr() const {
    TORCH_CHECK(scalar_type() == ScalarType::Half,
                "expected scalar type Half but found ",
                c10::toString(scalar_type()));
    return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
}

template<>
inline __half* TensorBase::data_ptr() const {
    TORCH_CHECK(scalar_type() == ScalarType::Half,
                "expected scalar type Half but found ",
                c10::toString(scalar_type()));
    return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
}

}

template<typename T>
struct remap_half {
    typedef T type;
};

template<>
struct remap_half<at::Half> {
    typedef __half type;
};

template<typename T>
__half to_half(T val) {
    return Convert<__half, T>::RightToLeft(val);
}

template<typename T>
struct fp_promote {
    typedef T type;
};

template<>
struct fp_promote<__half> {
    typedef float type;
};

#endif //__CUDACC__