| /****************************************************************************** |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are met: |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * * Neither the name of the NVIDIA CORPORATION nor the |
| * names of its contributors may be used to endorse or promote products |
| * derived from this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| * |
| ******************************************************************************/ |
| |
| /** |
| * \file |
| * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select. |
| */ |
| |
| #pragma once |
| |
| #include <iterator> |
| |
| #include "single_pass_scan_operators.cuh" |
| #include "../block/block_load.cuh" |
| #include "../block/block_store.cuh" |
| #include "../block/block_scan.cuh" |
| #include "../block/block_exchange.cuh" |
| #include "../block/block_discontinuity.cuh" |
| #include "../config.cuh" |
| #include "../grid/grid_queue.cuh" |
| #include "../iterator/cache_modified_input_iterator.cuh" |
| |
| /// Optional outer namespace(s) |
| CUB_NS_PREFIX |
| |
| /// CUB namespace |
| namespace cub { |
| |
| |
| /****************************************************************************** |
| * Tuning policy types |
| ******************************************************************************/ |
| |
| /** |
| * Parameterizable tuning policy type for AgentSelectIf |
| */ |
| template < |
| int _BLOCK_THREADS, ///< Threads per thread block |
| int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) |
| BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use |
| CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements |
| BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use |
| struct AgentSelectIfPolicy |
| { |
| enum |
| { |
| BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block |
| ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) |
| }; |
| |
| static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use |
| static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements |
| static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use |
| }; |
| |
| |
| |
| |
| /****************************************************************************** |
| * Thread block abstractions |
| ******************************************************************************/ |
| |
| |
| /** |
| * \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection |
| * |
| * Performs functor-based selection if SelectOpT functor type != NullType |
| * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType |
| * Otherwise performs discontinuity selection (keep unique) |
| */ |
| template < |
| typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type |
| typename InputIteratorT, ///< Random-access input iterator type for selection items |
| typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection) |
| typename SelectedOutputIteratorT, ///< Random-access input iterator type for selection_flags items |
| typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection) |
| typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection) |
| typename OffsetT, ///< Signed integer type for global offsets |
| bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output |
| struct AgentSelectIf |
| { |
| //--------------------------------------------------------------------- |
| // Types and constants |
| //--------------------------------------------------------------------- |
| |
| // The input value type |
| typedef typename std::iterator_traits<InputIteratorT>::value_type InputT; |
| |
| // The output value type |
| typedef typename If<(Equals<typename std::iterator_traits<SelectedOutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? |
| typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type, |
| typename std::iterator_traits<SelectedOutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type |
| |
| // The flag value type |
| typedef typename std::iterator_traits<FlagsInputIteratorT>::value_type FlagT; |
| |
| // Tile status descriptor interface type |
| typedef ScanTileState<OffsetT> ScanTileStateT; |
| |
| // Constants |
| enum |
| { |
| USE_SELECT_OP, |
| USE_SELECT_FLAGS, |
| USE_DISCONTINUITY, |
| |
| BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS, |
| ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD, |
| TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, |
| TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), |
| |
| SELECT_METHOD = (!Equals<SelectOpT, NullType>::VALUE) ? |
| USE_SELECT_OP : |
| (!Equals<FlagT, NullType>::VALUE) ? |
| USE_SELECT_FLAGS : |
| USE_DISCONTINUITY |
| }; |
| |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for items |
| typedef typename If<IsPointer<InputIteratorT>::VALUE, |
| CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator |
| InputIteratorT>::Type // Directly use the supplied input iterator type |
| WrappedInputIteratorT; |
| |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for values |
| typedef typename If<IsPointer<FlagsInputIteratorT>::VALUE, |
| CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator |
| FlagsInputIteratorT>::Type // Directly use the supplied input iterator type |
| WrappedFlagsInputIteratorT; |
| |
| // Parameterized BlockLoad type for input data |
| typedef BlockLoad< |
| OutputT, |
| BLOCK_THREADS, |
| ITEMS_PER_THREAD, |
| AgentSelectIfPolicyT::LOAD_ALGORITHM> |
| BlockLoadT; |
| |
| // Parameterized BlockLoad type for flags |
| typedef BlockLoad< |
| FlagT, |
| BLOCK_THREADS, |
| ITEMS_PER_THREAD, |
| AgentSelectIfPolicyT::LOAD_ALGORITHM> |
| BlockLoadFlags; |
| |
| // Parameterized BlockDiscontinuity type for items |
| typedef BlockDiscontinuity< |
| OutputT, |
| BLOCK_THREADS> |
| BlockDiscontinuityT; |
| |
| // Parameterized BlockScan type |
| typedef BlockScan< |
| OffsetT, |
| BLOCK_THREADS, |
| AgentSelectIfPolicyT::SCAN_ALGORITHM> |
| BlockScanT; |
| |
| // Callback type for obtaining tile prefix during block scan |
| typedef TilePrefixCallbackOp< |
| OffsetT, |
| cub::Sum, |
| ScanTileStateT> |
| TilePrefixCallbackOpT; |
| |
| // Item exchange type |
| typedef OutputT ItemExchangeT[TILE_ITEMS]; |
| |
| // Shared memory type for this thread block |
| union _TempStorage |
| { |
| struct |
| { |
| typename BlockScanT::TempStorage scan; // Smem needed for tile scanning |
| typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback |
| typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection |
| }; |
| |
| // Smem needed for loading items |
| typename BlockLoadT::TempStorage load_items; |
| |
| // Smem needed for loading values |
| typename BlockLoadFlags::TempStorage load_flags; |
| |
| // Smem needed for compacting items (allows non POD items in this union) |
| Uninitialized<ItemExchangeT> raw_exchange; |
| }; |
| |
| // Alias wrapper allowing storage to be unioned |
| struct TempStorage : Uninitialized<_TempStorage> {}; |
| |
|
|
| //--------------------------------------------------------------------- |
| // Per-thread fields |
| //--------------------------------------------------------------------- |
| |
| _TempStorage& temp_storage; ///< Reference to temp_storage |
| WrappedInputIteratorT d_in; ///< Input items |
| SelectedOutputIteratorT d_selected_out; ///< Unique output items |
| WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable) |
| InequalityWrapper<EqualityOpT> inequality_op; ///< T inequality operator |
| SelectOpT select_op; ///< Selection operator |
| OffsetT num_items; ///< Total number of input items |
| |
|
|
| //--------------------------------------------------------------------- |
| // Constructor |
| //--------------------------------------------------------------------- |
| |
| // Constructor |
| __device__ __forceinline__ |
| AgentSelectIf( |
| TempStorage &temp_storage, ///< Reference to temp_storage |
| InputIteratorT d_in, ///< Input data |
| FlagsInputIteratorT d_flags_in, ///< Input selection flags (if applicable) |
| SelectedOutputIteratorT d_selected_out, ///< Output data |
| SelectOpT select_op, ///< Selection operator |
| EqualityOpT equality_op, ///< Equality operator |
| OffsetT num_items) ///< Total number of input items |
| : |
| temp_storage(temp_storage.Alias()), |
| d_in(d_in), |
| d_flags_in(d_flags_in), |
| d_selected_out(d_selected_out), |
| select_op(select_op), |
| inequality_op(equality_op), |
| num_items(num_items) |
| {} |
| |
|
|
| //--------------------------------------------------------------------- |
| // Utility methods for initializing the selections |
| //--------------------------------------------------------------------- |
| |
| /** |
| * Initialize selections (specialized for selection operator) |
| */ |
| template <bool IS_FIRST_TILE, bool IS_LAST_TILE> |
| __device__ __forceinline__ void InitializeSelections( |
| OffsetT /*tile_offset*/, |
| OffsetT num_tile_items, |
| OutputT (&items)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| Int2Type<USE_SELECT_OP> /*select_method*/) |
| { |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| // Out-of-bounds items are selection_flags |
| selection_flags[ITEM] = 1; |
| |
| if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items)) |
| selection_flags[ITEM] = select_op(items[ITEM]); |
| } |
| } |
| |
|
|
| /** |
| * Initialize selections (specialized for valid flags) |
| */ |
| template <bool IS_FIRST_TILE, bool IS_LAST_TILE> |
| __device__ __forceinline__ void InitializeSelections( |
| OffsetT tile_offset, |
| OffsetT num_tile_items, |
| OutputT (&/*items*/)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| Int2Type<USE_SELECT_FLAGS> /*select_method*/) |
| { |
| CTA_SYNC(); |
| |
| FlagT flags[ITEMS_PER_THREAD]; |
| |
| if (IS_LAST_TILE) |
| { |
| // Out-of-bounds items are selection_flags |
| BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1); |
| } |
| else |
| { |
| BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); |
| } |
| |
| // Convert flag type to selection_flags type |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| selection_flags[ITEM] = flags[ITEM]; |
| } |
| } |
| |
|
|
| /** |
| * Initialize selections (specialized for discontinuity detection) |
| */ |
| template <bool IS_FIRST_TILE, bool IS_LAST_TILE> |
| __device__ __forceinline__ void InitializeSelections( |
| OffsetT tile_offset, |
| OffsetT num_tile_items, |
| OutputT (&items)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| Int2Type<USE_DISCONTINUITY> /*select_method*/) |
| { |
| if (IS_FIRST_TILE) |
| { |
| CTA_SYNC(); |
| |
| // Set head selection_flags. First tile sets the first flag for the first item |
| BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op); |
| } |
| else |
| { |
| OutputT tile_predecessor; |
| if (threadIdx.x == 0) |
| tile_predecessor = d_in[tile_offset - 1]; |
| |
| CTA_SYNC(); |
| |
| BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor); |
| } |
| |
| // Set selection flags for out-of-bounds items |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| // Set selection_flags for out-of-bounds items |
| if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items)) |
| selection_flags[ITEM] = 1; |
| } |
| } |
| |
|
|
| //--------------------------------------------------------------------- |
| // Scatter utility methods |
| //--------------------------------------------------------------------- |
| |
| /** |
| * Scatter flagged items to output offsets (specialized for direct scattering) |
| */ |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> |
| __device__ __forceinline__ void ScatterDirect( |
| OutputT (&items)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], |
| OffsetT num_selections) |
| { |
| // Scatter flagged items |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| if (selection_flags[ITEM]) |
| { |
| if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections) |
| { |
| d_selected_out[selection_indices[ITEM]] = items[ITEM]; |
| } |
| } |
| } |
| } |
| |
|
|
| /** |
| * Scatter flagged items to output offsets (specialized for two-phase scattering) |
| */ |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> |
| __device__ __forceinline__ void ScatterTwoPhase( |
| OutputT (&items)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], |
| int /*num_tile_items*/, ///< Number of valid items in this tile |
| int num_tile_selections, ///< Number of selections in this tile |
| OffsetT num_selections_prefix, ///< Total number of selections prior to this tile |
| OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile |
| Int2Type<false> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition |
| { |
| CTA_SYNC(); |
| |
| // Compact and scatter items |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix; |
| if (selection_flags[ITEM]) |
| { |
| temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; |
| } |
| } |
| |
| CTA_SYNC(); |
| |
| for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS) |
| { |
| d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item]; |
| } |
| } |
| |
|
|
| /** |
| * Scatter flagged items to output offsets (specialized for two-phase scattering) |
| */ |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> |
| __device__ __forceinline__ void ScatterTwoPhase( |
| OutputT (&items)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], |
| int num_tile_items, ///< Number of valid items in this tile |
| int num_tile_selections, ///< Number of selections in this tile |
| OffsetT num_selections_prefix, ///< Total number of selections prior to this tile |
| OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile |
| Int2Type<true> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition |
| { |
| CTA_SYNC(); |
| |
| int tile_num_rejections = num_tile_items - num_tile_selections; |
| |
| // Scatter items to shared memory (rejections first) |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM; |
| int local_selection_idx = selection_indices[ITEM] - num_selections_prefix; |
| int local_rejection_idx = item_idx - local_selection_idx; |
| int local_scatter_offset = (selection_flags[ITEM]) ? |
| tile_num_rejections + local_selection_idx : |
| local_rejection_idx; |
| |
| temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; |
| } |
| |
| CTA_SYNC(); |
| |
| // Gather items from shared memory and scatter to global |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; |
| int rejection_idx = item_idx; |
| int selection_idx = item_idx - tile_num_rejections; |
| OffsetT scatter_offset = (item_idx < tile_num_rejections) ? |
| num_items - num_rejected_prefix - rejection_idx - 1 : |
| num_selections_prefix + selection_idx; |
| |
| OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; |
| |
| if (!IS_LAST_TILE || (item_idx < num_tile_items)) |
| { |
| d_selected_out[scatter_offset] = item; |
| } |
| } |
| } |
| |
|
|
| /** |
| * Scatter flagged items |
| */ |
| template <bool IS_LAST_TILE, bool IS_FIRST_TILE> |
| __device__ __forceinline__ void Scatter( |
| OutputT (&items)[ITEMS_PER_THREAD], |
| OffsetT (&selection_flags)[ITEMS_PER_THREAD], |
| OffsetT (&selection_indices)[ITEMS_PER_THREAD], |
| int num_tile_items, ///< Number of valid items in this tile |
| int num_tile_selections, ///< Number of selections in this tile |
| OffsetT num_selections_prefix, ///< Total number of selections prior to this tile |
| OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile |
| OffsetT num_selections) ///< Total number of selections including this tile |
| { |
| // Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one |
| if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))) |
| { |
| ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>( |
| items, |
| selection_flags, |
| selection_indices, |
| num_tile_items, |
| num_tile_selections, |
| num_selections_prefix, |
| num_rejected_prefix, |
| Int2Type<KEEP_REJECTS>()); |
| } |
| else |
| { |
| ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>( |
| items, |
| selection_flags, |
| selection_indices, |
| num_selections); |
| } |
| } |
| |
| //--------------------------------------------------------------------- |
| // Cooperatively scan a device-wide sequence of tiles with other CTAs |
| //--------------------------------------------------------------------- |
| |
|
|
| /** |
| * Process first tile of input (dynamic chained scan). Returns the running count of selections (including this tile) |
| */ |
| template <bool IS_LAST_TILE> |
| __device__ __forceinline__ OffsetT ConsumeFirstTile( |
| int num_tile_items, ///< Number of input items comprising this tile |
| OffsetT tile_offset, ///< Tile offset |
| ScanTileStateT& tile_state) ///< Global tile state descriptor |
| { |
| OutputT items[ITEMS_PER_THREAD]; |
| OffsetT selection_flags[ITEMS_PER_THREAD]; |
| OffsetT selection_indices[ITEMS_PER_THREAD]; |
| |
| // Load items |
| if (IS_LAST_TILE) |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); |
| else |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); |
| |
| // Initialize selection_flags |
| InitializeSelections<true, IS_LAST_TILE>( |
| tile_offset, |
| num_tile_items, |
| items, |
| selection_flags, |
| Int2Type<SELECT_METHOD>()); |
| |
| CTA_SYNC(); |
| |
| // Exclusive scan of selection_flags |
| OffsetT num_tile_selections; |
| BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections); |
| |
| if (threadIdx.x == 0) |
| { |
| // Update tile status if this is not the last tile |
| if (!IS_LAST_TILE) |
| tile_state.SetInclusive(0, num_tile_selections); |
| } |
| |
| // Discount any out-of-bounds selections |
| if (IS_LAST_TILE) |
| num_tile_selections -= (TILE_ITEMS - num_tile_items); |
| |
| // Scatter flagged items |
| Scatter<IS_LAST_TILE, true>( |
| items, |
| selection_flags, |
| selection_indices, |
| num_tile_items, |
| num_tile_selections, |
| 0, |
| 0, |
| num_tile_selections); |
| |
| return num_tile_selections; |
| } |
| |
|
|
| /** |
| * Process subsequent tile of input (dynamic chained scan). Returns the running count of selections (including this tile) |
| */ |
| template <bool IS_LAST_TILE> |
| __device__ __forceinline__ OffsetT ConsumeSubsequentTile( |
| int num_tile_items, ///< Number of input items comprising this tile |
| int tile_idx, ///< Tile index |
| OffsetT tile_offset, ///< Tile offset |
| ScanTileStateT& tile_state) ///< Global tile state descriptor |
| { |
| OutputT items[ITEMS_PER_THREAD]; |
| OffsetT selection_flags[ITEMS_PER_THREAD]; |
| OffsetT selection_indices[ITEMS_PER_THREAD]; |
| |
| // Load items |
| if (IS_LAST_TILE) |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); |
| else |
| BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); |
| |
| // Initialize selection_flags |
| InitializeSelections<false, IS_LAST_TILE>( |
| tile_offset, |
| num_tile_items, |
| items, |
| selection_flags, |
| Int2Type<SELECT_METHOD>()); |
| |
| CTA_SYNC(); |
| |
| // Exclusive scan of values and selection_flags |
| TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, cub::Sum(), tile_idx); |
| BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op); |
| |
| OffsetT num_tile_selections = prefix_op.GetBlockAggregate(); |
| OffsetT num_selections = prefix_op.GetInclusivePrefix(); |
| OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix(); |
| OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix; |
| |
| // Discount any out-of-bounds selections |
| if (IS_LAST_TILE) |
| { |
| int num_discount = TILE_ITEMS - num_tile_items; |
| num_selections -= num_discount; |
| num_tile_selections -= num_discount; |
| } |
| |
| // Scatter flagged items |
| Scatter<IS_LAST_TILE, false>( |
| items, |
| selection_flags, |
| selection_indices, |
| num_tile_items, |
| num_tile_selections, |
| num_selections_prefix, |
| num_rejected_prefix, |
| num_selections); |
| |
| return num_selections; |
| } |
| |
|
|
| /** |
| * Process a tile of input |
| */ |
| template <bool IS_LAST_TILE> |
| __device__ __forceinline__ OffsetT ConsumeTile( |
| int num_tile_items, ///< Number of input items comprising this tile |
| int tile_idx, ///< Tile index |
| OffsetT tile_offset, ///< Tile offset |
| ScanTileStateT& tile_state) ///< Global tile state descriptor |
| { |
| OffsetT num_selections; |
| if (tile_idx == 0) |
| { |
| num_selections = ConsumeFirstTile<IS_LAST_TILE>(num_tile_items, tile_offset, tile_state); |
| } |
| else |
| { |
| num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state); |
| } |
| |
| return num_selections; |
| } |
| |
|
|
| /** |
| * Scan tiles of items as part of a dynamic chained scan |
| */ |
| template <typename NumSelectedIteratorT> ///< Output iterator type for recording number of items selection_flags |
| __device__ __forceinline__ void ConsumeRange( |
| int num_tiles, ///< Total number of input tiles |
| ScanTileStateT& tile_state, ///< Global tile state descriptor |
| NumSelectedIteratorT d_num_selected_out) ///< Output total number selection_flags |
| { |
| // Blocks are launched in increasing order, so just assign one tile per block |
| int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index |
| OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile |
| |
| if (tile_idx < num_tiles - 1) |
| { |
| // Not the last tile (full) |
| ConsumeTile<false>(TILE_ITEMS, tile_idx, tile_offset, tile_state); |
| } |
| else |
| { |
| // The last tile (possibly partially-full) |
| OffsetT num_remaining = num_items - tile_offset; |
| OffsetT num_selections = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state); |
| |
| if (threadIdx.x == 0) |
| { |
| // Output the total number of items selection_flags |
| *d_num_selected_out = num_selections; |
| } |
| } |
| } |
| |
| }; |
|
|
|
|
|
|
| } // CUB namespace |
| CUB_NS_POSTFIX // Optional outer namespace(s) |
|
|
|
|