| /****************************************************************************** |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are met: |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * * Neither the name of the NVIDIA CORPORATION nor the |
| * names of its contributors may be used to endorse or promote products |
| * derived from this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| * |
| ******************************************************************************/ |
| |
| /** |
| * \file |
| * Callback operator types for supplying BlockScan prefixes |
| */ |
| |
| #pragma once |
| |
| #include <iterator> |
| |
| #include "../thread/thread_load.cuh" |
| #include "../thread/thread_store.cuh" |
| #include "../warp/warp_reduce.cuh" |
| #include "../config.cuh" |
| #include "../util_device.cuh" |
| |
| /// Optional outer namespace(s) |
| CUB_NS_PREFIX |
| |
| /// CUB namespace |
| namespace cub { |
| |
| |
| /****************************************************************************** |
| * Prefix functor type for maintaining a running prefix while scanning a |
| * region independent of other thread blocks |
| ******************************************************************************/ |
| |
| /** |
| * Stateful callback operator type for supplying BlockScan prefixes. |
| * Maintains a running prefix that can be applied to consecutive |
| * BlockScan operations. |
| */ |
| template < |
| typename T, ///< BlockScan value type |
| typename ScanOpT> ///< Wrapped scan operator type |
| struct BlockScanRunningPrefixOp |
| { |
| ScanOpT op; ///< Wrapped scan operator |
| T running_total; ///< Running block-wide prefix |
| |
| /// Constructor |
| __device__ __forceinline__ BlockScanRunningPrefixOp(ScanOpT op) |
| : |
| op(op) |
| {} |
| |
| /// Constructor |
| __device__ __forceinline__ BlockScanRunningPrefixOp( |
| T starting_prefix, |
| ScanOpT op) |
| : |
| op(op), |
| running_total(starting_prefix) |
| {} |
| |
| /** |
| * Prefix callback operator. Returns the block-wide running_total in thread-0. |
| */ |
| __device__ __forceinline__ T operator()( |
| const T &block_aggregate) ///< The aggregate sum of the BlockScan inputs |
| { |
| T retval = running_total; |
| running_total = op(running_total, block_aggregate); |
| return retval; |
| } |
| }; |
| |
|
|
| /****************************************************************************** |
| * Generic tile status interface types for block-cooperative scans |
| ******************************************************************************/ |
| |
| /** |
| * Enumerations of tile status |
| */ |
| enum ScanTileStatus |
| { |
| SCAN_TILE_OOB, // Out-of-bounds (e.g., padding) |
| SCAN_TILE_INVALID = 99, // Not yet processed |
| SCAN_TILE_PARTIAL, // Tile aggregate is available |
| SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available |
| }; |
| |
| |
| /** |
| * Tile status interface. |
| */ |
| template < |
| typename T, |
| bool SINGLE_WORD = Traits<T>::PRIMITIVE> |
| struct ScanTileState; |
| |
| |
| /** |
| * Tile status interface specialized for scan status and value types |
| * that can be combined into one machine word that can be |
| * read/written coherently in a single access. |
| */ |
| template <typename T> |
| struct ScanTileState<T, true> |
| { |
| // Status word type |
| typedef typename If<(sizeof(T) == 8), |
| long long, |
| typename If<(sizeof(T) == 4), |
| int, |
| typename If<(sizeof(T) == 2), |
| short, |
| char>::Type>::Type>::Type StatusWord; |
| |
| |
| // Unit word type |
| typedef typename If<(sizeof(T) == 8), |
| longlong2, |
| typename If<(sizeof(T) == 4), |
| int2, |
| typename If<(sizeof(T) == 2), |
| int, |
| uchar2>::Type>::Type>::Type TxnWord; |
| |
| |
| // Device word type |
| struct TileDescriptor |
| { |
| StatusWord status; |
| T value; |
| }; |
| |
| |
| // Constants |
| enum |
| { |
| TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, |
| }; |
| |
| |
| // Device storage |
| TxnWord *d_tile_descriptors; |
|
|
| /// Constructor |
| __host__ __device__ __forceinline__ |
| ScanTileState() |
| : |
| d_tile_descriptors(NULL) |
| {} |
| |
|
|
| /// Initializer |
| __host__ __device__ __forceinline__ |
| cudaError_t Init( |
| int /*num_tiles*/, ///< [in] Number of tiles |
| void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. |
| size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation |
| { |
| d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); |
| return cudaSuccess; |
| } |
| |
|
|
| /** |
| * Compute device memory needed for tile status |
| */ |
| __host__ __device__ __forceinline__ |
| static cudaError_t AllocationSize( |
| int num_tiles, ///< [in] Number of tiles |
| size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation |
| { |
| temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors |
| return cudaSuccess; |
| } |
| |
|
|
| /** |
| * Initialize (from device) |
| */ |
| __device__ __forceinline__ void InitializeStatus(int num_tiles) |
| { |
| int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; |
| |
| TxnWord val = TxnWord(); |
| TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); |
| |
| if (tile_idx < num_tiles) |
| { |
| // Not-yet-set |
| descriptor->status = StatusWord(SCAN_TILE_INVALID); |
| d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; |
| } |
| |
| if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) |
| { |
| // Padding |
| descriptor->status = StatusWord(SCAN_TILE_OOB); |
| d_tile_descriptors[threadIdx.x] = val; |
| } |
| } |
| |
|
|
| /** |
| * Update the specified tile's inclusive value and corresponding status |
| */ |
| __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) |
| { |
| TileDescriptor tile_descriptor; |
| tile_descriptor.status = SCAN_TILE_INCLUSIVE; |
| tile_descriptor.value = tile_inclusive; |
| |
| TxnWord alias; |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; |
| ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); |
| } |
| |
|
|
| /** |
| * Update the specified tile's partial value and corresponding status |
| */ |
| __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) |
| { |
| TileDescriptor tile_descriptor; |
| tile_descriptor.status = SCAN_TILE_PARTIAL; |
| tile_descriptor.value = tile_partial; |
| |
| TxnWord alias; |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; |
| ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); |
| } |
| |
| /** |
| * Wait for the corresponding tile to become non-invalid |
| */ |
| __device__ __forceinline__ void WaitForValid( |
| int tile_idx, |
| StatusWord &status, |
| T &value) |
| { |
| TileDescriptor tile_descriptor; |
| do |
| { |
| __threadfence_block(); // prevent hoisting loads from loop |
| TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); |
| tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); |
| |
| } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); |
| |
| status = tile_descriptor.status; |
| value = tile_descriptor.value; |
| } |
| |
| }; |
|
|
|
|
|
|
| /** |
| * Tile status interface specialized for scan status and value types that |
| * cannot be combined into one machine word. |
| */ |
| template <typename T> |
| struct ScanTileState<T, false> |
| { |
| // Status word type |
| typedef char StatusWord; |
| |
| // Constants |
| enum |
| { |
| TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, |
| }; |
| |
| // Device storage |
| StatusWord *d_tile_status; |
| T *d_tile_partial; |
| T *d_tile_inclusive; |
| |
| /// Constructor |
| __host__ __device__ __forceinline__ |
| ScanTileState() |
| : |
| d_tile_status(NULL), |
| d_tile_partial(NULL), |
| d_tile_inclusive(NULL) |
| {} |
| |
|
|
| /// Initializer |
| __host__ __device__ __forceinline__ |
| cudaError_t Init( |
| int num_tiles, ///< [in] Number of tiles |
| void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. |
| size_t temp_storage_bytes) ///< [in] Size in bytes of \t d_temp_storage allocation |
| { |
| cudaError_t error = cudaSuccess; |
| do |
| { |
| void* allocations[3] = {}; |
| size_t allocation_sizes[3]; |
| |
| allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors |
| allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials |
| allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives |
| |
| // Compute allocation pointers into the single storage blob |
| if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; |
| |
| // Alias the offsets |
| d_tile_status = reinterpret_cast<StatusWord*>(allocations[0]); |
| d_tile_partial = reinterpret_cast<T*>(allocations[1]); |
| d_tile_inclusive = reinterpret_cast<T*>(allocations[2]); |
| } |
| while (0); |
| |
| return error; |
| } |
| |
|
|
| /** |
| * Compute device memory needed for tile status |
| */ |
| __host__ __device__ __forceinline__ |
| static cudaError_t AllocationSize( |
| int num_tiles, ///< [in] Number of tiles |
| size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation |
| { |
| // Specify storage allocation requirements |
| size_t allocation_sizes[3]; |
| allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors |
| allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials |
| allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives |
| |
| // Set the necessary size of the blob |
| void* allocations[3] = {}; |
| return CubDebug(AliasTemporaries(NULL, temp_storage_bytes, allocations, allocation_sizes)); |
| } |
| |
|
|
| /** |
| * Initialize (from device) |
| */ |
| __device__ __forceinline__ void InitializeStatus(int num_tiles) |
| { |
| int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; |
| if (tile_idx < num_tiles) |
| { |
| // Not-yet-set |
| d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID); |
| } |
| |
| if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) |
| { |
| // Padding |
| d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB); |
| } |
| } |
| |
|
|
| /** |
| * Update the specified tile's inclusive value and corresponding status |
| */ |
| __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) |
| { |
| // Update tile inclusive value |
| ThreadStore<STORE_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive); |
| |
| // Fence |
| __threadfence(); |
| |
| // Update tile status |
| ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE)); |
| } |
| |
|
|
| /** |
| * Update the specified tile's partial value and corresponding status |
| */ |
| __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) |
| { |
| // Update tile partial value |
| ThreadStore<STORE_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial); |
| |
| // Fence |
| __threadfence(); |
| |
| // Update tile status |
| ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL)); |
| } |
| |
| /** |
| * Wait for the corresponding tile to become non-invalid |
| */ |
| __device__ __forceinline__ void WaitForValid( |
| int tile_idx, |
| StatusWord &status, |
| T &value) |
| { |
| do { |
| status = ThreadLoad<LOAD_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx); |
| |
| __threadfence(); // prevent hoisting loads from loop or loads below above this one |
| |
| } while (status == SCAN_TILE_INVALID); |
| |
| if (status == StatusWord(SCAN_TILE_PARTIAL)) |
| value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx); |
| else |
| value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx); |
| } |
| }; |
| |
|
|
| /****************************************************************************** |
| * ReduceByKey tile status interface types for block-cooperative scans |
| ******************************************************************************/ |
| |
| /** |
| * Tile status interface for reduction by key. |
| * |
| */ |
| template < |
| typename ValueT, |
| typename KeyT, |
| bool SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)> |
| struct ReduceByKeyScanTileState; |
| |
| |
| /** |
| * Tile status interface for reduction by key, specialized for scan status and value types that |
| * cannot be combined into one machine word. |
| */ |
| template < |
| typename ValueT, |
| typename KeyT> |
| struct ReduceByKeyScanTileState<ValueT, KeyT, false> : |
| ScanTileState<KeyValuePair<KeyT, ValueT> > |
| { |
| typedef ScanTileState<KeyValuePair<KeyT, ValueT> > SuperClass; |
| |
| /// Constructor |
| __host__ __device__ __forceinline__ |
| ReduceByKeyScanTileState() : SuperClass() {} |
| }; |
| |
| |
| /** |
| * Tile status interface for reduction by key, specialized for scan status and value types that |
| * can be combined into one machine word that can be read/written coherently in a single access. |
| */ |
| template < |
| typename ValueT, |
| typename KeyT> |
| struct ReduceByKeyScanTileState<ValueT, KeyT, true> |
| { |
| typedef KeyValuePair<KeyT, ValueT>KeyValuePairT; |
| |
| // Constants |
| enum |
| { |
| PAIR_SIZE = sizeof(ValueT) + sizeof(KeyT), |
| TXN_WORD_SIZE = 1 << Log2<PAIR_SIZE + 1>::VALUE, |
| STATUS_WORD_SIZE = TXN_WORD_SIZE - PAIR_SIZE, |
| |
| TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, |
| }; |
| |
| // Status word type |
| typedef typename If<(STATUS_WORD_SIZE == 8), |
| long long, |
| typename If<(STATUS_WORD_SIZE == 4), |
| int, |
| typename If<(STATUS_WORD_SIZE == 2), |
| short, |
| char>::Type>::Type>::Type StatusWord; |
| |
| // Status word type |
| typedef typename If<(TXN_WORD_SIZE == 16), |
| longlong2, |
| typename If<(TXN_WORD_SIZE == 8), |
| long long, |
| int>::Type>::Type TxnWord; |
| |
| // Device word type (for when sizeof(ValueT) == sizeof(KeyT)) |
| struct TileDescriptorBigStatus |
| { |
| KeyT key; |
| ValueT value; |
| StatusWord status; |
| }; |
| |
| // Device word type (for when sizeof(ValueT) != sizeof(KeyT)) |
| struct TileDescriptorLittleStatus |
| { |
| ValueT value; |
| StatusWord status; |
| KeyT key; |
| }; |
| |
| // Device word type |
| typedef typename If< |
| (sizeof(ValueT) == sizeof(KeyT)), |
| TileDescriptorBigStatus, |
| TileDescriptorLittleStatus>::Type |
| TileDescriptor; |
| |
| |
| // Device storage |
| TxnWord *d_tile_descriptors; |
|
|
|
|
| /// Constructor |
| __host__ __device__ __forceinline__ |
| ReduceByKeyScanTileState() |
| : |
| d_tile_descriptors(NULL) |
| {} |
| |
|
|
| /// Initializer |
| __host__ __device__ __forceinline__ |
| cudaError_t Init( |
| int /*num_tiles*/, ///< [in] Number of tiles |
| void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. |
| size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation |
| { |
| d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); |
| return cudaSuccess; |
| } |
| |
|
|
| /** |
| * Compute device memory needed for tile status |
| */ |
| __host__ __device__ __forceinline__ |
| static cudaError_t AllocationSize( |
| int num_tiles, ///< [in] Number of tiles |
| size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation |
| { |
| temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors |
| return cudaSuccess; |
| } |
| |
|
|
| /** |
| * Initialize (from device) |
| */ |
| __device__ __forceinline__ void InitializeStatus(int num_tiles) |
| { |
| int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; |
| TxnWord val = TxnWord(); |
| TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); |
| |
| if (tile_idx < num_tiles) |
| { |
| // Not-yet-set |
| descriptor->status = StatusWord(SCAN_TILE_INVALID); |
| d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; |
| } |
| |
| if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) |
| { |
| // Padding |
| descriptor->status = StatusWord(SCAN_TILE_OOB); |
| d_tile_descriptors[threadIdx.x] = val; |
| } |
| } |
| |
|
|
| /** |
| * Update the specified tile's inclusive value and corresponding status |
| */ |
| __device__ __forceinline__ void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive) |
| { |
| TileDescriptor tile_descriptor; |
| tile_descriptor.status = SCAN_TILE_INCLUSIVE; |
| tile_descriptor.value = tile_inclusive.value; |
| tile_descriptor.key = tile_inclusive.key; |
| |
| TxnWord alias; |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; |
| ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); |
| } |
| |
|
|
| /** |
| * Update the specified tile's partial value and corresponding status |
| */ |
| __device__ __forceinline__ void SetPartial(int tile_idx, KeyValuePairT tile_partial) |
| { |
| TileDescriptor tile_descriptor; |
| tile_descriptor.status = SCAN_TILE_PARTIAL; |
| tile_descriptor.value = tile_partial.value; |
| tile_descriptor.key = tile_partial.key; |
| |
| TxnWord alias; |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; |
| ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); |
| } |
| |
| /** |
| * Wait for the corresponding tile to become non-invalid |
| */ |
| __device__ __forceinline__ void WaitForValid( |
| int tile_idx, |
| StatusWord &status, |
| KeyValuePairT &value) |
| { |
| // TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); |
| // TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); |
| // |
| // while (tile_descriptor.status == SCAN_TILE_INVALID) |
| // { |
| // __threadfence_block(); // prevent hoisting loads from loop |
| // |
| // alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); |
| // tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); |
| // } |
| // |
| // status = tile_descriptor.status; |
| // value.value = tile_descriptor.value; |
| // value.key = tile_descriptor.key; |
| |
| TileDescriptor tile_descriptor; |
| do |
| { |
| __threadfence_block(); // prevent hoisting loads from loop |
| TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); |
| tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); |
| |
| } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); |
| |
| status = tile_descriptor.status; |
| value.value = tile_descriptor.value; |
| value.key = tile_descriptor.key; |
| } |
| |
| }; |
|
|
|
|
| /****************************************************************************** |
| * Prefix call-back operator for coupling local block scan within a |
| * block-cooperative scan |
| ******************************************************************************/ |
| |
| /** |
| * Stateful block-scan prefix functor. Provides the the running prefix for |
| * the current tile by using the call-back warp to wait on on |
| * aggregates/prefixes from predecessor tiles to become available. |
| */ |
| template < |
| typename T, |
| typename ScanOpT, |
| typename ScanTileStateT, |
| int PTX_ARCH = CUB_PTX_ARCH> |
| struct TilePrefixCallbackOp |
| { |
| // Parameterized warp reduce |
| typedef WarpReduce<T, CUB_PTX_WARP_THREADS, PTX_ARCH> WarpReduceT; |
| |
| // Temporary storage type |
| struct _TempStorage |
| { |
| typename WarpReduceT::TempStorage warp_reduce; |
| T exclusive_prefix; |
| T inclusive_prefix; |
| T block_aggregate; |
| }; |
| |
| // Alias wrapper allowing temporary storage to be unioned |
| struct TempStorage : Uninitialized<_TempStorage> {}; |
| |
| // Type of status word |
| typedef typename ScanTileStateT::StatusWord StatusWord; |
| |
| // Fields |
| _TempStorage& temp_storage; ///< Reference to a warp-reduction instance |
| ScanTileStateT& tile_status; ///< Interface to tile status |
| ScanOpT scan_op; ///< Binary scan operator |
| int tile_idx; ///< The current tile index |
| T exclusive_prefix; ///< Exclusive prefix for the tile |
| T inclusive_prefix; ///< Inclusive prefix for the tile |
| |
| // Constructor |
| __device__ __forceinline__ |
| TilePrefixCallbackOp( |
| ScanTileStateT &tile_status, |
| TempStorage &temp_storage, |
| ScanOpT scan_op, |
| int tile_idx) |
| : |
| temp_storage(temp_storage.Alias()), |
| tile_status(tile_status), |
| scan_op(scan_op), |
| tile_idx(tile_idx) {} |
| |
| |
| // Block until all predecessors within the warp-wide window have non-invalid status |
| __device__ __forceinline__ |
| void ProcessWindow( |
| int predecessor_idx, ///< Preceding tile index to inspect |
| StatusWord &predecessor_status, ///< [out] Preceding tile status |
| T &window_aggregate) ///< [out] Relevant partial reduction from this window of preceding tiles |
| { |
| T value; |
| tile_status.WaitForValid(predecessor_idx, predecessor_status, value); |
| |
| // Perform a segmented reduction to get the prefix for the current window. |
| // Use the swizzled scan operator because we are now scanning *down* towards thread0. |
|
|
| int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE)); |
| window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce( |
| value, |
| tail_flag, |
| SwizzleScanOp<ScanOpT>(scan_op)); |
| } |
| |
|
|
| // BlockScan prefix callback functor (called by the first warp) |
| __device__ __forceinline__ |
| T operator()(T block_aggregate) |
| { |
| |
| // Update our status with our tile-aggregate |
| if (threadIdx.x == 0) |
| { |
| temp_storage.block_aggregate = block_aggregate; |
| tile_status.SetPartial(tile_idx, block_aggregate); |
| } |
| |
| int predecessor_idx = tile_idx - threadIdx.x - 1; |
| StatusWord predecessor_status; |
| T window_aggregate; |
| |
| // Wait for the warp-wide window of predecessor tiles to become valid |
| ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); |
| |
| // The exclusive tile prefix starts out as the current window aggregate |
| exclusive_prefix = window_aggregate; |
| |
| // Keep sliding the window back until we come across a tile whose inclusive prefix is known |
| while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff)) |
| { |
| predecessor_idx -= CUB_PTX_WARP_THREADS; |
| |
| // Update exclusive tile prefix with the window prefix |
| ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); |
| exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); |
| } |
| |
| // Compute the inclusive tile prefix and update the status for this tile |
| if (threadIdx.x == 0) |
| { |
| inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); |
| tile_status.SetInclusive(tile_idx, inclusive_prefix); |
| |
| temp_storage.exclusive_prefix = exclusive_prefix; |
| temp_storage.inclusive_prefix = inclusive_prefix; |
| } |
| |
| // Return exclusive_prefix |
| return exclusive_prefix; |
| } |
| |
| // Get the exclusive prefix stored in temporary storage |
| __device__ __forceinline__ |
| T GetExclusivePrefix() |
| { |
| return temp_storage.exclusive_prefix; |
| } |
| |
| // Get the inclusive prefix stored in temporary storage |
| __device__ __forceinline__ |
| T GetInclusivePrefix() |
| { |
| return temp_storage.inclusive_prefix; |
| } |
| |
| // Get the block aggregate stored in temporary storage |
| __device__ __forceinline__ |
| T GetBlockAggregate() |
| { |
| return temp_storage.block_aggregate; |
| } |
| |
| }; |
|
|
|
|
| } // CUB namespace |
| CUB_NS_POSTFIX // Optional outer namespace(s) |
|
|
|
|