| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #pragma once |
|
|
| #ifdef USE_ROCM |
|
|
| #include <hip/hip_bf16.h> |
| #include <hip/hip_common.h> |
| #include <hip/hip_fp16.h> |
|
|
| |
|
|
| namespace amdgpu { |
|
|
| template <typename T> |
| __forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize); |
|
|
| template <typename srcDtype, typename destDtype> |
| __forceinline__ __device__ destDtype cast(srcDtype val) { |
| |
| |
| return static_cast<destDtype>(val); |
| } |
|
|
| |
| template <> |
| __forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) { |
| return __shfl_xor(var, laneMask, width); |
| } |
|
|
| template <> |
| __forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) { |
| return __shfl_xor(var, laneMask, width); |
| } |
|
|
| template <> |
| __forceinline__ __device__ float cast<float, float>(float val) { |
| return val; |
| } |
|
|
| template <> |
| __forceinline__ __device__ float cast<__half, float>(__half val) { |
| return __half2float(val); |
| } |
|
|
| template <> |
| __forceinline__ __device__ float cast<__hip_bfloat16, float>(__hip_bfloat16 val) { |
| return __bfloat162float(val); |
| } |
|
|
| template <> |
| __forceinline__ __device__ __half cast<float, __half>(float fval) { |
| return __float2half(fval); |
| } |
|
|
| template <> |
| __forceinline__ __device__ __hip_bfloat16 cast<float, __hip_bfloat16>(float fval) { |
| return __float2bfloat16(fval); |
| } |
|
|
| } |
|
|
| template <typename T> |
| __forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) { |
| return amdgpu::shfl_xor_sync(mask, var, laneMask, width); |
| } |
|
|
| template <typename srcDtype> |
| __device__ __forceinline__ float castToFloat(srcDtype val) { |
| return amdgpu::cast<srcDtype, float>(val); |
| } |
|
|
| template <typename dstDtype> |
| __device__ __forceinline__ dstDtype castFromFloat(float val) { |
| return amdgpu::cast<float, dstDtype>(val); |
| } |
|
|
| |
| __host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) { |
| __half h_x = x; |
| __half h_y = y; |
| return __hmul(h_x, h_y); |
| } |
|
|
| #endif |
|
|