Kernels:
Trusted publisher
| // #include <torch/extension.h> | |
| namespace megablocks { | |
| typedef __half2 half2; | |
| struct __align__(8) half4 { | |
| half2 x, y; | |
| }; | |
| struct __align__(16) half8 { | |
| half2 x, y, z, w; | |
| }; | |
| template <class To, class From> | |
| __device__ __forceinline__ To BitCast(const From& src) noexcept { | |
| To dst; | |
| std::memcpy(&dst, &src, sizeof(To)); | |
| return dst; | |
| } | |
| template <typename T> | |
| __device__ __forceinline__ void Store(const T& value, T* ptr) { | |
| *ptr = value; | |
| } | |
| template <typename T> | |
| __device__ __forceinline__ T Load(const T* address) { | |
| return __ldg(address); | |
| } | |
| __device__ __forceinline__ half4 Load(const half4* address) { | |
| float2 x = __ldg(reinterpret_cast<const float2*>(address)); | |
| return BitCast<half4>(x); | |
| } | |
| __device__ __forceinline__ half8 Load(const half8* address) { | |
| float4 x = __ldg(reinterpret_cast<const float4*>(address)); | |
| return BitCast<half8>(x); | |
| } | |
| template <typename T> | |
| __device__ __forceinline__ T Zero() { return 0; }; | |
| template <> | |
| __device__ __forceinline__ half2 Zero<half2>() { | |
| return {(c10::Half)0., (c10::Half)0.}; | |
| }; | |
| template <> | |
| __device__ __forceinline__ half4 Zero<half4>() { | |
| return {Zero<half2>(), Zero<half2>()}; | |
| }; | |
| } // namespace megablocks | |