|
|
#pragma once |
|
|
#include <ATen/NumericUtils.h> |
|
|
#include <ATen/core/TensorBase.h> |
|
|
#include <ATen/cuda/cub.cuh> |
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
|
|
#include <c10/util/Load.h> |
|
|
#include <limits> |
|
|
#include <cmath> |
|
|
|
|
|
namespace at { |
|
|
namespace native { |
|
|
|
|
|
template <typename integer> |
|
|
constexpr inline integer ceil_div(integer n, integer m) { |
|
|
return (n + m - 1) / m; |
|
|
} |
|
|
|
|
|
template<typename scalar_t, typename idx_t, typename BinaryOperation> |
|
|
__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) { |
|
|
if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) { |
|
|
rhs = lhs; |
|
|
rhs_idx = lhs_idx; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, int num_threads_x, int num_threads_y, class BinaryFunction> |
|
|
__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, |
|
|
int num_rows, int row_size, |
|
|
scalar_t init, BinaryFunction binary_op) { |
|
|
__shared__ scalar_t vbuf[num_threads_y][2 * num_threads_x]; |
|
|
__shared__ int64_t ibuf[num_threads_y][2 * num_threads_x]; |
|
|
scalar_t* row_buf = vbuf[threadIdx.y]; |
|
|
int64_t* row_idx_buf = ibuf[threadIdx.y]; |
|
|
|
|
|
for (int block_row = blockIdx.x * blockDim.y; |
|
|
block_row < num_rows; |
|
|
block_row += blockDim.y * gridDim.x) { |
|
|
int row = block_row + threadIdx.y; |
|
|
const scalar_t *row_self = self_ + row * row_size; |
|
|
scalar_t *row_values = values_ + row * row_size; |
|
|
int64_t *row_indices = indices_ + row * row_size; |
|
|
scalar_t block_total = init; |
|
|
int64_t block_idx_final = 0; |
|
|
|
|
|
|
|
|
for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { |
|
|
|
|
|
int col1 = block_col + threadIdx.x; |
|
|
int col2 = block_col + num_threads_x + threadIdx.x; |
|
|
if (row < num_rows) { |
|
|
if (col1 < row_size) { |
|
|
row_buf[threadIdx.x] = c10::load(&row_self[col1]); |
|
|
row_idx_buf[threadIdx.x] = col1; |
|
|
} else { |
|
|
row_buf[threadIdx.x] = init; |
|
|
|
|
|
} |
|
|
|
|
|
if (col2 < row_size) { |
|
|
row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]); |
|
|
row_idx_buf[num_threads_x + threadIdx.x] = col2; |
|
|
} else { |
|
|
row_buf[num_threads_x + threadIdx.x] = init; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) { |
|
|
binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op); |
|
|
} |
|
|
} |
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
for (int s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { |
|
|
if (row < num_rows && threadIdx.x < s) { |
|
|
int offset = (2 * threadIdx.x + 1) * d - 1; |
|
|
binary_op_update(row_buf[offset], row_buf[offset + d], row_idx_buf[offset], row_idx_buf[offset + d], binary_op); |
|
|
} |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
for (int s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { |
|
|
if (row < num_rows && threadIdx.x < s - 1) { |
|
|
int offset = 2 * (threadIdx.x + 1) * d - 1; |
|
|
binary_op_update(row_buf[offset], row_buf[offset + d], row_idx_buf[offset], row_idx_buf[offset + d], binary_op); |
|
|
} |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
if (row < num_rows) { |
|
|
if (col1 < row_size){ |
|
|
row_values[col1] = row_buf[threadIdx.x]; |
|
|
row_indices[col1] = row_idx_buf[threadIdx.x]; |
|
|
} |
|
|
if (col2 < row_size) { |
|
|
row_values[col2] = row_buf[num_threads_x + threadIdx.x]; |
|
|
row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x]; |
|
|
} |
|
|
} |
|
|
block_total = row_buf[2 * num_threads_x - 1]; |
|
|
block_idx_final = row_idx_buf[2 * num_threads_x - 1]; |
|
|
__syncthreads(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, class BinaryFunction> |
|
|
__global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_, |
|
|
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { |
|
|
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { |
|
|
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { |
|
|
scalar_t *self = self_ + orow * row_size * num_irows + irow; |
|
|
scalar_t *values = values_ + orow * row_size * num_irows + irow; |
|
|
int64_t *indices = indices_ + orow * row_size * num_irows + irow; |
|
|
scalar_t out = init; |
|
|
int64_t out_idx = 0; |
|
|
|
|
|
for (auto col = decltype(row_size){0}; col < row_size; ++col) { |
|
|
const auto val = c10::load(self); |
|
|
if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) { |
|
|
out = val; |
|
|
out_idx = col; |
|
|
} |
|
|
*values = out; |
|
|
*indices = out_idx; |
|
|
self += num_irows; |
|
|
values += num_irows; |
|
|
indices += num_irows; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
inline void check_fits_in_unsigned(int64_t val, const char* name) { |
|
|
constexpr auto umax = std::numeric_limits<uint32_t>::max(); |
|
|
TORCH_CHECK( |
|
|
val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); |
|
|
} |
|
|
|
|
|
|
|
|
template<typename scalar_t, class BinaryFunction> |
|
|
__host__ void scan_outer_dim_with_indices( |
|
|
const TensorBase& self, const TensorBase& values, const TensorBase& indices, |
|
|
int dim, scalar_t init, BinaryFunction binary_op) { |
|
|
int64_t row_size = self.size(dim); |
|
|
auto sizes = self.sizes(); |
|
|
|
|
|
|
|
|
const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); |
|
|
|
|
|
|
|
|
const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); |
|
|
|
|
|
|
|
|
check_fits_in_unsigned(num_irows, "num_irows"); |
|
|
check_fits_in_unsigned(num_orows, "num_orows"); |
|
|
check_fits_in_unsigned(row_size, "row_size"); |
|
|
|
|
|
|
|
|
dim3 threads(std::min(512, int(num_irows))); |
|
|
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; |
|
|
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); |
|
|
tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
|
|
self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(), |
|
|
num_orows, num_irows, row_size, init, binary_op); |
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
} |
|
|
|
|
|
template <typename scalar_t, class BinaryFunction> |
|
|
__host__ void scan_innermost_dim_with_indices( |
|
|
const TensorBase& self, const TensorBase& values, const TensorBase& indices, |
|
|
scalar_t init, BinaryFunction binary_op) { |
|
|
int ndim = self.dim(); |
|
|
|
|
|
int row_size = self.size(ndim - 1); |
|
|
int num_rows = self.numel() / row_size; |
|
|
|
|
|
dim3 threads(16, 32); |
|
|
dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y)))); |
|
|
|
|
|
tensor_kernel_scan_innermost_dim_with_indices<scalar_t, 16, 32><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
|
|
self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(), |
|
|
num_rows, row_size, init, binary_op); |
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
} |
|
|
|
|
|
template<typename scalar_t, typename BinaryFunction> |
|
|
void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, |
|
|
int64_t dim, scalar_t init, BinaryFunction binary_op) { |
|
|
int ndim = self.dim(); |
|
|
auto self_ = self.expect_contiguous(); |
|
|
TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous()); |
|
|
if (dim == ndim - 1) { |
|
|
scan_innermost_dim_with_indices<scalar_t>(*self_, values, indices, init, binary_op); |
|
|
} else { |
|
|
scan_outer_dim_with_indices<scalar_t>(*self_, values, indices, dim, init, binary_op); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, class BinaryOp> |
|
|
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, |
|
|
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, |
|
|
const scalar_t init, BinaryOp binary_op) |
|
|
{ |
|
|
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { |
|
|
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { |
|
|
scalar_t *src = src_ + orow * row_size * num_irows + irow; |
|
|
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; |
|
|
scalar_t acc = init; |
|
|
|
|
|
for (uint32_t col = 0; col < row_size; ++col) { |
|
|
acc = binary_op(acc, c10::load(src)); |
|
|
*tgt = acc; |
|
|
|
|
|
src += num_irows; |
|
|
tgt += num_irows; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename T, int num_threads_x, int num_threads_y, class BinaryFunction> |
|
|
__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_, |
|
|
const uint32_t num_rows, const uint32_t row_size, |
|
|
T init, BinaryFunction binary_op){ |
|
|
for (uint32_t block_row = blockIdx.x * blockDim.y; |
|
|
block_row < num_rows; |
|
|
block_row += blockDim.y * gridDim.x) { |
|
|
uint32_t row = block_row + threadIdx.y; |
|
|
T block_total = init; |
|
|
|
|
|
T *row_src = src_ + row * row_size; |
|
|
T *row_tgt = tgt_ + row * row_size; |
|
|
|
|
|
|
|
|
|
|
|
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { |
|
|
|
|
|
uint32_t col1 = block_col + threadIdx.x; |
|
|
uint32_t col2 = block_col + num_threads_x + threadIdx.x; |
|
|
if (row < num_rows) { |
|
|
if (col1 < row_size) { |
|
|
row_buf[threadIdx.x] = row_src[col1]; |
|
|
} else { |
|
|
row_buf[threadIdx.x] = init; |
|
|
} |
|
|
|
|
|
if (col2 < row_size) { |
|
|
row_buf[num_threads_x + threadIdx.x] = row_src[col2]; |
|
|
} else { |
|
|
row_buf[num_threads_x + threadIdx.x] = init; |
|
|
} |
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) { |
|
|
row_buf[0] = binary_op(row_buf[0], block_total); |
|
|
} |
|
|
} |
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { |
|
|
if (row < num_rows && threadIdx.x < s) { |
|
|
uint32_t offset = (2 * threadIdx.x + 1) * d - 1; |
|
|
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); |
|
|
} |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { |
|
|
if (row < num_rows && threadIdx.x < s - 1) { |
|
|
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1; |
|
|
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); |
|
|
} |
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
if (row < num_rows) { |
|
|
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; |
|
|
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; |
|
|
} |
|
|
block_total = row_buf[2 * num_threads_x - 1]; |
|
|
__syncthreads(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
template < |
|
|
typename T, |
|
|
int num_threads_x, |
|
|
int num_threads_y, |
|
|
class BinaryFunction> |
|
|
__global__ typename std::enable_if<!c10::is_complex<T>::value, void>::type |
|
|
tensor_kernel_scan_innermost_dim( |
|
|
T* tgt_, |
|
|
T* src_, |
|
|
const uint32_t num_rows, |
|
|
const uint32_t row_size, |
|
|
T init, |
|
|
BinaryFunction binary_op) { |
|
|
__shared__ T sbuf[num_threads_y][2 * num_threads_x]; |
|
|
T* row_buf = sbuf[threadIdx.y]; |
|
|
|
|
|
tensor_kernel_scan_innermost_dim_impl<T, num_threads_x, num_threads_y>( |
|
|
row_buf, tgt_, src_, num_rows, row_size, init, binary_op); |
|
|
} |
|
|
|
|
|
template < |
|
|
typename T, |
|
|
int num_threads_x, |
|
|
int num_threads_y, |
|
|
class BinaryFunction> |
|
|
__global__ typename std::enable_if<c10::is_complex<T>::value, void>::type |
|
|
tensor_kernel_scan_innermost_dim( |
|
|
T* tgt_, |
|
|
T* src_, |
|
|
const uint32_t num_rows, |
|
|
const uint32_t row_size, |
|
|
T init, |
|
|
BinaryFunction binary_op) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using base_t = typename scalar_value_type<T>::type; |
|
|
__shared__ base_t sbuf[num_threads_y][4 * num_threads_x]; |
|
|
|
|
|
T* row_buf = reinterpret_cast<T*>(sbuf[threadIdx.y]); |
|
|
|
|
|
tensor_kernel_scan_innermost_dim_impl<T, num_threads_x, num_threads_y>( |
|
|
row_buf, tgt_, src_, num_rows, row_size, init, binary_op); |
|
|
} |
|
|
|
|
|
|
|
|
template<typename scalar_t, class BinaryFunction> |
|
|
__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result, |
|
|
int dim, scalar_t init, BinaryFunction binary_op) { |
|
|
const int64_t row_size = self.size(dim); |
|
|
auto sizes = self.sizes(); |
|
|
|
|
|
|
|
|
const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); |
|
|
|
|
|
|
|
|
const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); |
|
|
|
|
|
dim3 threads(std::min(512, int(num_irows))); |
|
|
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; |
|
|
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); |
|
|
|
|
|
check_fits_in_unsigned(num_irows, "num_irows"); |
|
|
check_fits_in_unsigned(num_orows, "num_orows"); |
|
|
check_fits_in_unsigned(row_size, "row_size"); |
|
|
|
|
|
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
|
|
result.data_ptr<scalar_t>(), self.data_ptr<scalar_t>(), |
|
|
num_orows, num_irows, row_size, init, binary_op); |
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
} |
|
|
|
|
|
template <typename scalar_t, class BinaryFunction> |
|
|
void scan_innermost_dim(const TensorBase& self, const TensorBase& result, |
|
|
scalar_t init, BinaryFunction binary_op) { |
|
|
int64_t ndim = self.dim(); |
|
|
|
|
|
int64_t row_size = self.size(ndim - 1); |
|
|
int64_t num_rows = self.numel() / row_size; |
|
|
|
|
|
dim3 threads(16, 32); |
|
|
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; |
|
|
dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y}))); |
|
|
|
|
|
check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))"); |
|
|
check_fits_in_unsigned(row_size, "row_size"); |
|
|
|
|
|
tensor_kernel_scan_innermost_dim<scalar_t, 16, 32><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
|
|
result.data_ptr<scalar_t>(), self.data_ptr<scalar_t>(), |
|
|
num_rows, row_size, init, binary_op); |
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
} |
|
|
|
|
|
template<typename scalar_t, typename BinaryFunction> |
|
|
void scan_dim(const TensorBase& self, const TensorBase& result, |
|
|
int64_t dim, scalar_t init, BinaryFunction binary_op) { |
|
|
int ndim = self.dim(); |
|
|
auto self_ = self.expect_contiguous(); |
|
|
TORCH_INTERNAL_ASSERT(result.is_contiguous()); |
|
|
|
|
|
if (self.numel() == self.size(dim)) { |
|
|
cuda::cub::inclusive_scan(self_->data_ptr<scalar_t>(), result.data_ptr<scalar_t>(), binary_op, self.numel()); |
|
|
} else if (dim == ndim - 1) { |
|
|
scan_innermost_dim<scalar_t>(*self_, result, init, binary_op); |
|
|
} else { |
|
|
scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op); |
|
|
} |
|
|
} |
|
|
|
|
|
}} |
|
|
|