File size: 1,941 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
#include <cstdint>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
namespace at::native {
template<int alignment>
inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const out_ptr, char * const in_ptr, const size_t index_stride_bytes, const size_t element_size) {
using at::native::memory::get_alignment;
const auto index_element_size = iter.element_size(2);
//TensorIterator strides and sizes are ordered fastest moving to slowest moving,
//in contrast to regular sizes
// we need contiguous source and dst slices and aligned pointers and strides and slice size to do vectorized loads
// also we need idx to be expanded in the last dimension so we can copy entire slices
// and we need the src tensor to keep 0 stride from restriding
// (it could have been deleted by dimension collapse, in this case iterator would still be 2d
// but we cannot use fast path)
return iter.ndim() == 2 && iter.strides(2)[0]==0 && iter.strides(2)[1]==index_element_size &&
static_cast<size_t>(iter.strides(0)[0])==element_size &&
static_cast<size_t>(iter.strides(1)[0])==element_size && static_cast<size_t>(iter.strides(1)[1] == 0) &&
get_alignment(out_ptr) == alignment && get_alignment(in_ptr) == alignment &&
get_alignment(static_cast<size_t>(iter.shape()[0] * element_size)) == alignment &&
get_alignment(static_cast<size_t>(index_stride_bytes)) == alignment &&
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
}
template <int64_t Alignment, typename index_t>
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
bool allow_neg_indices=false);
}
|