|
|
#pragma once |
|
|
|
|
|
#include <ATen/cuda/detail/IndexUtils.cuh> |
|
|
#include <ATen/native/cuda/Loops.cuh> |
|
|
#include <ATen/native/cuda/SortingCommon.cuh> |
|
|
#include <ATen/native/cuda/block_reduce.cuh> |
|
|
|
|
|
namespace at { |
|
|
namespace native { |
|
|
|
|
|
|
|
|
struct ModeUnsignedBoolPair { |
|
|
unsigned int val; |
|
|
bool flag; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct ModeUnsignedPair { |
|
|
unsigned int val; |
|
|
unsigned int index; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int Power2ScanSize, typename T, class BinaryOp> |
|
|
__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) { |
|
|
|
|
|
#pragma unroll |
|
|
for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { |
|
|
int index = (threadIdx.x + 1) * stride * 2 - 1; |
|
|
if (index < Power2ScanSize) { |
|
|
smem[index] = binop(smem[index], smem[index - stride]); |
|
|
} |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { |
|
|
int index = (threadIdx.x + 1) * stride * 2 - 1; |
|
|
if ((index + stride) < Power2ScanSize) { |
|
|
smem[index + stride] = binop(smem[index + stride], smem[index]); |
|
|
} |
|
|
__syncthreads(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int N, typename T, typename ReduceOp> |
|
|
__device__ T reduceBlockWithNThreadLocalReductions( |
|
|
T* smem, |
|
|
T threadVals[N], |
|
|
const unsigned int numVals, |
|
|
ReduceOp reduceOp, |
|
|
T init) { |
|
|
int offset = threadIdx.x * N; |
|
|
T local = offset < numVals ? threadVals[0] : init; |
|
|
|
|
|
#pragma unroll |
|
|
for (int i = 1; i < N; ++i) { |
|
|
++offset; |
|
|
T next = offset < numVals ? threadVals[i] : init; |
|
|
local = reduceOp.combine(local, next); |
|
|
} |
|
|
|
|
|
return cuda_utils::BlockReduce(local, reduceOp, init, smem); |
|
|
} |
|
|
|
|
|
template <typename T> |
|
|
__device__ inline void swapVars(T& t1, T& t2) { |
|
|
T tmp = t1; |
|
|
t1 = t2; |
|
|
t2 = tmp; |
|
|
} |
|
|
|
|
|
template <typename Comparator, typename K, typename V> |
|
|
__device__ inline void bitonicSwap( |
|
|
K& kA, |
|
|
V& vA, |
|
|
bool& validA, |
|
|
K& kB, |
|
|
V& vB, |
|
|
bool& validB, |
|
|
bool dir, |
|
|
const Comparator& comp) { |
|
|
|
|
|
bool swap = (comp(kA, kB) && validA) || !validB; |
|
|
if (swap == dir) { |
|
|
swapVars(kA, kB); |
|
|
swapVars(vA, vB); |
|
|
swapVars(validA, validB); |
|
|
} |
|
|
}; |
|
|
|
|
|
template <typename Comparator, typename K> |
|
|
__device__ inline void bitonicSwapKeys( |
|
|
K& kA, |
|
|
bool& validA, |
|
|
K& kB, |
|
|
bool& validB, |
|
|
bool dir, |
|
|
const Comparator& comp) { |
|
|
bool swap = (comp(kA, kB) && validA) || !validB; |
|
|
if (swap == dir) { |
|
|
swapVars(kA, kB); |
|
|
swapVars(validA, validB); |
|
|
} |
|
|
} |
|
|
|
|
|
template < |
|
|
typename K, |
|
|
typename IndexType, |
|
|
int Power2SortSize, |
|
|
typename Comparator> |
|
|
__device__ inline void bitonicSortKeys( |
|
|
K keys[Power2SortSize], |
|
|
bool valid[Power2SortSize], |
|
|
const Comparator& comp) { |
|
|
#if !defined(USE_ROCM) |
|
|
#pragma unroll |
|
|
#endif |
|
|
for (unsigned int size = 2; size < Power2SortSize; size *= 2) { |
|
|
bool flag = ((threadIdx.x & (size / 2)) != 0); |
|
|
|
|
|
#if !defined(USE_ROCM) |
|
|
#pragma unroll |
|
|
#endif |
|
|
for (unsigned int stride = size / 2; stride > 0; stride /= 2) { |
|
|
__syncthreads(); |
|
|
|
|
|
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); |
|
|
bitonicSwapKeys<Comparator, K>( |
|
|
keys[pos], |
|
|
valid[pos], |
|
|
keys[pos + stride], |
|
|
valid[pos + stride], |
|
|
flag, |
|
|
comp); |
|
|
} |
|
|
} |
|
|
|
|
|
#if !defined(USE_ROCM) |
|
|
#pragma unroll |
|
|
#endif |
|
|
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { |
|
|
__syncthreads(); |
|
|
|
|
|
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); |
|
|
bitonicSwapKeys<Comparator, K>( |
|
|
keys[pos], |
|
|
valid[pos], |
|
|
keys[pos + stride], |
|
|
valid[pos + stride], |
|
|
false, |
|
|
comp); |
|
|
} |
|
|
|
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, unsigned int Power2Size> |
|
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11070 |
|
|
__launch_bounds__(1024, 1) |
|
|
#endif |
|
|
__global__ void compute_mode( |
|
|
T* input, |
|
|
at::cuda::detail::TensorInfo<T, unsigned int> values, |
|
|
at::cuda::detail::TensorInfo<int64_t, unsigned int> indices, |
|
|
int64_t sliceSize, |
|
|
int64_t slices) { |
|
|
int tidx = threadIdx.x; |
|
|
int stidx = blockDim.x + threadIdx.x; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsigned int blockId = getLinearBlockId<unsigned int>(); |
|
|
unsigned int linearOffset = blockId * sliceSize; |
|
|
|
|
|
if (blockId >= slices) { |
|
|
return; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extern __shared__ char shmem[]; |
|
|
|
|
|
|
|
|
|
|
|
T* smem = reinterpret_cast<T*>(shmem); |
|
|
|
|
|
|
|
|
if (tidx < sliceSize) { |
|
|
smem[tidx] = c10::load(&input[linearOffset + tidx]); |
|
|
} |
|
|
if (stidx < sliceSize) { |
|
|
smem[stidx] = c10::load(&input[linearOffset + stidx]); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool* bmem = reinterpret_cast<bool*>(&smem[Power2Size]); |
|
|
|
|
|
|
|
|
|
|
|
bmem[tidx] = tidx < sliceSize; |
|
|
bmem[stidx] = stidx < sliceSize; |
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
bitonicSortKeys<T, unsigned int, Power2Size>( |
|
|
smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) { |
|
|
return a < b; |
|
|
}); |
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ModeUnsignedBoolPair* ubpmem = |
|
|
reinterpret_cast<struct ModeUnsignedBoolPair*>(&smem[Power2Size]); |
|
|
|
|
|
if (tidx == 0) { |
|
|
ubpmem[0].flag = true; |
|
|
ubpmem[0].val = 0; |
|
|
} |
|
|
|
|
|
|
|
|
ubpmem[tidx * 2 + 1].flag = |
|
|
smem[tidx * 2] != smem[tidx * 2 + 1]; |
|
|
ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag; |
|
|
|
|
|
|
|
|
if (((tidx + 1) * 2) < Power2Size) { |
|
|
ubpmem[(tidx + 1) * 2].flag = |
|
|
smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2]; |
|
|
ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag; |
|
|
} |
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inclusivePrefixScan<Power2Size>( |
|
|
ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) { |
|
|
ModeUnsignedBoolPair c; |
|
|
c.val = a.flag ? a.val : a.val + b.val; |
|
|
c.flag = a.flag | b.flag; |
|
|
return c; |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ModeUnsignedPair* uupmem = |
|
|
reinterpret_cast<struct ModeUnsignedPair*>(ubpmem); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ModeUnsignedPair uup[2]; |
|
|
uup[0].index = tidx * 2; |
|
|
uup[0].val = ubpmem[tidx * 2].val; |
|
|
uup[1].index = tidx * 2 + 1; |
|
|
uup[1].val = ubpmem[tidx * 2 + 1].val; |
|
|
__syncthreads(); |
|
|
|
|
|
struct ModeUnsignedPair max = {0, 0}; |
|
|
|
|
|
struct MaxOp { |
|
|
inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const { |
|
|
return b.val > a.val ? b : a; |
|
|
} |
|
|
|
|
|
inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const { |
|
|
ModeUnsignedPair ret; |
|
|
ret.index = WARP_SHFL_DOWN(acc.index, offset); |
|
|
ret.val = WARP_SHFL_DOWN(acc.val, offset); |
|
|
return ret; |
|
|
} |
|
|
} max_op; |
|
|
|
|
|
max = reduceBlockWithNThreadLocalReductions<2>( |
|
|
uupmem, |
|
|
uup, |
|
|
sliceSize, |
|
|
max_op, |
|
|
max); |
|
|
|
|
|
|
|
|
|
|
|
__shared__ T mode; |
|
|
|
|
|
|
|
|
|
|
|
if (tidx == 0) { |
|
|
mode = smem[max.index]; |
|
|
} |
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsigned mode_index[2] = {0u, 0u}; |
|
|
if (tidx * 2 < sliceSize) { |
|
|
const unsigned idx = tidx * 2; |
|
|
mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; |
|
|
} |
|
|
if (tidx * 2 + 1 < sliceSize) { |
|
|
const unsigned idx = tidx * 2 + 1; |
|
|
mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; |
|
|
} |
|
|
|
|
|
struct MaxIndexOp { |
|
|
inline __device__ unsigned combine(unsigned a, unsigned b) const { |
|
|
return b > a ? b : a; |
|
|
} |
|
|
|
|
|
inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const { |
|
|
return WARP_SHFL_DOWN(acc, offset); |
|
|
} |
|
|
} max_index_op; |
|
|
|
|
|
int64_t index = reduceBlockWithNThreadLocalReductions<2>( |
|
|
reinterpret_cast<unsigned*>(&shmem[0]), |
|
|
mode_index, |
|
|
sliceSize, |
|
|
max_index_op, |
|
|
0u); |
|
|
|
|
|
|
|
|
|
|
|
if (tidx == 0) { |
|
|
unsigned int outputOffset = |
|
|
at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get( |
|
|
blockId, values); |
|
|
values.data[outputOffset] = mode; |
|
|
indices.data[outputOffset] = index; |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
|