/****************************************************************************** * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * @file AgentScanByKey implements a stateful abstraction of CUDA thread blocks * for participating in device-wide prefix scan by key. */ #pragma once #include #include #include #include #include #include #include #include #include CUB_NAMESPACE_BEGIN /****************************************************************************** * Tuning policy types ******************************************************************************/ /** * Parameterizable tuning policy type for AgentScanByKey * * @tparam DelayConstructorT * Implementation detail, do not specify directly, requirements on the * content of this type are subject to breaking change. */ template > struct AgentScanByKeyPolicy { static constexpr int BLOCK_THREADS = _BLOCK_THREADS; static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD; static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; static constexpr BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; static constexpr BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; struct detail { using delay_constructor_t = DelayConstructorT; }; }; /****************************************************************************** * Thread block abstractions ******************************************************************************/ /** * @brief AgentScanByKey implements a stateful abstraction of CUDA thread * blocks for participating in device-wide prefix scan by key. * * @tparam AgentScanByKeyPolicyT * Parameterized AgentScanPolicyT tuning policy type * * @tparam KeysInputIteratorT * Random-access input iterator type * * @tparam ValuesInputIteratorT * Random-access input iterator type * * @tparam ValuesOutputIteratorT * Random-access output iterator type * * @tparam EqualityOp * Equality functor type * * @tparam ScanOpT * Scan functor type * * @tparam InitValueT * The init_value element for ScanOpT type (cub::NullType for inclusive scan) * * @tparam OffsetT * Signed integer type for global offsets * */ template struct AgentScanByKey { //--------------------------------------------------------------------- // Types and constants //--------------------------------------------------------------------- using KeyT = cub::detail::value_t; using InputT = cub::detail::value_t; using SizeValuePairT = KeyValuePair; using KeyValuePairT = KeyValuePair; using ReduceBySegmentOpT = ReduceBySegmentOp; using ScanTileStateT = ReduceByKeyScanTileState; // Constants // Inclusive scan if no init_value type is provided static constexpr int IS_INCLUSIVE = std::is_same::value; static constexpr int BLOCK_THREADS = AgentScanByKeyPolicyT::BLOCK_THREADS; static constexpr int ITEMS_PER_THREAD = AgentScanByKeyPolicyT::ITEMS_PER_THREAD; static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD; using WrappedKeysInputIteratorT = cub::detail::conditional_t< std::is_pointer::value, CacheModifiedInputIterator, KeysInputIteratorT>; using WrappedValuesInputIteratorT = cub::detail::conditional_t< std::is_pointer::value, CacheModifiedInputIterator, ValuesInputIteratorT>; using BlockLoadKeysT = BlockLoad; using BlockLoadValuesT = BlockLoad; using BlockStoreValuesT = BlockStore; using BlockDiscontinuityKeysT = BlockDiscontinuity; using DelayConstructorT = typename AgentScanByKeyPolicyT::detail::delay_constructor_t; using TilePrefixCallbackT = TilePrefixCallbackOp; using BlockScanT = BlockScan; union TempStorage_ { struct ScanStorage { typename BlockScanT::TempStorage scan; typename TilePrefixCallbackT::TempStorage prefix; typename BlockDiscontinuityKeysT::TempStorage discontinuity; } scan_storage; typename BlockLoadKeysT::TempStorage load_keys; typename BlockLoadValuesT::TempStorage load_values; typename BlockStoreValuesT::TempStorage store_values; }; struct TempStorage : cub::Uninitialized {}; //--------------------------------------------------------------------- // Per-thread fields //--------------------------------------------------------------------- TempStorage_ &storage; WrappedKeysInputIteratorT d_keys_in; KeyT *d_keys_prev_in; WrappedValuesInputIteratorT d_values_in; ValuesOutputIteratorT d_values_out; InequalityWrapper inequality_op; ScanOpT scan_op; ReduceBySegmentOpT pair_scan_op; InitValueT init_value; //--------------------------------------------------------------------- // Block scan utility methods (first tile) //--------------------------------------------------------------------- // Exclusive scan specialization __device__ __forceinline__ void ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT &tile_aggregate, Int2Type /* is_inclusive */) { BlockScanT(storage.scan_storage.scan) .ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate); } // Inclusive scan specialization __device__ __forceinline__ void ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT &tile_aggregate, Int2Type /* is_inclusive */) { BlockScanT(storage.scan_storage.scan) .InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate); } //--------------------------------------------------------------------- // Block scan utility methods (subsequent tiles) //--------------------------------------------------------------------- // Exclusive scan specialization (with prefix from predecessors) __device__ __forceinline__ void ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT &tile_aggregate, TilePrefixCallbackT &prefix_op, Int2Type /* is_incclusive */) { BlockScanT(storage.scan_storage.scan) .ExclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op); tile_aggregate = prefix_op.GetBlockAggregate(); } // Inclusive scan specialization (with prefix from predecessors) __device__ __forceinline__ void ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT &tile_aggregate, TilePrefixCallbackT &prefix_op, Int2Type /* is_inclusive */) { BlockScanT(storage.scan_storage.scan) .InclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op); tile_aggregate = prefix_op.GetBlockAggregate(); } //--------------------------------------------------------------------- // Zip utility methods //--------------------------------------------------------------------- template __device__ __forceinline__ void ZipValuesAndFlags(OffsetT num_remaining, AccumT (&values)[ITEMS_PER_THREAD], OffsetT (&segment_flags)[ITEMS_PER_THREAD], SizeValuePairT (&scan_items)[ITEMS_PER_THREAD]) { // Zip values and segment_flags #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { // Set segment_flags for first out-of-bounds item, zero for others if (IS_LAST_TILE && OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM == num_remaining) { segment_flags[ITEM] = 1; } scan_items[ITEM].value = values[ITEM]; scan_items[ITEM].key = segment_flags[ITEM]; } } __device__ __forceinline__ void UnzipValues(AccumT (&values)[ITEMS_PER_THREAD], SizeValuePairT (&scan_items)[ITEMS_PER_THREAD]) { // Zip values and segment_flags #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { values[ITEM] = scan_items[ITEM].value; } } template ::value, typename std::enable_if::type = 0> __device__ __forceinline__ void AddInitToScan(AccumT (&items)[ITEMS_PER_THREAD], OffsetT (&flags)[ITEMS_PER_THREAD]) { #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { items[ITEM] = flags[ITEM] ? init_value : scan_op(init_value, items[ITEM]); } } template ::value, typename std::enable_if::type = 0> __device__ __forceinline__ void AddInitToScan(AccumT (&/*items*/)[ITEMS_PER_THREAD], OffsetT (&/*flags*/)[ITEMS_PER_THREAD]) {} //--------------------------------------------------------------------- // Cooperatively scan a device-wide sequence of tiles with other CTAs //--------------------------------------------------------------------- // Process a tile of input (dynamic chained scan) // template __device__ __forceinline__ void ConsumeTile(OffsetT /*num_items*/, OffsetT num_remaining, int tile_idx, OffsetT tile_base, ScanTileStateT &tile_state) { // Load items KeyT keys[ITEMS_PER_THREAD]; AccumT values[ITEMS_PER_THREAD]; OffsetT segment_flags[ITEMS_PER_THREAD]; SizeValuePairT scan_items[ITEMS_PER_THREAD]; if (IS_LAST_TILE) { // Fill last element with the first element // because collectives are not suffix guarded BlockLoadKeysT(storage.load_keys) .Load(d_keys_in + tile_base, keys, num_remaining, *(d_keys_in + tile_base)); } else { BlockLoadKeysT(storage.load_keys).Load(d_keys_in + tile_base, keys); } CTA_SYNC(); if (IS_LAST_TILE) { // Fill last element with the first element // because collectives are not suffix guarded BlockLoadValuesT(storage.load_values) .Load(d_values_in + tile_base, values, num_remaining, *(d_values_in + tile_base)); } else { BlockLoadValuesT(storage.load_values) .Load(d_values_in + tile_base, values); } CTA_SYNC(); // first tile if (tile_idx == 0) { BlockDiscontinuityKeysT(storage.scan_storage.discontinuity) .FlagHeads(segment_flags, keys, inequality_op); // Zip values and segment_flags ZipValuesAndFlags(num_remaining, values, segment_flags, scan_items); // Exclusive scan of values and segment_flags SizeValuePairT tile_aggregate; ScanTile(scan_items, tile_aggregate, Int2Type()); if (threadIdx.x == 0) { if (!IS_LAST_TILE) { tile_state.SetInclusive(0, tile_aggregate); } scan_items[0].key = 0; } } else { KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_prev_in[tile_idx] : KeyT(); BlockDiscontinuityKeysT(storage.scan_storage.discontinuity) .FlagHeads(segment_flags, keys, inequality_op, tile_pred_key); // Zip values and segment_flags ZipValuesAndFlags(num_remaining, values, segment_flags, scan_items); SizeValuePairT tile_aggregate; TilePrefixCallbackT prefix_op(tile_state, storage.scan_storage.prefix, pair_scan_op, tile_idx); ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type()); } CTA_SYNC(); UnzipValues(values, scan_items); AddInitToScan(values, segment_flags); // Store items if (IS_LAST_TILE) { BlockStoreValuesT(storage.store_values) .Store(d_values_out + tile_base, values, num_remaining); } else { BlockStoreValuesT(storage.store_values) .Store(d_values_out + tile_base, values); } } //--------------------------------------------------------------------- // Constructor //--------------------------------------------------------------------- // Dequeue and scan tiles of items as part of a dynamic chained scan // with Init functor __device__ __forceinline__ AgentScanByKey(TempStorage &storage, KeysInputIteratorT d_keys_in, KeyT *d_keys_prev_in, ValuesInputIteratorT d_values_in, ValuesOutputIteratorT d_values_out, EqualityOp equality_op, ScanOpT scan_op, InitValueT init_value) : storage(storage.Alias()) , d_keys_in(d_keys_in) , d_keys_prev_in(d_keys_prev_in) , d_values_in(d_values_in) , d_values_out(d_values_out) , inequality_op(equality_op) , scan_op(scan_op) , pair_scan_op(scan_op) , init_value(init_value) {} /** * Scan tiles of items as part of a dynamic chained scan * * @param num_items * Total number of input items * * @param tile_state * Global tile state descriptor * * start_tile * The starting tile for the current grid */ __device__ __forceinline__ void ConsumeRange(OffsetT num_items, ScanTileStateT &tile_state, int start_tile) { int tile_idx = blockIdx.x; OffsetT tile_base = OffsetT(ITEMS_PER_TILE) * tile_idx; OffsetT num_remaining = num_items - tile_base; if (num_remaining > ITEMS_PER_TILE) { // Not the last tile (full) ConsumeTile(num_items, num_remaining, tile_idx, tile_base, tile_state); } else if (num_remaining > 0) { // The last tile (possibly partially-full) ConsumeTile(num_items, num_remaining, tile_idx, tile_base, tile_state); } } }; CUB_NAMESPACE_END