| #include <ATen/ATen.h> |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ |
| switch(TYPE) \ |
| { \ |
| case at::ScalarType::Float: \ |
| { \ |
| using scalar_t_##LEVEL = float; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| case at::ScalarType::Half: \ |
| { \ |
| using scalar_t_##LEVEL = at::Half; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| default: \ |
| AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
| } |
|
|
|
|
| #define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ |
| switch(TYPE) \ |
| { \ |
| case at::ScalarType::Double: \ |
| { \ |
| using scalar_t_##LEVEL = double; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| case at::ScalarType::Float: \ |
| { \ |
| using scalar_t_##LEVEL = float; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| case at::ScalarType::Half: \ |
| { \ |
| using scalar_t_##LEVEL = at::Half; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| default: \ |
| AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
| } |
|
|
|
|
| #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ |
| switch(TYPE) \ |
| { \ |
| case at::ScalarType::Double: \ |
| { \ |
| using scalar_t_##LEVEL = double; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| case at::ScalarType::Float: \ |
| { \ |
| using scalar_t_##LEVEL = float; \ |
| __VA_ARGS__; \ |
| break; \ |
| } \ |
| default: \ |
| AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ |
| } |
|
|
|
|
| template<typename T> |
| __device__ __forceinline__ T reduce_block_into_lanes |
| (T *x, |
| T val, |
| int lanes=1, |
| bool share_result=false) |
| { |
| int tid = threadIdx.x + threadIdx.y*blockDim.x; |
| int blockSize = blockDim.x*blockDim.y; |
|
|
| if(blockSize >= 64) |
| { |
| x[tid] = val; |
| __syncthreads(); |
| } |
|
|
| #pragma unroll |
| for(int i = (blockSize >> 1); i >= 64; i >>= 1) |
| { |
| if(tid < i) |
| x[tid] = x[tid] + x[tid+i]; |
| __syncthreads(); |
| } |
|
|
| T final; |
|
|
| if(tid < 32) |
| { |
| if(blockSize >= 64) |
| final = x[tid] + x[tid+32]; |
| else |
| final = val; |
| |
|
|
| #pragma unroll |
| for(int i = 16; i >= lanes; i >>= 1) |
| final = final + __shfl_down_sync(0xffffffff, final, i); |
| } |
|
|
| if(share_result) |
| { |
| if(tid < lanes) |
| x[tid] = final; |
| |
| __syncthreads(); |
| } |
|
|
| return final; |
| } |
|
|