| | #pragma once |
| | #include <cstdint> |
| | #include <c10/core/ScalarType.h> |
| | #include <ATen/cuda/CUDAConfig.h> |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | namespace at { |
| | namespace cuda { |
| | namespace cub { |
| |
|
| | inline int get_num_bits(uint64_t max_key) { |
| | int num_bits = 1; |
| | while (max_key > 1) { |
| | max_key >>= 1; |
| | num_bits++; |
| | } |
| | return num_bits; |
| | } |
| |
|
| | namespace detail { |
| |
|
| | |
| | |
| | |
| | template <int N> struct alignas(N) OpaqueType { char data[N]; }; |
| |
|
| | template<typename key_t, int value_size> |
| | void radix_sort_pairs_impl( |
| | const key_t *keys_in, key_t *keys_out, |
| | const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out, |
| | int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); |
| |
|
| | } |
| |
|
| | template<typename key_t, typename value_t> |
| | void radix_sort_pairs( |
| | const key_t *keys_in, key_t *keys_out, |
| | const value_t *values_in, value_t *values_out, |
| | int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) { |
| | static_assert(std::is_trivially_copyable<value_t>::value || |
| | AT_ROCM_ENABLED(), |
| | "radix_sort_pairs value type must be trivially copyable"); |
| | |
| | using opaque_t = detail::OpaqueType<sizeof(value_t)>; |
| | static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0, |
| | "This size of value_t is not instantiated. Please instantiate it in cub.cu" |
| | " and modify this check."); |
| | static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned"); |
| | detail::radix_sort_pairs_impl( |
| | keys_in, keys_out, |
| | reinterpret_cast<const opaque_t*>(values_in), |
| | reinterpret_cast<opaque_t*>(values_out), |
| | n, descending, begin_bit, end_bit); |
| | } |
| |
|
| | template<typename key_t> |
| | void radix_sort_keys( |
| | const key_t *keys_in, key_t *keys_out, |
| | int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8); |
| |
|
| | |
| | template <typename input_t, typename output_t> |
| | void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n); |
| |
|
| | template <typename scalar_t> |
| | void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { |
| | return inclusive_sum_truncating(input, output, n); |
| | } |
| |
|
| | |
| | template <typename input_t, typename output_t> |
| | void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n); |
| |
|
| | template <typename scalar_t> |
| | void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { |
| | return exclusive_sum_in_common_type(input, output, n); |
| | } |
| |
|
| | void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n); |
| | inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) { |
| | return mask_exclusive_sum( |
| | reinterpret_cast<const uint8_t*>(mask), output_idx, n); |
| | } |
| |
|
| | }}} |
| |
|