| /****************************************************************************** |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are met: |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * * Neither the name of the NVIDIA CORPORATION nor the |
| * names of its contributors may be used to endorse or promote products |
| * derived from this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| * |
| ******************************************************************************/ |
| |
| /** |
| * \file |
| * cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block |
| */ |
| |
| #pragma once |
| |
| #include <stdint.h> |
| |
| #include "../thread/thread_reduce.cuh" |
| #include "../thread/thread_scan.cuh" |
| #include "../block/block_scan.cuh" |
| #include "../block/radix_rank_sort_operations.cuh" |
| #include "../config.cuh" |
| #include "../util_ptx.cuh" |
| #include "../util_type.cuh" |
| |
| #include <cuda/std/type_traits> |
| |
| CUB_NAMESPACE_BEGIN |
| |
| |
| /** |
| * \brief Radix ranking algorithm, the algorithm used to implement stable ranking of the |
| * keys from a single tile. Note that different ranking algorithms require different |
| * initial arrangements of keys to function properly. |
| */ |
| enum RadixRankAlgorithm |
| { |
| /** Ranking using the BlockRadixRank algorithm with MEMOIZE_OUTER_SCAN == false. It |
| * uses thread-private histograms, and thus uses more shared memory. Requires blocked |
| * arrangement of keys. Does not support count callbacks. */ |
| RADIX_RANK_BASIC, |
| /** Ranking using the BlockRadixRank algorithm with MEMOIZE_OUTER_SCAN == |
| * true. Similar to RADIX_RANK BASIC, it requires blocked arrangement of |
| * keys and does not support count callbacks.*/ |
| RADIX_RANK_MEMOIZE, |
| /** Ranking using the BlockRadixRankMatch algorithm. It uses warp-private |
| * histograms and matching for ranking the keys in a single warp. Therefore, |
| * it uses less shared memory compared to RADIX_RANK_BASIC. It requires |
| * warp-striped key arrangement and supports count callbacks. */ |
| RADIX_RANK_MATCH, |
| /** Ranking using the BlockRadixRankMatchEarlyCounts algorithm with |
| * MATCH_ALGORITHM == WARP_MATCH_ANY. An alternative implementation of |
| * match-based ranking that computes bin counts early. Because of this, it |
| * works better with onesweep sorting, which requires bin counts for |
| * decoupled look-back. Assumes warp-striped key arrangement and supports |
| * count callbacks.*/ |
| RADIX_RANK_MATCH_EARLY_COUNTS_ANY, |
| /** Ranking using the BlockRadixRankEarlyCounts algorithm with |
| * MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR. It uses extra space in shared |
| * memory to generate warp match masks using atomicOr(). This is faster when |
| * there are few matches, but can lead to slowdowns if the number of |
| * matching keys among warp lanes is high. Assumes warp-striped key |
| * arrangement and supports count callbacks. */ |
| RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR |
| }; |
| |
| |
| /** Empty callback implementation */ |
| template <int BINS_PER_THREAD> |
| struct BlockRadixRankEmptyCallback |
| { |
| __device__ __forceinline__ void operator()(int (&bins)[BINS_PER_THREAD]) {} |
| }; |
| |
| |
| #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document |
| namespace detail |
| { |
| |
| template <int Bits, int PartialWarpThreads, int PartialWarpId> |
| struct warp_in_block_matcher_t |
| { |
| static __device__ std::uint32_t match_any(std::uint32_t label, std::uint32_t warp_id) |
| { |
| if (warp_id == static_cast<std::uint32_t>(PartialWarpId)) |
| { |
| return MatchAny<Bits, PartialWarpThreads>(label); |
| } |
| |
| return MatchAny<Bits>(label); |
| } |
| }; |
| |
| template <int Bits, int PartialWarpId> |
| struct warp_in_block_matcher_t<Bits, 0, PartialWarpId> |
| { |
| static __device__ std::uint32_t match_any(std::uint32_t label, std::uint32_t warp_id) |
| { |
| return MatchAny<Bits>(label); |
| } |
| }; |
| |
| } // namespace detail |
| #endif // DOXYGEN_SHOULD_SKIP_THIS |
| |
| |
| /** |
| * \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block. |
| * \ingroup BlockModule |
| * |
| * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension |
| * \tparam RADIX_BITS The number of radix bits per digit place |
| * \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low |
| * \tparam MEMOIZE_OUTER_SCAN <b>[optional]</b> Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details. |
| * \tparam INNER_SCAN_ALGORITHM <b>[optional]</b> The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS) |
| * \tparam SMEM_CONFIG <b>[optional]</b> Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte) |
| * \tparam BLOCK_DIM_Y <b>[optional]</b> The thread block length in threads along the Y dimension (default: 1) |
| * \tparam BLOCK_DIM_Z <b>[optional]</b> The thread block length in threads along the Z dimension (default: 1) |
| * \tparam LEGACY_PTX_ARCH <b>[optional]</b> Unused. |
| * |
| * \par Overview |
| * Blah... |
| * - Keys must be in a form suitable for radix ranking (i.e., unsigned bits). |
| * - \blocked |
| * |
| * \par Performance Considerations |
| * - \granularity |
| * |
| * \par |
| * \code |
| * #include <cub/cub.cuh> |
| * |
| * __global__ void ExampleKernel(...) |
| * { |
| * constexpr int block_threads = 2; |
| * constexpr int radix_bits = 5; |
| * |
| * // Specialize BlockRadixRank for a 1D block of 2 threads |
| * using block_radix_rank = cub::BlockRadixRank<block_threads, radix_bits>; |
| * using storage_t = typename block_radix_rank::TempStorage; |
| * |
| * // Allocate shared memory for BlockRadixSort |
| * __shared__ storage_t temp_storage; |
| * |
| * // Obtain a segment of consecutive items that are blocked across threads |
| * int keys[2]; |
| * int ranks[2]; |
| * ... |
| * |
| * cub::BFEDigitExtractor<int> extractor(0, radix_bits); |
| * block_radix_rank(temp_storage).RankKeys(keys, ranks, extractor); |
| * |
| * ... |
| * \endcode |
| * Suppose the set of input `keys` across the block of threads is `{ [16,10], [9,11] }`. |
| * The corresponding output `ranks` in those threads will be `{ [3,1], [0,2] }`. |
| * |
| * \par Re-using dynamically allocating shared memory |
| * The following example under the examples/block folder illustrates usage of |
| * dynamically shared memory with BlockReduce and how to re-purpose |
| * the same memory region: |
| * <a href="../../examples/block/example_block_reduce_dyn_smem.cu">example_block_reduce_dyn_smem.cu</a> |
| * |
| * This example can be easily adapted to the storage required by BlockRadixRank. |
| */ |
| template < |
| int BLOCK_DIM_X, |
| int RADIX_BITS, |
| bool IS_DESCENDING, |
| bool MEMOIZE_OUTER_SCAN = true, |
| BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, |
| cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, |
| int BLOCK_DIM_Y = 1, |
| int BLOCK_DIM_Z = 1, |
| int LEGACY_PTX_ARCH = 0> |
| class BlockRadixRank |
| { |
| private: |
| |
| /****************************************************************************** |
| * Type definitions and constants |
| ******************************************************************************/ |
| |
| // Integer type for digit counters (to be packed into words of type PackedCounters) |
| using DigitCounter = unsigned short; |
| |
| // Integer type for packing DigitCounters into columns of shared memory banks |
| using PackedCounter = |
| cub::detail::conditional_t<SMEM_CONFIG == cudaSharedMemBankSizeEightByte, |
| unsigned long long, |
| unsigned int>; |
| |
| static constexpr DigitCounter max_tile_size = ::cuda::std::numeric_limits<DigitCounter>::max(); |
| |
| enum |
| { |
| // The thread block size in threads |
| BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, |
| |
| RADIX_DIGITS = 1 << RADIX_BITS, |
| |
| LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(0), |
| WARP_THREADS = 1 << LOG_WARP_THREADS, |
| WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, |
| |
| BYTES_PER_COUNTER = sizeof(DigitCounter), |
| LOG_BYTES_PER_COUNTER = Log2<BYTES_PER_COUNTER>::VALUE, |
| |
| PACKING_RATIO = static_cast<int>(sizeof(PackedCounter) / sizeof(DigitCounter)), |
| LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE, |
| |
| LOG_COUNTER_LANES = CUB_MAX((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0), // Always at least one lane |
| COUNTER_LANES = 1 << LOG_COUNTER_LANES, |
| |
| // The number of packed counters per thread (plus one for padding) |
| PADDED_COUNTER_LANES = COUNTER_LANES + 1, |
| RAKING_SEGMENT = PADDED_COUNTER_LANES, |
| }; |
| |
| public: |
|
|
| enum |
| { |
| /// Number of bin-starting offsets tracked per thread |
| BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), |
| }; |
| |
| private: |
|
|
|
|
| /// BlockScan type |
| typedef BlockScan< |
| PackedCounter, |
| BLOCK_DIM_X, |
| INNER_SCAN_ALGORITHM, |
| BLOCK_DIM_Y, |
| BLOCK_DIM_Z> |
| BlockScan; |
| |
|
|
| /// Shared memory storage layout type for BlockRadixRank |
| struct __align__(16) _TempStorage |
| { |
| union Aliasable |
| { |
| DigitCounter digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO]; |
| PackedCounter raking_grid[BLOCK_THREADS][RAKING_SEGMENT]; |
| |
| } aliasable; |
| |
| // Storage for scanning local ranks |
| typename BlockScan::TempStorage block_scan; |
| }; |
| |
|
|
| /****************************************************************************** |
| * Thread fields |
| ******************************************************************************/ |
| |
| /// Shared storage reference |
| _TempStorage &temp_storage; |
| |
| /// Linear thread-id |
| unsigned int linear_tid; |
| |
| /// Copy of raking segment, promoted to registers |
| PackedCounter cached_segment[RAKING_SEGMENT]; |
| |
|
|
| /****************************************************************************** |
| * Utility methods |
| ******************************************************************************/ |
| |
| /** |
| * Internal storage allocator |
| */ |
| __device__ __forceinline__ _TempStorage& PrivateStorage() |
| { |
| __shared__ _TempStorage private_storage; |
| return private_storage; |
| } |
| |
|
|
| /** |
| * Performs upsweep raking reduction, returning the aggregate |
| */ |
| __device__ __forceinline__ PackedCounter Upsweep() |
| { |
| PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; |
| PackedCounter *raking_ptr; |
| |
| if (MEMOIZE_OUTER_SCAN) |
| { |
| // Copy data into registers |
| #pragma unroll |
| for (int i = 0; i < RAKING_SEGMENT; i++) |
| { |
| cached_segment[i] = smem_raking_ptr[i]; |
| } |
| raking_ptr = cached_segment; |
| } |
| else |
| { |
| raking_ptr = smem_raking_ptr; |
| } |
| |
| return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum()); |
| } |
| |
|
|
| /// Performs exclusive downsweep raking scan |
| __device__ __forceinline__ void ExclusiveDownsweep( |
| PackedCounter raking_partial) |
| { |
| PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; |
| |
| PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ? |
| cached_segment : |
| smem_raking_ptr; |
| |
| // Exclusive raking downsweep scan |
| internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial); |
| |
| if (MEMOIZE_OUTER_SCAN) |
| { |
| // Copy data back to smem |
| #pragma unroll |
| for (int i = 0; i < RAKING_SEGMENT; i++) |
| { |
| smem_raking_ptr[i] = cached_segment[i]; |
| } |
| } |
| } |
| |
|
|
| /** |
| * Reset shared memory digit counters |
| */ |
| __device__ __forceinline__ void ResetCounters() |
| { |
| // Reset shared memory digit counters |
| #pragma unroll |
| for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++) |
| { |
| *((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0; |
| } |
| } |
| |
|
|
| /** |
| * Block-scan prefix callback |
| */ |
| struct PrefixCallBack |
| { |
| __device__ __forceinline__ PackedCounter operator()(PackedCounter block_aggregate) |
| { |
| PackedCounter block_prefix = 0; |
| |
| // Propagate totals in packed fields |
| #pragma unroll |
| for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++) |
| { |
| block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED); |
| } |
| |
| return block_prefix; |
| } |
| }; |
| |
|
|
| /** |
| * Scan shared memory digit counters. |
| */ |
| __device__ __forceinline__ void ScanCounters() |
| { |
| // Upsweep scan |
| PackedCounter raking_partial = Upsweep(); |
| |
| // Compute exclusive sum |
| PackedCounter exclusive_partial; |
| PrefixCallBack prefix_call_back; |
| BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back); |
| |
| // Downsweep scan with exclusive partial |
| ExclusiveDownsweep(exclusive_partial); |
| } |
| |
| public: |
|
|
| /// \smemstorage{BlockScan} |
| struct TempStorage : Uninitialized<_TempStorage> {}; |
| |
|
|
| /******************************************************************//** |
| * \name Collective constructors |
| *********************************************************************/ |
| //@{ |
| |
| /** |
| * \brief Collective constructor using a private static allocation of shared memory as temporary storage. |
| */ |
| __device__ __forceinline__ BlockRadixRank() |
| : |
| temp_storage(PrivateStorage()), |
| linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) |
| {} |
| |
|
|
| /** |
| * \brief Collective constructor using the specified memory allocation as temporary storage. |
| */ |
| __device__ __forceinline__ BlockRadixRank( |
| TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage |
| : |
| temp_storage(temp_storage.Alias()), |
| linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) |
| {} |
| |
|
|
| //@} end member group |
| /******************************************************************//** |
| * \name Raking |
| *********************************************************************/ |
| //@{ |
| |
| /** |
| * \brief Rank keys. |
| */ |
| template < |
| typename UnsignedBits, |
| int KEYS_PER_THREAD, |
| typename DigitExtractorT> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile |
| int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile |
| DigitExtractorT digit_extractor) ///< [in] The digit extractor |
| { |
| static_assert(BLOCK_THREADS * KEYS_PER_THREAD <= max_tile_size, |
| "DigitCounter type is too small to hold this number of keys"); |
| |
| DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit |
| DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem |
| |
| // Reset shared memory digit counters |
| ResetCounters(); |
| |
| #pragma unroll |
| for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) |
| { |
| // Get digit |
| std::uint32_t digit = digit_extractor.Digit(keys[ITEM]); |
| |
| // Get sub-counter |
| std::uint32_t sub_counter = digit >> LOG_COUNTER_LANES; |
| |
| // Get counter lane |
| std::uint32_t counter_lane = digit & (COUNTER_LANES - 1); |
| |
| if (IS_DESCENDING) |
| { |
| sub_counter = PACKING_RATIO - 1 - sub_counter; |
| counter_lane = COUNTER_LANES - 1 - counter_lane; |
| } |
| |
| // Pointer to smem digit counter |
| digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter]; |
| |
| // Load thread-exclusive prefix |
| thread_prefixes[ITEM] = *digit_counters[ITEM]; |
| |
| // Store inclusive prefix |
| *digit_counters[ITEM] = thread_prefixes[ITEM] + 1; |
| } |
| |
| CTA_SYNC(); |
| |
| // Scan shared memory counters |
| ScanCounters(); |
| |
| CTA_SYNC(); |
| |
| // Extract the local ranks of each key |
| #pragma unroll |
| for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) |
| { |
| // Add in thread block exclusive prefix |
| ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM]; |
| } |
| } |
| |
|
|
| /** |
| * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. |
| */ |
| template < |
| typename UnsignedBits, |
| int KEYS_PER_THREAD, |
| typename DigitExtractorT> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile |
| int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) |
| DigitExtractorT digit_extractor, ///< [in] The digit extractor |
| int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] |
| { |
| static_assert(BLOCK_THREADS * KEYS_PER_THREAD <= max_tile_size, |
| "DigitCounter type is too small to hold this number of keys"); |
| |
| // Rank keys |
| RankKeys(keys, ranks, digit_extractor); |
| |
| // Get the inclusive and exclusive digit totals corresponding to the calling thread. |
| #pragma unroll |
| for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) |
| { |
| int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; |
| |
| if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) |
| { |
| if (IS_DESCENDING) |
| bin_idx = RADIX_DIGITS - bin_idx - 1; |
| |
| // Obtain ex/inclusive digit counts. (Unfortunately these all reside in the |
| // first counter column, resulting in unavoidable bank conflicts.) |
| unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1)); |
| unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES); |
| |
| exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter]; |
| } |
| } |
| } |
| }; |
| |
|
|
|
|
|
|
|
|
| /** |
| * Radix-rank using match.any |
| */ |
| template < |
| int BLOCK_DIM_X, |
| int RADIX_BITS, |
| bool IS_DESCENDING, |
| BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, |
| int BLOCK_DIM_Y = 1, |
| int BLOCK_DIM_Z = 1, |
| int LEGACY_PTX_ARCH = 0> |
| class BlockRadixRankMatch |
| { |
| private: |
| |
| /****************************************************************************** |
| * Type definitions and constants |
| ******************************************************************************/ |
| |
| typedef int32_t RankT; |
| typedef int32_t DigitCounterT; |
| |
| enum |
| { |
| // The thread block size in threads |
| BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, |
| |
| RADIX_DIGITS = 1 << RADIX_BITS, |
| |
| LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(0), |
| WARP_THREADS = 1 << LOG_WARP_THREADS, |
| PARTIAL_WARP_THREADS = BLOCK_THREADS % WARP_THREADS, |
| WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, |
| |
| PADDED_WARPS = ((WARPS & 0x1) == 0) ? |
| WARPS + 1 : |
| WARPS, |
| |
| COUNTERS = PADDED_WARPS * RADIX_DIGITS, |
| RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS, |
| PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ? |
| RAKING_SEGMENT + 1 : |
| RAKING_SEGMENT, |
| }; |
| |
| public: |
|
|
| enum |
| { |
| /// Number of bin-starting offsets tracked per thread |
| BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), |
| }; |
| |
| private: |
|
|
| /// BlockScan type |
| typedef BlockScan< |
| DigitCounterT, |
| BLOCK_THREADS, |
| INNER_SCAN_ALGORITHM, |
| BLOCK_DIM_Y, |
| BLOCK_DIM_Z> |
| BlockScanT; |
| |
|
|
| /// Shared memory storage layout type for BlockRadixRank |
| struct __align__(16) _TempStorage |
| { |
| typename BlockScanT::TempStorage block_scan; |
| |
| union __align__(16) Aliasable |
| { |
| volatile DigitCounterT warp_digit_counters[RADIX_DIGITS][PADDED_WARPS]; |
| DigitCounterT raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT]; |
| |
| } aliasable; |
| }; |
| |
|
|
| /****************************************************************************** |
| * Thread fields |
| ******************************************************************************/ |
| |
| /// Shared storage reference |
| _TempStorage &temp_storage; |
| |
| /// Linear thread-id |
| unsigned int linear_tid; |
| |
|
|
|
|
| public: |
|
|
| /// \smemstorage{BlockScan} |
| struct TempStorage : Uninitialized<_TempStorage> {}; |
| |
|
|
| /******************************************************************//** |
| * \name Collective constructors |
| *********************************************************************/ |
| //@{ |
| |
|
|
| /** |
| * \brief Collective constructor using the specified memory allocation as temporary storage. |
| */ |
| __device__ __forceinline__ BlockRadixRankMatch( |
| TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage |
| : |
| temp_storage(temp_storage.Alias()), |
| linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) |
| {} |
| |
|
|
| //@} end member group |
| /******************************************************************//** |
| * \name Raking |
| *********************************************************************/ |
| //@{ |
| |
| /** \brief Computes the count of keys for each digit value, and calls the |
| * callback with the array of key counts. |
| |
| * @tparam CountsCallback The callback type. It should implement an instance |
| * overload of operator()(int (&bins)[BINS_TRACKED_PER_THREAD]), where bins |
| * is an array of key counts for each digit value distributed in block |
| * distribution among the threads of the thread block. Key counts can be |
| * used, to update other data structures in global or shared |
| * memory. Depending on the implementation of the ranking algoirhtm |
| * (see BlockRadixRankMatchEarlyCounts), key counts may become available |
| * early, therefore, they are returned through a callback rather than a |
| * separate output parameter of RankKeys(). |
| */ |
| template <int KEYS_PER_THREAD, typename CountsCallback> |
| __device__ __forceinline__ void CallBack(CountsCallback callback) |
| { |
| int bins[BINS_TRACKED_PER_THREAD]; |
| // Get count for each digit |
| #pragma unroll |
| for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) |
| { |
| int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; |
| const int TILE_ITEMS = KEYS_PER_THREAD * BLOCK_THREADS; |
| |
| if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) |
| { |
| if (IS_DESCENDING) |
| { |
| bin_idx = RADIX_DIGITS - bin_idx - 1; |
| bins[track] = (bin_idx > 0 ? |
| temp_storage.aliasable.warp_digit_counters[bin_idx - 1][0] : TILE_ITEMS) - |
| temp_storage.aliasable.warp_digit_counters[bin_idx][0]; |
| } |
| else |
| { |
| bins[track] = (bin_idx < RADIX_DIGITS - 1 ? |
| temp_storage.aliasable.warp_digit_counters[bin_idx + 1][0] : TILE_ITEMS) - |
| temp_storage.aliasable.warp_digit_counters[bin_idx][0]; |
| } |
| } |
| } |
| callback(bins); |
| } |
| |
| /** |
| * \brief Rank keys. |
| */ |
| template < |
| typename UnsignedBits, |
| int KEYS_PER_THREAD, |
| typename DigitExtractorT, |
| typename CountsCallback> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile |
| int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile |
| DigitExtractorT digit_extractor, ///< [in] The digit extractor |
| CountsCallback callback) |
| { |
| // Initialize shared digit counters |
| |
| #pragma unroll |
| for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) |
| temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0; |
| |
| CTA_SYNC(); |
| |
| // Each warp will strip-mine its section of input, one strip at a time |
| |
| volatile DigitCounterT *digit_counters[KEYS_PER_THREAD]; |
| uint32_t warp_id = linear_tid >> LOG_WARP_THREADS; |
| uint32_t lane_mask_lt = LaneMaskLt(); |
| |
| #pragma unroll |
| for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) |
| { |
| // My digit |
| std::uint32_t digit = digit_extractor.Digit(keys[ITEM]); |
| |
| if (IS_DESCENDING) |
| digit = RADIX_DIGITS - digit - 1; |
| |
| // Mask of peers who have same digit as me |
| uint32_t peer_mask = |
| detail::warp_in_block_matcher_t< |
| RADIX_BITS, |
| PARTIAL_WARP_THREADS, |
| WARPS - 1>::match_any(digit, warp_id); |
| |
| // Pointer to smem digit counter for this key |
| digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id]; |
| |
| // Number of occurrences in previous strips |
| DigitCounterT warp_digit_prefix = *digit_counters[ITEM]; |
| |
| // Warp-sync |
| WARP_SYNC(0xFFFFFFFF); |
| |
| // Number of peers having same digit as me |
| int32_t digit_count = __popc(peer_mask); |
| |
| // Number of lower-ranked peers having same digit seen so far |
| int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt); |
| |
| if (peer_digit_prefix == 0) |
| { |
| // First thread for each digit updates the shared warp counter |
| *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count); |
| } |
| |
| // Warp-sync |
| WARP_SYNC(0xFFFFFFFF); |
| |
| // Number of prior keys having same digit |
| ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix); |
| } |
| |
| CTA_SYNC(); |
| |
| // Scan warp counters |
| |
| DigitCounterT scan_counters[PADDED_RAKING_SEGMENT]; |
| |
| #pragma unroll |
| for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) |
| scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM]; |
| |
| BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters); |
| |
| #pragma unroll |
| for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) |
| temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM]; |
| |
| CTA_SYNC(); |
| if (!std::is_same< |
| CountsCallback, |
| BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>>::value) |
| { |
| CallBack<KEYS_PER_THREAD>(callback); |
| } |
| |
| // Seed ranks with counter values from previous warps |
| #pragma unroll |
| for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) |
| ranks[ITEM] += *digit_counters[ITEM]; |
| } |
| |
| template < |
| typename UnsignedBits, |
| int KEYS_PER_THREAD, |
| typename DigitExtractorT> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], |
| DigitExtractorT digit_extractor) |
| { |
| RankKeys(keys, ranks, digit_extractor, |
| BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>()); |
| } |
| |
| /** |
| * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. |
| */ |
| template < |
| typename UnsignedBits, |
| int KEYS_PER_THREAD, |
| typename DigitExtractorT, |
| typename CountsCallback> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile |
| int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) |
| DigitExtractorT digit_extractor, ///< [in] The digit extractor |
| int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD], ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] |
| CountsCallback callback) |
| { |
| RankKeys(keys, ranks, digit_extractor, callback); |
| |
| // Get exclusive count for each digit |
| #pragma unroll |
| for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) |
| { |
| int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; |
| |
| if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) |
| { |
| if (IS_DESCENDING) |
| bin_idx = RADIX_DIGITS - bin_idx - 1; |
| |
| exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0]; |
| } |
| } |
| } |
| |
| template < |
| typename UnsignedBits, |
| int KEYS_PER_THREAD, |
| typename DigitExtractorT> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile |
| int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) |
| DigitExtractorT digit_extractor, |
| int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] |
| { |
| RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix, |
| BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>()); |
| } |
| }; |
| |
| enum WarpMatchAlgorithm |
| { |
| WARP_MATCH_ANY, |
| WARP_MATCH_ATOMIC_OR |
| }; |
| |
| /** |
| * Radix-rank using matching which computes the counts of keys for each digit |
| * value early, at the expense of doing more work. This may be useful e.g. for |
| * decoupled look-back, where it reduces the time other thread blocks need to |
| * wait for digit counts to become available. |
| */ |
| template <int BLOCK_DIM_X, int RADIX_BITS, bool IS_DESCENDING, |
| BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, |
| WarpMatchAlgorithm MATCH_ALGORITHM = WARP_MATCH_ANY, int NUM_PARTS = 1> |
| struct BlockRadixRankMatchEarlyCounts |
| { |
| // constants |
| enum |
| { |
| BLOCK_THREADS = BLOCK_DIM_X, |
| RADIX_DIGITS = 1 << RADIX_BITS, |
| BINS_PER_THREAD = (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS, |
| BINS_TRACKED_PER_THREAD = BINS_PER_THREAD, |
| FULL_BINS = BINS_PER_THREAD * BLOCK_THREADS == RADIX_DIGITS, |
| WARP_THREADS = CUB_PTX_WARP_THREADS, |
| PARTIAL_WARP_THREADS = BLOCK_THREADS % WARP_THREADS, |
| BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS, |
| PARTIAL_WARP_ID = BLOCK_WARPS - 1, |
| WARP_MASK = ~0, |
| NUM_MATCH_MASKS = MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR ? BLOCK_WARPS : 0, |
| // Guard against declaring zero-sized array: |
| MATCH_MASKS_ALLOC_SIZE = NUM_MATCH_MASKS < 1 ? 1 : NUM_MATCH_MASKS, |
| }; |
| |
| // types |
| typedef cub::BlockScan<int, BLOCK_THREADS, INNER_SCAN_ALGORITHM> BlockScan; |
| |
| |
| |
| // temporary storage |
| struct TempStorage |
| { |
| union |
| { |
| int warp_offsets[BLOCK_WARPS][RADIX_DIGITS]; |
| int warp_histograms[BLOCK_WARPS][RADIX_DIGITS][NUM_PARTS]; |
| }; |
| |
| int match_masks[MATCH_MASKS_ALLOC_SIZE][RADIX_DIGITS]; |
| |
| typename BlockScan::TempStorage prefix_tmp; |
| }; |
| |
| TempStorage& temp_storage; |
| |
| // internal ranking implementation |
| template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, |
| typename CountsCallback> |
| struct BlockRadixRankMatchInternal |
| { |
| TempStorage& s; |
| DigitExtractorT digit_extractor; |
| CountsCallback callback; |
| int warp; |
| int lane; |
| |
| __device__ __forceinline__ std::uint32_t Digit(UnsignedBits key) |
| { |
| std::uint32_t digit = digit_extractor.Digit(key); |
| return IS_DESCENDING ? RADIX_DIGITS - 1 - digit : digit; |
| } |
| |
| __device__ __forceinline__ int ThreadBin(int u) |
| { |
| int bin = threadIdx.x * BINS_PER_THREAD + u; |
| return IS_DESCENDING ? RADIX_DIGITS - 1 - bin : bin; |
| } |
| |
| __device__ __forceinline__ |
| void ComputeHistogramsWarp(UnsignedBits (&keys)[KEYS_PER_THREAD]) |
| { |
| //int* warp_offsets = &s.warp_offsets[warp][0]; |
| int (&warp_histograms)[RADIX_DIGITS][NUM_PARTS] = s.warp_histograms[warp]; |
| // compute warp-private histograms |
| #pragma unroll |
| for (int bin = lane; bin < RADIX_DIGITS; bin += WARP_THREADS) |
| { |
| #pragma unroll |
| for (int part = 0; part < NUM_PARTS; ++part) |
| { |
| warp_histograms[bin][part] = 0; |
| } |
| } |
| if (MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR) |
| { |
| int* match_masks = &s.match_masks[warp][0]; |
| #pragma unroll |
| for (int bin = lane; bin < RADIX_DIGITS; bin += WARP_THREADS) |
| { |
| match_masks[bin] = 0; |
| } |
| } |
| WARP_SYNC(WARP_MASK); |
| |
| // compute private per-part histograms |
| int part = lane % NUM_PARTS; |
| #pragma unroll |
| for (int u = 0; u < KEYS_PER_THREAD; ++u) |
| { |
| atomicAdd(&warp_histograms[Digit(keys[u])][part], 1); |
| } |
| |
| // sum different parts; |
| // no extra work is necessary if NUM_PARTS == 1 |
| if (NUM_PARTS > 1) |
| { |
| WARP_SYNC(WARP_MASK); |
| // TODO: handle RADIX_DIGITS % WARP_THREADS != 0 if it becomes necessary |
| const int WARP_BINS_PER_THREAD = RADIX_DIGITS / WARP_THREADS; |
| int bins[WARP_BINS_PER_THREAD]; |
| #pragma unroll |
| for (int u = 0; u < WARP_BINS_PER_THREAD; ++u) |
| { |
| int bin = lane + u * WARP_THREADS; |
| bins[u] = internal::ThreadReduce(warp_histograms[bin], Sum()); |
| } |
| CTA_SYNC(); |
| |
| // store the resulting histogram in shared memory |
| int* warp_offsets = &s.warp_offsets[warp][0]; |
| #pragma unroll |
| for (int u = 0; u < WARP_BINS_PER_THREAD; ++u) |
| { |
| int bin = lane + u * WARP_THREADS; |
| warp_offsets[bin] = bins[u]; |
| } |
| } |
| } |
| |
| __device__ __forceinline__ |
| void ComputeOffsetsWarpUpsweep(int (&bins)[BINS_PER_THREAD]) |
| { |
| // sum up warp-private histograms |
| #pragma unroll |
| for (int u = 0; u < BINS_PER_THREAD; ++u) |
| { |
| bins[u] = 0; |
| int bin = ThreadBin(u); |
| if (FULL_BINS || (bin >= 0 && bin < RADIX_DIGITS)) |
| { |
| #pragma unroll |
| for (int j_warp = 0; j_warp < BLOCK_WARPS; ++j_warp) |
| { |
| int warp_offset = s.warp_offsets[j_warp][bin]; |
| s.warp_offsets[j_warp][bin] = bins[u]; |
| bins[u] += warp_offset; |
| } |
| } |
| } |
| } |
| |
| __device__ __forceinline__ |
| void ComputeOffsetsWarpDownsweep(int (&offsets)[BINS_PER_THREAD]) |
| { |
| #pragma unroll |
| for (int u = 0; u < BINS_PER_THREAD; ++u) |
| { |
| int bin = ThreadBin(u); |
| if (FULL_BINS || (bin >= 0 && bin < RADIX_DIGITS)) |
| { |
| int digit_offset = offsets[u]; |
| #pragma unroll |
| for (int j_warp = 0; j_warp < BLOCK_WARPS; ++j_warp) |
| { |
| s.warp_offsets[j_warp][bin] += digit_offset; |
| } |
| } |
| } |
| } |
| |
| __device__ __forceinline__ |
| void ComputeRanksItem( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], |
| Int2Type<WARP_MATCH_ATOMIC_OR>) |
| { |
| // compute key ranks |
| int lane_mask = 1 << lane; |
| int* warp_offsets = &s.warp_offsets[warp][0]; |
| int* match_masks = &s.match_masks[warp][0]; |
| #pragma unroll |
| for (int u = 0; u < KEYS_PER_THREAD; ++u) |
| { |
| std::uint32_t bin = Digit(keys[u]); |
| int* p_match_mask = &match_masks[bin]; |
| atomicOr(p_match_mask, lane_mask); |
| WARP_SYNC(WARP_MASK); |
| int bin_mask = *p_match_mask; |
| int leader = (WARP_THREADS - 1) - __clz(bin_mask); |
| int warp_offset = 0; |
| int popc = __popc(bin_mask & LaneMaskLe()); |
| if (lane == leader) |
| { |
| // atomic is a bit faster |
| warp_offset = atomicAdd(&warp_offsets[bin], popc); |
| } |
| warp_offset = SHFL_IDX_SYNC(warp_offset, leader, WARP_MASK); |
| if (lane == leader) *p_match_mask = 0; |
| WARP_SYNC(WARP_MASK); |
| ranks[u] = warp_offset + popc - 1; |
| } |
| } |
| |
| __device__ __forceinline__ |
| void ComputeRanksItem( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], |
| Int2Type<WARP_MATCH_ANY>) |
| { |
| // compute key ranks |
| int* warp_offsets = &s.warp_offsets[warp][0]; |
| #pragma unroll |
| for (int u = 0; u < KEYS_PER_THREAD; ++u) |
| { |
| std::uint32_t bin = Digit(keys[u]); |
| int bin_mask = detail::warp_in_block_matcher_t<RADIX_BITS, |
| PARTIAL_WARP_THREADS, |
| BLOCK_WARPS - 1>::match_any(bin, |
| warp); |
| int leader = (WARP_THREADS - 1) - __clz(bin_mask); |
| int warp_offset = 0; |
| int popc = __popc(bin_mask & LaneMaskLe()); |
| if (lane == leader) |
| { |
| // atomic is a bit faster |
| warp_offset = atomicAdd(&warp_offsets[bin], popc); |
| } |
| warp_offset = SHFL_IDX_SYNC(warp_offset, leader, WARP_MASK); |
| ranks[u] = warp_offset + popc - 1; |
| } |
| } |
| |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], |
| int (&ranks)[KEYS_PER_THREAD], |
| int (&exclusive_digit_prefix)[BINS_PER_THREAD]) |
| { |
| ComputeHistogramsWarp(keys); |
| |
| CTA_SYNC(); |
| int bins[BINS_PER_THREAD]; |
| ComputeOffsetsWarpUpsweep(bins); |
| callback(bins); |
| |
| BlockScan(s.prefix_tmp).ExclusiveSum(bins, exclusive_digit_prefix); |
| |
| ComputeOffsetsWarpDownsweep(exclusive_digit_prefix); |
| CTA_SYNC(); |
| ComputeRanksItem(keys, ranks, Int2Type<MATCH_ALGORITHM>()); |
| } |
| |
| __device__ __forceinline__ BlockRadixRankMatchInternal |
| (TempStorage& temp_storage, DigitExtractorT digit_extractor, CountsCallback callback) |
| : s(temp_storage), digit_extractor(digit_extractor), |
| callback(callback), warp(threadIdx.x / WARP_THREADS), lane(LaneId()) |
| {} |
| }; |
| |
| __device__ __forceinline__ BlockRadixRankMatchEarlyCounts |
| (TempStorage& temp_storage) : temp_storage(temp_storage) {} |
| |
| /** |
| * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. |
| */ |
| template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, |
| typename CountsCallback> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], |
| int (&ranks)[KEYS_PER_THREAD], |
| DigitExtractorT digit_extractor, |
| int (&exclusive_digit_prefix)[BINS_PER_THREAD], |
| CountsCallback callback) |
| { |
| BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, DigitExtractorT, CountsCallback> |
| internal(temp_storage, digit_extractor, callback); |
| internal.RankKeys(keys, ranks, exclusive_digit_prefix); |
| } |
| |
| template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], |
| int (&ranks)[KEYS_PER_THREAD], |
| DigitExtractorT digit_extractor, |
| int (&exclusive_digit_prefix)[BINS_PER_THREAD]) |
| { |
| typedef BlockRadixRankEmptyCallback<BINS_PER_THREAD> CountsCallback; |
| BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, DigitExtractorT, CountsCallback> |
| internal(temp_storage, digit_extractor, CountsCallback()); |
| internal.RankKeys(keys, ranks, exclusive_digit_prefix); |
| } |
| |
| template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT> |
| __device__ __forceinline__ void RankKeys( |
| UnsignedBits (&keys)[KEYS_PER_THREAD], |
| int (&ranks)[KEYS_PER_THREAD], |
| DigitExtractorT digit_extractor) |
| { |
| int exclusive_digit_prefix[BINS_PER_THREAD]; |
| RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix); |
| } |
| }; |
| |
|
|
| #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document |
| namespace detail |
| { |
| |
| // `BlockRadixRank` doesn't conform to the typical pattern, not exposing the algorithm |
| // template parameter. Other algorithms don't provide the same template parameters, not allowing |
| // multi-dimensional thread block specializations. |
| // |
| // TODO(senior-zero) for 3.0: |
| // - Put existing implementations into the detail namespace |
| // - Support multi-dimensional thread blocks in the rest of implementations |
| // - Repurpose BlockRadixRank as an entry name with the algorithm template parameter |
| template <RadixRankAlgorithm RankAlgorithm, |
| int BlockDimX, |
| int RadixBits, |
| bool IsDescending, |
| BlockScanAlgorithm ScanAlgorithm> |
| using block_radix_rank_t = cub::detail::conditional_t< |
| RankAlgorithm == RADIX_RANK_BASIC, |
| BlockRadixRank<BlockDimX, RadixBits, IsDescending, false, ScanAlgorithm>, |
| cub::detail::conditional_t< |
| RankAlgorithm == RADIX_RANK_MEMOIZE, |
| BlockRadixRank<BlockDimX, RadixBits, IsDescending, true, ScanAlgorithm>, |
| cub::detail::conditional_t< |
| RankAlgorithm == RADIX_RANK_MATCH, |
| BlockRadixRankMatch<BlockDimX, RadixBits, IsDescending, ScanAlgorithm>, |
| cub::detail::conditional_t< |
| RankAlgorithm == RADIX_RANK_MATCH_EARLY_COUNTS_ANY, |
| BlockRadixRankMatchEarlyCounts<BlockDimX, |
| RadixBits, |
| IsDescending, |
| ScanAlgorithm, |
| WARP_MATCH_ANY>, |
| BlockRadixRankMatchEarlyCounts<BlockDimX, |
| RadixBits, |
| IsDescending, |
| ScanAlgorithm, |
| WARP_MATCH_ATOMIC_OR>>>>>; |
| |
| } // namespace detail |
| #endif // DOXYGEN_SHOULD_SKIP_THIS |
| |
| |
| CUB_NAMESPACE_END |
| |
| |
| |