| /****************************************************************************** |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are met: |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * * Neither the name of the NVIDIA CORPORATION nor the |
| * names of its contributors may be used to endorse or promote products |
| * derived from this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| * |
| ******************************************************************************/ |
| |
| /** |
| * \file |
| * cub::AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key. |
| */ |
| |
| #pragma once |
| |
| #include <iterator> |
| |
| #include "single_pass_scan_operators.cuh" |
| #include "../block/block_load.cuh" |
| #include "../block/block_store.cuh" |
| #include "../block/block_scan.cuh" |
| #include "../block/block_discontinuity.cuh" |
| #include "../config.cuh" |
| #include "../iterator/cache_modified_input_iterator.cuh" |
| #include "../iterator/constant_input_iterator.cuh" |
| |
| /// Optional outer namespace(s) |
| CUB_NS_PREFIX |
| |
| /// CUB namespace |
| namespace cub { |
| |
| |
| /****************************************************************************** |
| * Tuning policy types |
| ******************************************************************************/ |
| |
| /** |
| * Parameterizable tuning policy type for AgentReduceByKey |
| */ |
| template < |
| int _BLOCK_THREADS, ///< Threads per thread block |
| int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) |
| BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use |
| CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements |
| BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use |
| struct AgentReduceByKeyPolicy |
| { |
| enum |
| { |
| BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block |
| ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) |
| }; |
| |
| static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use |
| static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements |
| static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use |
| }; |
| |
| |
| /****************************************************************************** |
| * Thread block abstractions |
| ******************************************************************************/ |
| |
| /** |
| * \brief AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key |
| */ |
| template < |
| typename AgentReduceByKeyPolicyT, ///< Parameterized AgentReduceByKeyPolicy tuning policy type |
| typename KeysInputIteratorT, ///< Random-access input iterator type for keys |
| typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys |
| typename ValuesInputIteratorT, ///< Random-access input iterator type for values |
| typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values |
| typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of items selected |
| typename EqualityOpT, ///< KeyT equality operator type |
| typename ReductionOpT, ///< ValueT reduction operator type |
| typename OffsetT> ///< Signed integer type for global offsets |
| struct AgentReduceByKey |
| { |
| //--------------------------------------------------------------------- |
| // Types and constants |
| //--------------------------------------------------------------------- |
| |
| // The input keys type |
| typedef typename std::iterator_traits<KeysInputIteratorT>::value_type KeyInputT; |
| |
| // The output keys type |
| typedef typename If<(Equals<typename std::iterator_traits<UniqueOutputIteratorT>::value_type, void>::VALUE), // KeyOutputT = (if output iterator's value type is void) ? |
| typename std::iterator_traits<KeysInputIteratorT>::value_type, // ... then the input iterator's value type, |
| typename std::iterator_traits<UniqueOutputIteratorT>::value_type>::Type KeyOutputT; // ... else the output iterator's value type |
| |
| // The input values type |
| typedef typename std::iterator_traits<ValuesInputIteratorT>::value_type ValueInputT; |
| |
| // The output values type |
| typedef typename If<(Equals<typename std::iterator_traits<AggregatesOutputIteratorT>::value_type, void>::VALUE), // ValueOutputT = (if output iterator's value type is void) ? |
| typename std::iterator_traits<ValuesInputIteratorT>::value_type, // ... then the input iterator's value type, |
| typename std::iterator_traits<AggregatesOutputIteratorT>::value_type>::Type ValueOutputT; // ... else the output iterator's value type |
| |
| // Tuple type for scanning (pairs accumulated segment-value with segment-index) |
| typedef KeyValuePair<OffsetT, ValueOutputT> OffsetValuePairT; |
| |
| // Tuple type for pairing keys and values |
| typedef KeyValuePair<KeyOutputT, ValueOutputT> KeyValuePairT; |
| |
| // Tile status descriptor interface type |
| typedef ReduceByKeyScanTileState<ValueOutputT, OffsetT> ScanTileStateT; |
| |
| // Guarded inequality functor |
| template <typename _EqualityOpT> |
| struct GuardedInequalityWrapper |
| { |
| _EqualityOpT op; ///< Wrapped equality operator |
| int num_remaining; ///< Items remaining |
| |
| /// Constructor |
| __host__ __device__ __forceinline__ |
| GuardedInequalityWrapper(_EqualityOpT op, int num_remaining) : op(op), num_remaining(num_remaining) {} |
| |
| /// Boolean inequality operator, returns <tt>(a != b)</tt> |
| template <typename T> |
| __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b, int idx) const |
| { |
| if (idx < num_remaining) |
| return !op(a, b); // In bounds |
| |
| // Return true if first out-of-bounds item, false otherwise |
| return (idx == num_remaining); |
| } |
| }; |
| |
| |
| // Constants |
| enum |
| { |
| BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS, |
| ITEMS_PER_THREAD = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD, |
| TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, |
| TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), |
| |
| // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type) |
| HAS_IDENTITY_ZERO = (Equals<ReductionOpT, cub::Sum>::VALUE) && (Traits<ValueOutputT>::PRIMITIVE), |
| }; |
| |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys |
| typedef typename If<IsPointer<KeysInputIteratorT>::VALUE, |
| CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator |
| KeysInputIteratorT>::Type // Directly use the supplied input iterator type |
| WrappedKeysInputIteratorT; |
| |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for values |
| typedef typename If<IsPointer<ValuesInputIteratorT>::VALUE, |
| CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator |
| ValuesInputIteratorT>::Type // Directly use the supplied input iterator type |
| WrappedValuesInputIteratorT; |
| |
| // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values |
| typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE, |
| CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator |
| AggregatesOutputIteratorT>::Type // Directly use the supplied input iterator type |
| WrappedFixupInputIteratorT; |
| |
| // Reduce-value-by-segment scan operator |
| typedef ReduceBySegmentOp<ReductionOpT> ReduceBySegmentOpT; |
| |
| // Parameterized BlockLoad type for keys |
| typedef BlockLoad< |
| KeyOutputT, |
| BLOCK_THREADS, |
| ITEMS_PER_THREAD, |
| AgentReduceByKeyPolicyT::LOAD_ALGORITHM> |
| BlockLoadKeysT; |
| |
| // Parameterized BlockLoad type for values |
| typedef BlockLoad< |
| ValueOutputT, |
| BLOCK_THREADS, |
| ITEMS_PER_THREAD, |
| AgentReduceByKeyPolicyT::LOAD_ALGORITHM> |
| BlockLoadValuesT; |
| |
| // Parameterized BlockDiscontinuity type for keys |
| typedef BlockDiscontinuity< |
| KeyOutputT, |
| BLOCK_THREADS> |
| BlockDiscontinuityKeys; |
| |
| // Parameterized BlockScan type |
| typedef BlockScan< |
| OffsetValuePairT, |
| BLOCK_THREADS, |
| AgentReduceByKeyPolicyT::SCAN_ALGORITHM> |
| BlockScanT; |
| |
| // Callback type for obtaining tile prefix during block scan |
| typedef TilePrefixCallbackOp< |
| OffsetValuePairT, |
| ReduceBySegmentOpT, |
| ScanTileStateT> |
| TilePrefixCallbackOpT; |
| |
| // Key and value exchange types |
| typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1]; |
| typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1]; |
| |
| // Shared memory type for this thread block |
| union _TempStorage |
| { |
| struct |
| { |
| typename BlockScanT::TempStorage scan; // Smem needed for tile scanning |
| typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback |
| typename BlockDiscontinuityKeys::TempStorage discontinuity; // Smem needed for discontinuity detection |
| }; |
| |
| // Smem needed for loading keys |
| typename BlockLoadKeysT::TempStorage load_keys; |
| |
| // Smem needed for loading values |
| typename BlockLoadValuesT::TempStorage load_values; |
| |
| // Smem needed for compacting key value pairs(allows non POD items in this union) |
| Uninitialized<KeyValuePairT[TILE_ITEMS + 1]> raw_exchange; |
| }; |
| |
| // Alias wrapper allowing storage to be unioned |
| struct TempStorage : Uninitialized<_TempStorage> {}; |
| |
|
|
| //--------------------------------------------------------------------- |
| // Per-thread fields |
| //--------------------------------------------------------------------- |
| |
| _TempStorage& temp_storage; ///< Reference to temp_storage |
| WrappedKeysInputIteratorT d_keys_in; ///< Input keys |
| UniqueOutputIteratorT d_unique_out; ///< Unique output keys |
| WrappedValuesInputIteratorT d_values_in; ///< Input values |
| AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates |
| NumRunsOutputIteratorT d_num_runs_out; ///< Output pointer for total number of segments identified |
| EqualityOpT equality_op; ///< KeyT equality operator |
| ReductionOpT reduction_op; ///< Reduction operator |
| ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator |
| |
|
|
| //--------------------------------------------------------------------- |
| // Constructor |
| //--------------------------------------------------------------------- |
| |
| // Constructor |
| __device__ __forceinline__ |
| AgentReduceByKey( |
| TempStorage& temp_storage, ///< Reference to temp_storage |
| KeysInputIteratorT d_keys_in, ///< Input keys |
| UniqueOutputIteratorT d_unique_out, ///< Unique output keys |
| ValuesInputIteratorT d_values_in, ///< Input values |
| AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates |
| NumRunsOutputIteratorT d_num_runs_out, ///< Output pointer for total number of segments identified |
| EqualityOpT equality_op, ///< KeyT equality operator |
| ReductionOpT reduction_op) ///< ValueT reduction operator |
| : |
| temp_storage(temp_storage.Alias()), |
| d_keys_in(d_keys_in), |
| d_unique_out(d_unique_out), |
| d_values_in(d_values_in), |
| d_aggregates_out(d_aggregates_out), |
| d_num_runs_out(d_num_runs_out), |
| equality_op(equality_op), |
| reduction_op(reduction_op), |
| scan_op(reduction_op) |
| {} |
| |
|
|
| //--------------------------------------------------------------------- |
| // Scatter utility methods |
| //--------------------------------------------------------------------- |
| |
| /** |
| * Directly scatter flagged items to output offsets |
| */ |
| __device__ __forceinline__ void ScatterDirect( |
| KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], |
| OffsetT (&segment_flags)[ITEMS_PER_THREAD], |
| OffsetT (&segment_indices)[ITEMS_PER_THREAD]) |
| { |
| // Scatter flagged keys and values |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| if (segment_flags[ITEM]) |
| { |
| d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key; |
| d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value; |
| } |
| } |
| } |
| |
|
|
| /** |
| * 2-phase scatter flagged items to output offsets |
| * |
| * The exclusive scan causes each head flag to be paired with the previous |
| * value aggregate: the scatter offsets must be decremented for value aggregates |
| */ |
| __device__ __forceinline__ void ScatterTwoPhase( |
| KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], |
| OffsetT (&segment_flags)[ITEMS_PER_THREAD], |
| OffsetT (&segment_indices)[ITEMS_PER_THREAD], |
| OffsetT num_tile_segments, |
| OffsetT num_tile_segments_prefix) |
| { |
| CTA_SYNC(); |
| |
| // Compact and scatter pairs |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| if (segment_flags[ITEM]) |
| { |
| temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM]; |
| } |
| } |
| |
| CTA_SYNC(); |
| |
| for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS) |
| { |
| KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item]; |
| d_unique_out[num_tile_segments_prefix + item] = pair.key; |
| d_aggregates_out[num_tile_segments_prefix + item] = pair.value; |
| } |
| } |
| |
|
|
| /** |
| * Scatter flagged items |
| */ |
| __device__ __forceinline__ void Scatter( |
| KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], |
| OffsetT (&segment_flags)[ITEMS_PER_THREAD], |
| OffsetT (&segment_indices)[ITEMS_PER_THREAD], |
| OffsetT num_tile_segments, |
| OffsetT num_tile_segments_prefix) |
| { |
| // Do a one-phase scatter if (a) two-phase is disabled or (b) the average number of selected items per thread is less than one |
| if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS)) |
| { |
| ScatterTwoPhase( |
| scatter_items, |
| segment_flags, |
| segment_indices, |
| num_tile_segments, |
| num_tile_segments_prefix); |
| } |
| else |
| { |
| ScatterDirect( |
| scatter_items, |
| segment_flags, |
| segment_indices); |
| } |
| } |
| |
|
|
| //--------------------------------------------------------------------- |
| // Cooperatively scan a device-wide sequence of tiles with other CTAs |
| //--------------------------------------------------------------------- |
| |
| /** |
| * Process a tile of input (dynamic chained scan) |
| */ |
| template <bool IS_LAST_TILE> ///< Whether the current tile is the last tile |
| __device__ __forceinline__ void ConsumeTile( |
| OffsetT num_remaining, ///< Number of global input items remaining (including this tile) |
| int tile_idx, ///< Tile index |
| OffsetT tile_offset, ///< Tile offset |
| ScanTileStateT& tile_state) ///< Global tile state descriptor |
| { |
| KeyOutputT keys[ITEMS_PER_THREAD]; // Tile keys |
| KeyOutputT prev_keys[ITEMS_PER_THREAD]; // Tile keys shuffled up |
| ValueOutputT values[ITEMS_PER_THREAD]; // Tile values |
| OffsetT head_flags[ITEMS_PER_THREAD]; // Segment head flags |
| OffsetT segment_indices[ITEMS_PER_THREAD]; // Segment indices |
| OffsetValuePairT scan_items[ITEMS_PER_THREAD]; // Zipped values and segment flags|indices |
| KeyValuePairT scatter_items[ITEMS_PER_THREAD]; // Zipped key value pairs for scattering |
| |
| // Load keys |
| if (IS_LAST_TILE) |
| BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining); |
| else |
| BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys); |
| |
| // Load tile predecessor key in first thread |
| KeyOutputT tile_predecessor; |
| if (threadIdx.x == 0) |
| { |
| tile_predecessor = (tile_idx == 0) ? |
| keys[0] : // First tile gets repeat of first item (thus first item will not be flagged as a head) |
| d_keys_in[tile_offset - 1]; // Subsequent tiles get last key from previous tile |
| } |
| |
| CTA_SYNC(); |
| |
| // Load values |
| if (IS_LAST_TILE) |
| BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values, num_remaining); |
| else |
| BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values); |
| |
| CTA_SYNC(); |
| |
| // Initialize head-flags and shuffle up the previous keys |
| if (IS_LAST_TILE) |
| { |
| // Use custom flag operator to additionally flag the first out-of-bounds item |
| GuardedInequalityWrapper<EqualityOpT> flag_op(equality_op, num_remaining); |
| BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads( |
| head_flags, keys, prev_keys, flag_op, tile_predecessor); |
| } |
| else |
| { |
| InequalityWrapper<EqualityOpT> flag_op(equality_op); |
| BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads( |
| head_flags, keys, prev_keys, flag_op, tile_predecessor); |
| } |
| |
| // Zip values and head flags |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| scan_items[ITEM].value = values[ITEM]; |
| scan_items[ITEM].key = head_flags[ITEM]; |
| } |
| |
| // Perform exclusive tile scan |
| OffsetValuePairT block_aggregate; // Inclusive block-wide scan aggregate |
| OffsetT num_segments_prefix; // Number of segments prior to this tile |
| OffsetValuePairT total_aggregate; // The tile prefix folded with block_aggregate |
| if (tile_idx == 0) |
| { |
| // Scan first tile |
| BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate); |
| num_segments_prefix = 0; |
| total_aggregate = block_aggregate; |
| |
| // Update tile status if there are successor tiles |
| if ((!IS_LAST_TILE) && (threadIdx.x == 0)) |
| tile_state.SetInclusive(0, block_aggregate); |
| } |
| else |
| { |
| // Scan non-first tile |
| TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); |
| BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, prefix_op); |
| |
| block_aggregate = prefix_op.GetBlockAggregate(); |
| num_segments_prefix = prefix_op.GetExclusivePrefix().key; |
| total_aggregate = prefix_op.GetInclusivePrefix(); |
| } |
| |
| // Rezip scatter items and segment indices |
| #pragma unroll |
| for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) |
| { |
| scatter_items[ITEM].key = prev_keys[ITEM]; |
| scatter_items[ITEM].value = scan_items[ITEM].value; |
| segment_indices[ITEM] = scan_items[ITEM].key; |
| } |
| |
| // At this point, each flagged segment head has: |
| // - The key for the previous segment |
| // - The reduced value from the previous segment |
| // - The segment index for the reduced value |
| |
| // Scatter flagged keys and values |
| OffsetT num_tile_segments = block_aggregate.key; |
| Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix); |
| |
| // Last thread in last tile will output final count (and last pair, if necessary) |
| if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1)) |
| { |
| OffsetT num_segments = num_segments_prefix + num_tile_segments; |
| |
| // If the last tile is a whole tile, output the final_value |
| if (num_remaining == TILE_ITEMS) |
| { |
| d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1]; |
| d_aggregates_out[num_segments] = total_aggregate.value; |
| num_segments++; |
| } |
| |
| // Output the total number of items selected |
| *d_num_runs_out = num_segments; |
| } |
| } |
| |
|
|
| /** |
| * Scan tiles of items as part of a dynamic chained scan |
| */ |
| __device__ __forceinline__ void ConsumeRange( |
| int num_items, ///< Total number of input items |
| ScanTileStateT& tile_state, ///< Global tile state descriptor |
| int start_tile) ///< The starting tile for the current grid |
| { |
| // Blocks are launched in increasing order, so just assign one tile per block |
| int tile_idx = start_tile + blockIdx.x; // Current tile index |
| OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile |
| OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) |
| |
| if (num_remaining > TILE_ITEMS) |
| { |
| // Not last tile |
| ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state); |
| } |
| else if (num_remaining > 0) |
| { |
| // Last tile |
| ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state); |
| } |
| } |
| |
| }; |
|
|
|
|
| } // CUB namespace |
| CUB_NS_POSTFIX // Optional outer namespace(s) |
|
|
|
|