| /****************************************************************************** | |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. | |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * * Redistributions of source code must retain the above copyright | |
| * notice, this list of conditions and the following disclaimer. | |
| * * Redistributions in binary form must reproduce the above copyright | |
| * notice, this list of conditions and the following disclaimer in the | |
| * documentation and/or other materials provided with the distribution. | |
| * * Neither the name of the NVIDIA CORPORATION nor the | |
| * names of its contributors may be used to endorse or promote products | |
| * derived from this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| ******************************************************************************/ | |
| /** | |
| * \file | |
| * Callback operator types for supplying BlockScan prefixes | |
| */ | |
| #pragma once | |
| #include <iterator> | |
| #include <cub/config.cuh> | |
| #include <cub/detail/strong_load.cuh> | |
| #include <cub/detail/strong_store.cuh> | |
| #include <cub/detail/uninitialized_copy.cuh> | |
| #include <cub/thread/thread_load.cuh> | |
| #include <cub/thread/thread_store.cuh> | |
| #include <cub/util_device.cuh> | |
| #include <cub/warp/warp_reduce.cuh> | |
| #include <nv/target> | |
| CUB_NAMESPACE_BEGIN | |
| /****************************************************************************** | |
| * Prefix functor type for maintaining a running prefix while scanning a | |
| * region independent of other thread blocks | |
| ******************************************************************************/ | |
| /** | |
| * Stateful callback operator type for supplying BlockScan prefixes. | |
| * Maintains a running prefix that can be applied to consecutive | |
| * BlockScan operations. | |
| */ | |
| template < | |
| typename T, ///< BlockScan value type | |
| typename ScanOpT> ///< Wrapped scan operator type | |
| struct BlockScanRunningPrefixOp | |
| { | |
| ScanOpT op; ///< Wrapped scan operator | |
| T running_total; ///< Running block-wide prefix | |
| /// Constructor | |
| __device__ __forceinline__ BlockScanRunningPrefixOp(ScanOpT op) | |
| : | |
| op(op) | |
| {} | |
| /// Constructor | |
| __device__ __forceinline__ BlockScanRunningPrefixOp( | |
| T starting_prefix, | |
| ScanOpT op) | |
| : | |
| op(op), | |
| running_total(starting_prefix) | |
| {} | |
| /** | |
| * Prefix callback operator. Returns the block-wide running_total in thread-0. | |
| */ | |
| __device__ __forceinline__ T operator()( | |
| const T &block_aggregate) ///< The aggregate sum of the BlockScan inputs | |
| { | |
| T retval = running_total; | |
| running_total = op(running_total, block_aggregate); | |
| return retval; | |
| } | |
| }; | |
| /****************************************************************************** | |
| * Generic tile status interface types for block-cooperative scans | |
| ******************************************************************************/ | |
| /** | |
| * Enumerations of tile status | |
| */ | |
| enum ScanTileStatus | |
| { | |
| SCAN_TILE_OOB, // Out-of-bounds (e.g., padding) | |
| SCAN_TILE_INVALID = 99, // Not yet processed | |
| SCAN_TILE_PARTIAL, // Tile aggregate is available | |
| SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available | |
| }; | |
| namespace detail | |
| { | |
| template <int Delay, unsigned int GridThreshold = 500> | |
| __device__ __forceinline__ void delay() | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (if (Delay > 0) | |
| { | |
| if (gridDim.x < GridThreshold) | |
| { | |
| __threadfence_block(); | |
| } | |
| else | |
| { | |
| __nanosleep(Delay); | |
| } | |
| })); | |
| } | |
| template <unsigned int GridThreshold = 500> | |
| __device__ __forceinline__ void delay(int ns) | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (if (ns > 0) | |
| { | |
| if (gridDim.x < GridThreshold) | |
| { | |
| __threadfence_block(); | |
| } | |
| else | |
| { | |
| __nanosleep(ns); | |
| } | |
| })); | |
| } | |
| template <int Delay> | |
| __device__ __forceinline__ void always_delay() | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, (__nanosleep(Delay);)); | |
| } | |
| __device__ __forceinline__ void always_delay(int ns) | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, (__nanosleep(ns);), ((void)ns;)); | |
| } | |
| template <unsigned int Delay = 350, unsigned int GridThreshold = 500> | |
| __device__ __forceinline__ void delay_or_prevent_hoisting() | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (delay<Delay, GridThreshold>();), | |
| (__threadfence_block();)); | |
| } | |
| template <unsigned int GridThreshold = 500> | |
| __device__ __forceinline__ void delay_or_prevent_hoisting(int ns) | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (delay<GridThreshold>(ns);), | |
| ((void)ns; __threadfence_block();)); | |
| } | |
| template <unsigned int Delay = 350> | |
| __device__ __forceinline__ void always_delay_or_prevent_hoisting() | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (always_delay(Delay);), | |
| (__threadfence_block();)); | |
| } | |
| __device__ __forceinline__ void always_delay_or_prevent_hoisting(int ns) | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (always_delay(ns);), | |
| ((void)ns; __threadfence_block();)); | |
| } | |
| template <unsigned int L2WriteLatency> | |
| struct no_delay_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| __device__ __forceinline__ void operator()() | |
| { | |
| NV_IF_TARGET(NV_PROVIDES_SM_70, | |
| (), | |
| (__threadfence_block();)); | |
| } | |
| }; | |
| __device__ __forceinline__ no_delay_constructor_t(unsigned int /* seed */) | |
| { | |
| delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() { return {}; } | |
| }; | |
| template <unsigned int Delay, unsigned int L2WriteLatency, unsigned int GridThreshold = 500> | |
| struct reduce_by_key_delay_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| __device__ __forceinline__ void operator()() | |
| { | |
| NV_DISPATCH_TARGET( | |
| NV_IS_EXACTLY_SM_80, (delay<Delay, GridThreshold>();), | |
| NV_PROVIDES_SM_70, (delay< 0, GridThreshold>();), | |
| NV_IS_DEVICE, (__threadfence_block();)); | |
| } | |
| }; | |
| __device__ __forceinline__ reduce_by_key_delay_constructor_t(unsigned int /* seed */) | |
| { | |
| delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() { return {}; } | |
| }; | |
| template <unsigned int Delay, unsigned int L2WriteLatency> | |
| struct fixed_delay_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| __device__ __forceinline__ void operator()() { delay_or_prevent_hoisting<Delay>(); } | |
| }; | |
| __device__ __forceinline__ fixed_delay_constructor_t(unsigned int /* seed */) | |
| { | |
| delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() { return {}; } | |
| }; | |
| template <unsigned int InitialDelay, unsigned int L2WriteLatency> | |
| struct exponential_backoff_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| int delay; | |
| __device__ __forceinline__ void operator()() | |
| { | |
| always_delay_or_prevent_hoisting(delay); | |
| delay <<= 1; | |
| } | |
| }; | |
| __device__ __forceinline__ exponential_backoff_constructor_t(unsigned int /* seed */) | |
| { | |
| always_delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() { return {InitialDelay}; } | |
| }; | |
| template <unsigned int InitialDelay, unsigned int L2WriteLatency> | |
| struct exponential_backoff_jitter_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| static constexpr unsigned int a = 16807; | |
| static constexpr unsigned int c = 0; | |
| static constexpr unsigned int m = 1u << 31; | |
| unsigned int max_delay; | |
| unsigned int &seed; | |
| __device__ __forceinline__ unsigned int next(unsigned int min, unsigned int max) | |
| { | |
| return (seed = (a * seed + c) % m) % (max + 1 - min) + min; | |
| } | |
| __device__ __forceinline__ void operator()() | |
| { | |
| always_delay_or_prevent_hoisting(next(0, max_delay)); | |
| max_delay <<= 1; | |
| } | |
| }; | |
| unsigned int seed; | |
| __device__ __forceinline__ exponential_backoff_jitter_constructor_t(unsigned int seed) | |
| : seed(seed) | |
| { | |
| always_delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() { return {InitialDelay, seed}; } | |
| }; | |
| template <unsigned int InitialDelay, unsigned int L2WriteLatency> | |
| struct exponential_backoff_jitter_window_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| static constexpr unsigned int a = 16807; | |
| static constexpr unsigned int c = 0; | |
| static constexpr unsigned int m = 1u << 31; | |
| unsigned int max_delay; | |
| unsigned int &seed; | |
| __device__ __forceinline__ unsigned int next(unsigned int min, unsigned int max) | |
| { | |
| return (seed = (a * seed + c) % m) % (max + 1 - min) + min; | |
| } | |
| __device__ __forceinline__ void operator()() | |
| { | |
| unsigned int next_max_delay = max_delay << 1; | |
| always_delay_or_prevent_hoisting(next(max_delay, next_max_delay)); | |
| max_delay = next_max_delay; | |
| } | |
| }; | |
| unsigned int seed; | |
| __device__ __forceinline__ exponential_backoff_jitter_window_constructor_t(unsigned int seed) | |
| : seed(seed) | |
| { | |
| always_delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() { return {InitialDelay, seed}; } | |
| }; | |
| template <unsigned int InitialDelay, unsigned int L2WriteLatency> | |
| struct exponential_backon_jitter_window_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| static constexpr unsigned int a = 16807; | |
| static constexpr unsigned int c = 0; | |
| static constexpr unsigned int m = 1u << 31; | |
| unsigned int max_delay; | |
| unsigned int &seed; | |
| __device__ __forceinline__ unsigned int next(unsigned int min, unsigned int max) | |
| { | |
| return (seed = (a * seed + c) % m) % (max + 1 - min) + min; | |
| } | |
| __device__ __forceinline__ void operator()() | |
| { | |
| int prev_delay = max_delay >> 1; | |
| always_delay_or_prevent_hoisting(next(prev_delay, max_delay)); | |
| max_delay = prev_delay; | |
| } | |
| }; | |
| unsigned int seed; | |
| unsigned int max_delay = InitialDelay; | |
| __device__ __forceinline__ exponential_backon_jitter_window_constructor_t(unsigned int seed) | |
| : seed(seed) | |
| { | |
| always_delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() | |
| { | |
| max_delay >>= 1; | |
| return {max_delay, seed}; | |
| } | |
| }; | |
| template <unsigned int InitialDelay, unsigned int L2WriteLatency> | |
| struct exponential_backon_jitter_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| static constexpr unsigned int a = 16807; | |
| static constexpr unsigned int c = 0; | |
| static constexpr unsigned int m = 1u << 31; | |
| unsigned int max_delay; | |
| unsigned int &seed; | |
| __device__ __forceinline__ unsigned int next(unsigned int min, unsigned int max) | |
| { | |
| return (seed = (a * seed + c) % m) % (max + 1 - min) + min; | |
| } | |
| __device__ __forceinline__ void operator()() | |
| { | |
| always_delay_or_prevent_hoisting(next(0, max_delay)); | |
| max_delay >>= 1; | |
| } | |
| }; | |
| unsigned int seed; | |
| unsigned int max_delay = InitialDelay; | |
| __device__ __forceinline__ exponential_backon_jitter_constructor_t(unsigned int seed) | |
| : seed(seed) | |
| { | |
| always_delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() | |
| { | |
| max_delay >>= 1; | |
| return {max_delay, seed}; | |
| } | |
| }; | |
| template <unsigned int InitialDelay, unsigned int L2WriteLatency> | |
| struct exponential_backon_constructor_t | |
| { | |
| struct delay_t | |
| { | |
| unsigned int delay; | |
| __device__ __forceinline__ void operator()() | |
| { | |
| always_delay_or_prevent_hoisting(delay); | |
| delay >>= 1; | |
| } | |
| }; | |
| unsigned int max_delay = InitialDelay; | |
| __device__ __forceinline__ exponential_backon_constructor_t(unsigned int /* seed */) | |
| { | |
| always_delay<L2WriteLatency>(); | |
| } | |
| __device__ __forceinline__ delay_t operator()() | |
| { | |
| max_delay >>= 1; | |
| return {max_delay}; | |
| } | |
| }; | |
| using default_no_delay_constructor_t = no_delay_constructor_t<450>; | |
| using default_no_delay_t = default_no_delay_constructor_t::delay_t; | |
| template <class T> | |
| using default_delay_constructor_t = cub::detail::conditional_t<Traits<T>::PRIMITIVE, | |
| fixed_delay_constructor_t<350, 450>, | |
| default_no_delay_constructor_t>; | |
| template <class T> | |
| using default_delay_t = typename default_delay_constructor_t<T>::delay_t; | |
| template <class KeyT, class ValueT> | |
| using default_reduce_by_key_delay_constructor_t = | |
| detail::conditional_t<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16), | |
| reduce_by_key_delay_constructor_t<350, 450>, | |
| default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>; | |
| } | |
| /** | |
| * Tile status interface. | |
| */ | |
| template < | |
| typename T, | |
| bool SINGLE_WORD = Traits<T>::PRIMITIVE> | |
| struct ScanTileState; | |
| /** | |
| * Tile status interface specialized for scan status and value types | |
| * that can be combined into one machine word that can be | |
| * read/written coherently in a single access. | |
| */ | |
| template <typename T> | |
| struct ScanTileState<T, true> | |
| { | |
| // Status word type | |
| using StatusWord = cub::detail::conditional_t< | |
| sizeof(T) == 8, | |
| unsigned long long, | |
| cub::detail::conditional_t< | |
| sizeof(T) == 4, | |
| unsigned int, | |
| cub::detail::conditional_t<sizeof(T) == 2, unsigned short, unsigned char>>>; | |
| // Unit word type | |
| using TxnWord = cub::detail::conditional_t< | |
| sizeof(T) == 8, | |
| ulonglong2, | |
| cub::detail::conditional_t< | |
| sizeof(T) == 4, | |
| uint2, | |
| unsigned int>>; | |
| // Device word type | |
| struct TileDescriptor | |
| { | |
| StatusWord status; | |
| T value; | |
| }; | |
| // Constants | |
| enum | |
| { | |
| TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, | |
| }; | |
| // Device storage | |
| TxnWord *d_tile_descriptors; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ | |
| ScanTileState() | |
| : | |
| d_tile_descriptors(NULL) | |
| {} | |
| /// Initializer | |
| __host__ __device__ __forceinline__ | |
| cudaError_t Init( | |
| int /*num_tiles*/, ///< [in] Number of tiles | |
| void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. | |
| size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation | |
| { | |
| d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); | |
| return cudaSuccess; | |
| } | |
| /** | |
| * Compute device memory needed for tile status | |
| */ | |
| __host__ __device__ __forceinline__ | |
| static cudaError_t AllocationSize( | |
| int num_tiles, ///< [in] Number of tiles | |
| size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation | |
| { | |
| temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TxnWord); // bytes needed for tile status descriptors | |
| return cudaSuccess; | |
| } | |
| /** | |
| * Initialize (from device) | |
| */ | |
| __device__ __forceinline__ void InitializeStatus(int num_tiles) | |
| { | |
| int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; | |
| TxnWord val = TxnWord(); | |
| TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); | |
| if (tile_idx < num_tiles) | |
| { | |
| // Not-yet-set | |
| descriptor->status = StatusWord(SCAN_TILE_INVALID); | |
| d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; | |
| } | |
| if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) | |
| { | |
| // Padding | |
| descriptor->status = StatusWord(SCAN_TILE_OOB); | |
| d_tile_descriptors[threadIdx.x] = val; | |
| } | |
| } | |
| /** | |
| * Update the specified tile's inclusive value and corresponding status | |
| */ | |
| __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) | |
| { | |
| TileDescriptor tile_descriptor; | |
| tile_descriptor.status = SCAN_TILE_INCLUSIVE; | |
| tile_descriptor.value = tile_inclusive; | |
| TxnWord alias; | |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
| detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
| } | |
| /** | |
| * Update the specified tile's partial value and corresponding status | |
| */ | |
| __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) | |
| { | |
| TileDescriptor tile_descriptor; | |
| tile_descriptor.status = SCAN_TILE_PARTIAL; | |
| tile_descriptor.value = tile_partial; | |
| TxnWord alias; | |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
| detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
| } | |
| /** | |
| * Wait for the corresponding tile to become non-invalid | |
| */ | |
| template <class DelayT = detail::default_delay_t<T>> | |
| __device__ __forceinline__ void WaitForValid( | |
| int tile_idx, | |
| StatusWord &status, | |
| T &value, | |
| DelayT delay_or_prevent_hoisting = {}) | |
| { | |
| TileDescriptor tile_descriptor; | |
| { | |
| TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
| tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
| } | |
| while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)) | |
| { | |
| delay_or_prevent_hoisting(); | |
| TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
| tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
| } | |
| status = tile_descriptor.status; | |
| value = tile_descriptor.value; | |
| } | |
| /** | |
| * Loads and returns the tile's value. The returned value is undefined if either (a) the tile's status is invalid or | |
| * (b) there is no memory fence between reading a non-invalid status and the call to LoadValid. | |
| */ | |
| __device__ __forceinline__ T LoadValid(int tile_idx) | |
| { | |
| TxnWord alias = d_tile_descriptors[TILE_STATUS_PADDING + tile_idx]; | |
| TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
| return tile_descriptor.value; | |
| } | |
| }; | |
| /** | |
| * Tile status interface specialized for scan status and value types that | |
| * cannot be combined into one machine word. | |
| */ | |
| template <typename T> | |
| struct ScanTileState<T, false> | |
| { | |
| // Status word type | |
| using StatusWord = unsigned int; | |
| // Constants | |
| enum | |
| { | |
| TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, | |
| }; | |
| // Device storage | |
| StatusWord *d_tile_status; | |
| T *d_tile_partial; | |
| T *d_tile_inclusive; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ | |
| ScanTileState() | |
| : | |
| d_tile_status(NULL), | |
| d_tile_partial(NULL), | |
| d_tile_inclusive(NULL) | |
| {} | |
| /// Initializer | |
| __host__ __device__ __forceinline__ | |
| cudaError_t Init( | |
| int num_tiles, ///< [in] Number of tiles | |
| void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. | |
| size_t temp_storage_bytes) ///< [in] Size in bytes of \t d_temp_storage allocation | |
| { | |
| cudaError_t error = cudaSuccess; | |
| do | |
| { | |
| void* allocations[3] = {}; | |
| size_t allocation_sizes[3]; | |
| allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors | |
| allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials | |
| allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives | |
| // Compute allocation pointers into the single storage blob | |
| if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; | |
| // Alias the offsets | |
| d_tile_status = reinterpret_cast<StatusWord*>(allocations[0]); | |
| d_tile_partial = reinterpret_cast<T*>(allocations[1]); | |
| d_tile_inclusive = reinterpret_cast<T*>(allocations[2]); | |
| } | |
| while (0); | |
| return error; | |
| } | |
| /** | |
| * Compute device memory needed for tile status | |
| */ | |
| __host__ __device__ __forceinline__ | |
| static cudaError_t AllocationSize( | |
| int num_tiles, ///< [in] Number of tiles | |
| size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation | |
| { | |
| // Specify storage allocation requirements | |
| size_t allocation_sizes[3]; | |
| allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors | |
| allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials | |
| allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives | |
| // Set the necessary size of the blob | |
| void* allocations[3] = {}; | |
| return CubDebug(AliasTemporaries(NULL, temp_storage_bytes, allocations, allocation_sizes)); | |
| } | |
| /** | |
| * Initialize (from device) | |
| */ | |
| __device__ __forceinline__ void InitializeStatus(int num_tiles) | |
| { | |
| int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; | |
| if (tile_idx < num_tiles) | |
| { | |
| // Not-yet-set | |
| d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID); | |
| } | |
| if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) | |
| { | |
| // Padding | |
| d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB); | |
| } | |
| } | |
| /** | |
| * Update the specified tile's inclusive value and corresponding status | |
| */ | |
| __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) | |
| { | |
| // Update tile inclusive value | |
| ThreadStore<STORE_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive); | |
| detail::store_release(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE)); | |
| } | |
| /** | |
| * Update the specified tile's partial value and corresponding status | |
| */ | |
| __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) | |
| { | |
| // Update tile partial value | |
| ThreadStore<STORE_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial); | |
| detail::store_release(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL)); | |
| } | |
| /** | |
| * Wait for the corresponding tile to become non-invalid | |
| */ | |
| template <class DelayT = detail::default_no_delay_t> | |
| __device__ __forceinline__ void WaitForValid( | |
| int tile_idx, | |
| StatusWord &status, | |
| T &value, | |
| DelayT delay = {}) | |
| { | |
| do | |
| { | |
| delay(); | |
| status = detail::load_relaxed(d_tile_status + TILE_STATUS_PADDING + tile_idx); | |
| __threadfence(); | |
| } while (WARP_ANY((status == SCAN_TILE_INVALID), 0xffffffff)); | |
| if (status == StatusWord(SCAN_TILE_PARTIAL)) | |
| { | |
| value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx); | |
| } | |
| else | |
| { | |
| value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx); | |
| } | |
| } | |
| /** | |
| * Loads and returns the tile's value. The returned value is undefined if either (a) the tile's status is invalid or | |
| * (b) there is no memory fence between reading a non-invalid status and the call to LoadValid. | |
| */ | |
| __device__ __forceinline__ T LoadValid(int tile_idx) | |
| { | |
| return d_tile_inclusive[TILE_STATUS_PADDING + tile_idx]; | |
| } | |
| }; | |
| /****************************************************************************** | |
| * ReduceByKey tile status interface types for block-cooperative scans | |
| ******************************************************************************/ | |
| /** | |
| * Tile status interface for reduction by key. | |
| * | |
| */ | |
| template < | |
| typename ValueT, | |
| typename KeyT, | |
| bool SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)> | |
| struct ReduceByKeyScanTileState; | |
| /** | |
| * Tile status interface for reduction by key, specialized for scan status and value types that | |
| * cannot be combined into one machine word. | |
| */ | |
| template < | |
| typename ValueT, | |
| typename KeyT> | |
| struct ReduceByKeyScanTileState<ValueT, KeyT, false> : | |
| ScanTileState<KeyValuePair<KeyT, ValueT> > | |
| { | |
| typedef ScanTileState<KeyValuePair<KeyT, ValueT> > SuperClass; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ | |
| ReduceByKeyScanTileState() : SuperClass() {} | |
| }; | |
| /** | |
| * Tile status interface for reduction by key, specialized for scan status and value types that | |
| * can be combined into one machine word that can be read/written coherently in a single access. | |
| */ | |
| template < | |
| typename ValueT, | |
| typename KeyT> | |
| struct ReduceByKeyScanTileState<ValueT, KeyT, true> | |
| { | |
| using KeyValuePairT = KeyValuePair<KeyT, ValueT>; | |
| // Constants | |
| enum | |
| { | |
| PAIR_SIZE = static_cast<int>(sizeof(ValueT) + sizeof(KeyT)), | |
| TXN_WORD_SIZE = 1 << Log2<PAIR_SIZE + 1>::VALUE, | |
| STATUS_WORD_SIZE = TXN_WORD_SIZE - PAIR_SIZE, | |
| TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, | |
| }; | |
| // Status word type | |
| using StatusWord = cub::detail::conditional_t< | |
| STATUS_WORD_SIZE == 8, | |
| unsigned long long, | |
| cub::detail::conditional_t< | |
| STATUS_WORD_SIZE == 4, | |
| unsigned int, | |
| cub::detail::conditional_t<STATUS_WORD_SIZE == 2, unsigned short, unsigned char>>>; | |
| // Status word type | |
| using TxnWord = cub::detail::conditional_t< | |
| TXN_WORD_SIZE == 16, | |
| ulonglong2, | |
| cub::detail::conditional_t<TXN_WORD_SIZE == 8, unsigned long long, unsigned int>>; | |
| // Device word type (for when sizeof(ValueT) == sizeof(KeyT)) | |
| struct TileDescriptorBigStatus | |
| { | |
| KeyT key; | |
| ValueT value; | |
| StatusWord status; | |
| }; | |
| // Device word type (for when sizeof(ValueT) != sizeof(KeyT)) | |
| struct TileDescriptorLittleStatus | |
| { | |
| ValueT value; | |
| StatusWord status; | |
| KeyT key; | |
| }; | |
| // Device word type | |
| using TileDescriptor = | |
| cub::detail::conditional_t<sizeof(ValueT) == sizeof(KeyT), | |
| TileDescriptorBigStatus, | |
| TileDescriptorLittleStatus>; | |
| // Device storage | |
| TxnWord *d_tile_descriptors; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ | |
| ReduceByKeyScanTileState() | |
| : | |
| d_tile_descriptors(NULL) | |
| {} | |
| /// Initializer | |
| __host__ __device__ __forceinline__ | |
| cudaError_t Init( | |
| int /*num_tiles*/, ///< [in] Number of tiles | |
| void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. | |
| size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation | |
| { | |
| d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); | |
| return cudaSuccess; | |
| } | |
| /** | |
| * Compute device memory needed for tile status | |
| */ | |
| __host__ __device__ __forceinline__ | |
| static cudaError_t AllocationSize( | |
| int num_tiles, ///< [in] Number of tiles | |
| size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation | |
| { | |
| temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TxnWord); // bytes needed for tile status descriptors | |
| return cudaSuccess; | |
| } | |
| /** | |
| * Initialize (from device) | |
| */ | |
| __device__ __forceinline__ void InitializeStatus(int num_tiles) | |
| { | |
| int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; | |
| TxnWord val = TxnWord(); | |
| TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); | |
| if (tile_idx < num_tiles) | |
| { | |
| // Not-yet-set | |
| descriptor->status = StatusWord(SCAN_TILE_INVALID); | |
| d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; | |
| } | |
| if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) | |
| { | |
| // Padding | |
| descriptor->status = StatusWord(SCAN_TILE_OOB); | |
| d_tile_descriptors[threadIdx.x] = val; | |
| } | |
| } | |
| /** | |
| * Update the specified tile's inclusive value and corresponding status | |
| */ | |
| __device__ __forceinline__ void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive) | |
| { | |
| TileDescriptor tile_descriptor; | |
| tile_descriptor.status = SCAN_TILE_INCLUSIVE; | |
| tile_descriptor.value = tile_inclusive.value; | |
| tile_descriptor.key = tile_inclusive.key; | |
| TxnWord alias; | |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
| detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
| } | |
| /** | |
| * Update the specified tile's partial value and corresponding status | |
| */ | |
| __device__ __forceinline__ void SetPartial(int tile_idx, KeyValuePairT tile_partial) | |
| { | |
| TileDescriptor tile_descriptor; | |
| tile_descriptor.status = SCAN_TILE_PARTIAL; | |
| tile_descriptor.value = tile_partial.value; | |
| tile_descriptor.key = tile_partial.key; | |
| TxnWord alias; | |
| *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; | |
| detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); | |
| } | |
| /** | |
| * Wait for the corresponding tile to become non-invalid | |
| */ | |
| template <class DelayT = detail::fixed_delay_constructor_t<350, 450>::delay_t> | |
| __device__ __forceinline__ void WaitForValid( | |
| int tile_idx, | |
| StatusWord &status, | |
| KeyValuePairT &value, | |
| DelayT delay_or_prevent_hoisting = {}) | |
| { | |
| // TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
| // TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
| // | |
| // while (tile_descriptor.status == SCAN_TILE_INVALID) | |
| // { | |
| // __threadfence_block(); // prevent hoisting loads from loop | |
| // | |
| // alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
| // tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
| // } | |
| // | |
| // status = tile_descriptor.status; | |
| // value.value = tile_descriptor.value; | |
| // value.key = tile_descriptor.key; | |
| TileDescriptor tile_descriptor; | |
| do | |
| { | |
| delay_or_prevent_hoisting(); | |
| TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); | |
| tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); | |
| } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); | |
| status = tile_descriptor.status; | |
| value.value = tile_descriptor.value; | |
| value.key = tile_descriptor.key; | |
| } | |
| }; | |
| /****************************************************************************** | |
| * Prefix call-back operator for coupling local block scan within a | |
| * block-cooperative scan | |
| ******************************************************************************/ | |
| /** | |
| * Stateful block-scan prefix functor. Provides the the running prefix for | |
| * the current tile by using the call-back warp to wait on on | |
| * aggregates/prefixes from predecessor tiles to become available. | |
| * | |
| * @tparam DelayConstructorT | |
| * Implementation detail, do not specify directly, requirements on the | |
| * content of this type are subject to breaking change. | |
| */ | |
| template < | |
| typename T, | |
| typename ScanOpT, | |
| typename ScanTileStateT, | |
| int LEGACY_PTX_ARCH = 0, | |
| typename DelayConstructorT = detail::default_delay_constructor_t<T>> | |
| struct TilePrefixCallbackOp | |
| { | |
| // Parameterized warp reduce | |
| typedef WarpReduce<T, CUB_PTX_WARP_THREADS> WarpReduceT; | |
| // Temporary storage type | |
| struct _TempStorage | |
| { | |
| typename WarpReduceT::TempStorage warp_reduce; | |
| T exclusive_prefix; | |
| T inclusive_prefix; | |
| T block_aggregate; | |
| }; | |
| // Alias wrapper allowing temporary storage to be unioned | |
| struct TempStorage : Uninitialized<_TempStorage> {}; | |
| // Type of status word | |
| typedef typename ScanTileStateT::StatusWord StatusWord; | |
| // Fields | |
| _TempStorage& temp_storage; ///< Reference to a warp-reduction instance | |
| ScanTileStateT& tile_status; ///< Interface to tile status | |
| ScanOpT scan_op; ///< Binary scan operator | |
| int tile_idx; ///< The current tile index | |
| T exclusive_prefix; ///< Exclusive prefix for the tile | |
| T inclusive_prefix; ///< Inclusive prefix for the tile | |
| // Constructs prefix functor for a given tile index. | |
| // Precondition: thread blocks processing all of the predecessor tiles were scheduled. | |
| __device__ __forceinline__ TilePrefixCallbackOp(ScanTileStateT &tile_status, | |
| TempStorage &temp_storage, | |
| ScanOpT scan_op, | |
| int tile_idx) | |
| : temp_storage(temp_storage.Alias()) | |
| , tile_status(tile_status) | |
| , scan_op(scan_op) | |
| , tile_idx(tile_idx) | |
| {} | |
| // Computes the tile index and constructs prefix functor with it. | |
| // Precondition: thread block per tile assignment. | |
| __device__ __forceinline__ TilePrefixCallbackOp(ScanTileStateT &tile_status, | |
| TempStorage &temp_storage, | |
| ScanOpT scan_op) | |
| : TilePrefixCallbackOp(tile_status, temp_storage, scan_op, blockIdx.x) | |
| {} | |
| // Block until all predecessors within the warp-wide window have non-invalid status | |
| template <class DelayT = detail::default_delay_t<T>> | |
| __device__ __forceinline__ | |
| void ProcessWindow( | |
| int predecessor_idx, ///< Preceding tile index to inspect | |
| StatusWord &predecessor_status, ///< [out] Preceding tile status | |
| T &window_aggregate, ///< [out] Relevant partial reduction from this window of preceding tiles | |
| DelayT delay = {}) | |
| { | |
| T value; | |
| tile_status.WaitForValid(predecessor_idx, predecessor_status, value, delay); | |
| // Perform a segmented reduction to get the prefix for the current window. | |
| // Use the swizzled scan operator because we are now scanning *down* towards thread0. | |
| int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE)); | |
| window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce( | |
| value, | |
| tail_flag, | |
| SwizzleScanOp<ScanOpT>(scan_op)); | |
| } | |
| // BlockScan prefix callback functor (called by the first warp) | |
| __device__ __forceinline__ | |
| T operator()(T block_aggregate) | |
| { | |
| // Update our status with our tile-aggregate | |
| if (threadIdx.x == 0) | |
| { | |
| detail::uninitialized_copy(&temp_storage.block_aggregate, | |
| block_aggregate); | |
| tile_status.SetPartial(tile_idx, block_aggregate); | |
| } | |
| int predecessor_idx = tile_idx - threadIdx.x - 1; | |
| StatusWord predecessor_status; | |
| T window_aggregate; | |
| // Wait for the warp-wide window of predecessor tiles to become valid | |
| DelayConstructorT construct_delay(tile_idx); | |
| ProcessWindow(predecessor_idx, predecessor_status, window_aggregate, construct_delay()); | |
| // The exclusive tile prefix starts out as the current window aggregate | |
| exclusive_prefix = window_aggregate; | |
| // Keep sliding the window back until we come across a tile whose inclusive prefix is known | |
| while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff)) | |
| { | |
| predecessor_idx -= CUB_PTX_WARP_THREADS; | |
| // Update exclusive tile prefix with the window prefix | |
| ProcessWindow(predecessor_idx, predecessor_status, window_aggregate, construct_delay()); | |
| exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); | |
| } | |
| // Compute the inclusive tile prefix and update the status for this tile | |
| if (threadIdx.x == 0) | |
| { | |
| inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); | |
| tile_status.SetInclusive(tile_idx, inclusive_prefix); | |
| detail::uninitialized_copy(&temp_storage.exclusive_prefix, | |
| exclusive_prefix); | |
| detail::uninitialized_copy(&temp_storage.inclusive_prefix, | |
| inclusive_prefix); | |
| } | |
| // Return exclusive_prefix | |
| return exclusive_prefix; | |
| } | |
| // Get the exclusive prefix stored in temporary storage | |
| __device__ __forceinline__ | |
| T GetExclusivePrefix() | |
| { | |
| return temp_storage.exclusive_prefix; | |
| } | |
| // Get the inclusive prefix stored in temporary storage | |
| __device__ __forceinline__ | |
| T GetInclusivePrefix() | |
| { | |
| return temp_storage.inclusive_prefix; | |
| } | |
| // Get the block aggregate stored in temporary storage | |
| __device__ __forceinline__ | |
| T GetBlockAggregate() | |
| { | |
| return temp_storage.block_aggregate; | |
| } | |
| __device__ __forceinline__ | |
| int GetTileIdx() const | |
| { | |
| return tile_idx; | |
| } | |
| }; | |
| CUB_NAMESPACE_END | |