| | #pragma once |
| | #include <ATen/cuda/cub.h> |
| |
|
| | #include <cstddef> |
| | #include <type_traits> |
| | #include <iterator> |
| | #include <limits> |
| |
|
| | #include <c10/util/C++17.h> |
| |
|
| | #include <ATen/cuda/cub_definitions.cuh> |
| |
|
| | #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() |
| |
|
| | #include <cub/cub.cuh> |
| |
|
| | #else |
| |
|
| | |
| | |
| | #undef CUB_NS_POSTFIX |
| | #undef CUB_NS_PREFIX |
| | #undef CUB_NS_QUALIFIER |
| | #define CUB_NS_PREFIX namespace at_cuda_detail { |
| | #define CUB_NS_POSTFIX } |
| | #define CUB_NS_QUALIFIER ::at_cuda_detail::cub |
| | #include <cub/cub.cuh> |
| | #undef CUB_NS_POSTFIX |
| | #undef CUB_NS_PREFIX |
| | #undef CUB_NS_QUALIFIER |
| |
|
| | #endif |
| |
|
| | #include <ATen/cuda/Exceptions.h> |
| | #include <c10/cuda/CUDACachingAllocator.h> |
| | #include <c10/cuda/CUDAStream.h> |
| |
|
| | |
| | #define CUB_WRAPPER(func, ...) do { \ |
| | size_t temp_storage_bytes = 0; \ |
| | func(nullptr, temp_storage_bytes, __VA_ARGS__); \ |
| | auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \ |
| | auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ |
| | func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ |
| | AT_CUDA_CHECK(cudaGetLastError()); \ |
| | } while (false) |
| |
|
| | #ifdef USE_ROCM |
| | #define NO_ROCM(x) |
| | #define ROCM_HIPCUB(x) ::hipcub |
| | #else |
| | #define NO_ROCM(x) x |
| | #define ROCM_HIPCUB(x) x |
| | #endif |
| |
|
| | #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || \ |
| | (defined(USE_ROCM) && ROCM_VERSION >= 40500) |
| |
|
| | #if !defined(USE_ROCM) |
| | namespace at_cuda_detail { |
| | #endif |
| |
|
| | |
| |
|
| | template <> |
| | struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16> |
| | { |
| | static __host__ __device__ __forceinline__ c10::BFloat16 Max() { |
| | unsigned short max_word = 0x7F7F; |
| | return reinterpret_cast<c10::BFloat16&>(max_word); |
| | } |
| |
|
| | static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() { |
| | unsigned short lowest_word = 0xFF7F; |
| | return reinterpret_cast<c10::BFloat16&>(lowest_word); |
| | } |
| | }; |
| |
|
| | template <> |
| | struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: |
| | ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; |
| |
|
| | #if !defined(USE_ROCM) |
| | } |
| | #endif |
| |
|
| | #endif |
| |
|
| | #if !defined(USE_ROCM) |
| | namespace at { namespace native { |
| | namespace cub = ::at_cuda_detail::cub; |
| | }} |
| | #endif |
| |
|
| | namespace at { |
| | namespace cuda { |
| | namespace cub { |
| |
|
| | namespace detail { |
| |
|
| | template<typename T> |
| | struct cuda_type { |
| | using type = T; |
| | }; |
| | template<> |
| | struct cuda_type<c10::Half> { |
| | using type = __half; |
| | }; |
| |
|
| | #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() |
| |
|
| | template<> |
| | struct cuda_type<c10::BFloat16> { |
| | using type = __nv_bfloat16; |
| | }; |
| |
|
| | #elif (defined(USE_ROCM) && ROCM_VERSION >= 40500) |
| |
|
| | template<> |
| | struct cuda_type<c10::BFloat16> { |
| | using type = hip_bfloat16; |
| | }; |
| |
|
| | #endif |
| |
|
| | } |
| |
|
| | template<typename key_t, typename value_t, typename OffsetIteratorT> |
| | inline void segmented_sort_pairs( |
| | const key_t *keys_in, key_t *keys_out, |
| | const value_t *values_in, value_t *values_out, |
| | int64_t num_elements, int64_t num_segments, |
| | OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets, |
| | bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 |
| | ) { |
| | TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(), |
| | "cub sort does not support sorting more than INT_MAX elements"); |
| | TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(), |
| | "cub sort does not support sorting more than INT_MAX elements"); |
| | using key_t_ = typename detail::cuda_type<key_t>::type; |
| |
|
| | auto allocator = c10::cuda::CUDACachingAllocator::get(); |
| | c10::DataPtr keys_out_owner; |
| |
|
| | if (keys_out == nullptr) { |
| | keys_out_owner = allocator->allocate(num_elements * sizeof(key_t)); |
| | keys_out = reinterpret_cast<key_t *>(keys_out_owner.get()); |
| | } |
| |
|
| | const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in); |
| | key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out); |
| |
|
| | if (descending) { |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending, |
| | keys_in_, keys_out_, values_in, values_out, |
| | num_elements, num_segments, begin_offsets, end_offsets, |
| | begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); |
| | } else { |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs, |
| | keys_in_, keys_out_, values_in, values_out, |
| | num_elements, num_segments, begin_offsets, end_offsets, |
| | begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); |
| | } |
| | } |
| |
|
| | #if CUB_SUPPORTS_UNIQUE_BY_KEY() |
| | template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename KeysOutputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> |
| | inline void unique_by_key( |
| | KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, |
| | KeysOutputIteratorT keys_out, ValuesOutputIteratorT values_out, |
| | NumSelectedIteratorT num_selected, int64_t num_input_items) |
| | { |
| | |
| | constexpr bool null_keys_out = std::is_same<KeysOutputIteratorT, std::nullptr_t>::value; |
| | using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type; |
| | using RealKeysOutputIteratorT = typename std::conditional<null_keys_out, KeyT *, KeysOutputIteratorT>::type; |
| | RealKeysOutputIteratorT keys_out_; |
| | auto allocator = c10::cuda::CUDACachingAllocator::get(); |
| | c10::DataPtr keys_out_owner; |
| | c10::guts::if_constexpr<null_keys_out>( |
| | [&](auto _) { |
| | keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT)); |
| | keys_out_ = static_cast<KeyT *>(keys_out_owner.get()); |
| | }, |
| | [&](auto _) { |
| | keys_out_ = keys_out; |
| | } |
| | ); |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, |
| | keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); |
| | } |
| | #endif |
| |
|
| | namespace impl { |
| |
|
| | template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT> |
| | C10_LAUNCH_BOUNDS_1(1) |
| | __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){ |
| | |
| | using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type; |
| | *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); |
| | } |
| |
|
| | #if !CUB_SUPPORTS_FUTURE_VALUE() |
| | template<typename ValueT, typename InputIteratorT> |
| | struct chained_iterator { |
| | using iterator_category = std::random_access_iterator_tag; |
| | using difference_type = std::ptrdiff_t; |
| | using value_type = ValueT; |
| | using pointer = ValueT*; |
| | using reference = ValueT&; |
| |
|
| | InputIteratorT iter; |
| | ValueT *first; |
| | difference_type offset = 0; |
| |
|
| | __device__ ValueT operator[](difference_type i) { |
| | i += offset; |
| | if (i == 0) { |
| | return *first; |
| | } else { |
| | return ValueT(iter[i - 1]); |
| | } |
| | } |
| | __device__ chained_iterator operator+(difference_type i) { |
| | return chained_iterator{iter, first, i}; |
| | } |
| | __device__ ValueT operator*() { |
| | return (*this)[0]; |
| | } |
| | }; |
| | #endif |
| |
|
| | |
| | |
| | constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; |
| | } |
| |
|
| | |
| | |
| | |
| | template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size> |
| | inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) { |
| | #if defined(USE_ROCM) && (ROCM_VERSION >= 50000) |
| | |
| | CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan, |
| | input, |
| | output, |
| | scan_op, |
| | num_items, |
| | at::cuda::getCurrentCUDAStream()); |
| | C10_HIP_KERNEL_LAUNCH_CHECK(); |
| | #else |
| | |
| | |
| | |
| | int size_cub = std::min<int64_t>(num_items, max_cub_size); |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |
| | input, |
| | output, |
| | scan_op, |
| | size_cub, |
| | at::cuda::getCurrentCUDAStream()); |
| | C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| | using input_t = typename std::iterator_traits<InputIteratorT>::value_type; |
| | for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) { |
| | auto allocator = c10::cuda::CUDACachingAllocator::get(); |
| | c10::DataPtr first_elem = allocator->allocate(sizeof(input_t)); |
| | auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get()); |
| |
|
| | size_cub = std::min<int64_t>(num_items - i, max_cub_size); |
| | impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( |
| | output + i - 1, |
| | input + i, |
| | first_elem_ptr, |
| | scan_op); |
| | C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| | #if !CUB_SUPPORTS_FUTURE_VALUE() |
| | using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>; |
| | using tuple = typename ArgIndexInputIterator::value_type; |
| | auto input_iter_transform = [=] __device__ (const tuple &x)->input_t { |
| | if (x.key == 0) { |
| | return *first_elem_ptr; |
| | } else { |
| | return x.value; |
| | } |
| | }; |
| | auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>( |
| | ArgIndexInputIterator(input + i), input_iter_transform); |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |
| | input_, |
| | output + i, |
| | scan_op, |
| | size_cub, |
| | at::cuda::getCurrentCUDAStream()); |
| | #else |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |
| | input + i + 1, |
| | output + i, |
| | scan_op, |
| | ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), |
| | size_cub, |
| | at::cuda::getCurrentCUDAStream()); |
| | #endif |
| | } |
| | #endif |
| | } |
| |
|
| | template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size> |
| | inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) { |
| | #if defined(USE_ROCM) && (ROCM_VERSION >= 50000) |
| | |
| | CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan, |
| | input, |
| | output, |
| | scan_op, |
| | init_value, |
| | num_items, |
| | at::cuda::getCurrentCUDAStream()); |
| | C10_HIP_KERNEL_LAUNCH_CHECK(); |
| | #else |
| | |
| | |
| | |
| | int size_cub = std::min<int64_t>(num_items, max_cub_size); |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |
| | input, |
| | output, |
| | scan_op, |
| | init_value, |
| | size_cub, |
| | at::cuda::getCurrentCUDAStream()); |
| | C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| | for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) { |
| | auto allocator = c10::cuda::CUDACachingAllocator::get(); |
| | c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT)); |
| | auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get()); |
| |
|
| | size_cub = std::min<int64_t>(num_items - i, max_cub_size); |
| | impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( |
| | output + i - 1, |
| | input + i - 1, |
| | first_elem_ptr, |
| | scan_op); |
| | C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| | #if !CUB_SUPPORTS_FUTURE_VALUE() |
| | auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{ |
| | input + i, first_elem_ptr}; |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, |
| | input_, |
| | output + i, |
| | scan_op, |
| | size_cub, |
| | at::cuda::getCurrentCUDAStream()); |
| | #else |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, |
| | input + i, |
| | output + i, |
| | scan_op, |
| | ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), |
| | size_cub, |
| | at::cuda::getCurrentCUDAStream()); |
| | #endif |
| | } |
| | #endif |
| | } |
| |
|
| | #if CUB_SUPPORTS_SCAN_BY_KEY() |
| |
|
| | template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> |
| | inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { |
| | TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), |
| | "cub InclusiveSumByKey does not support more than INT_MAX elements"); |
| | CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey, |
| | keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); |
| | } |
| |
|
| | template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT> |
| | inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) { |
| | TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), |
| | "cub InclusiveSumByKey does not support more than INT_MAX elements"); |
| | CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey, |
| | keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); |
| | } |
| |
|
| | #endif |
| |
|
| | template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> |
| | void unique(InputIteratorT input, OutputIteratorT output, |
| | NumSelectedIteratorT num_selected_out, int64_t num_items) { |
| | TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), |
| | "cub unique does not support more than INT_MAX elements"); |
| | CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, |
| | input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); |
| | } |
| |
|
| | template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT, |
| | typename LengthOutputIteratorT> |
| | void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out, |
| | LengthOutputIteratorT length_out, int64_t num_items) { |
| | TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), |
| | "cub run_length_encode does not support more than INT_MAX elements"); |
| | CUB_WRAPPER( |
| | NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, |
| | input, output, counts_out, length_out, num_items, |
| | at::cuda::getCurrentCUDAStream()); |
| | } |
| |
|
| | template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T> |
| | void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) { |
| | TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), |
| | "cub reduce does not support more than INT_MAX elements"); |
| | CUB_WRAPPER( |
| | NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce, |
| | input, output, num_items, op, init, |
| | at::cuda::getCurrentCUDAStream()); |
| |
|
| | } |
| |
|
| | }}} |
| |
|