thrust / install /include /cub /agent /agent_scan_by_key.cuh
camenduru's picture
thanks to nvidia ❤
0dc1b04
/******************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* @file AgentScanByKey implements a stateful abstraction of CUDA thread blocks
* for participating in device-wide prefix scan by key.
*/
#pragma once
#include <cub/agent/single_pass_scan_operators.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <cub/config.cuh>
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/util_type.cuh>
#include <iterator>
CUB_NAMESPACE_BEGIN
/******************************************************************************
* Tuning policy types
******************************************************************************/
/**
* Parameterizable tuning policy type for AgentScanByKey
*
* @tparam DelayConstructorT
* Implementation detail, do not specify directly, requirements on the
* content of this type are subject to breaking change.
*/
template <int _BLOCK_THREADS,
int _ITEMS_PER_THREAD = 1,
BlockLoadAlgorithm _LOAD_ALGORITHM = BLOCK_LOAD_DIRECT,
CacheLoadModifier _LOAD_MODIFIER = LOAD_DEFAULT,
BlockScanAlgorithm _SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
BlockStoreAlgorithm _STORE_ALGORITHM = BLOCK_STORE_DIRECT,
typename DelayConstructorT = detail::fixed_delay_constructor_t<350, 450>>
struct AgentScanByKeyPolicy
{
static constexpr int BLOCK_THREADS = _BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD;
static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM;
static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER;
static constexpr BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM;
static constexpr BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM;
struct detail
{
using delay_constructor_t = DelayConstructorT;
};
};
/******************************************************************************
* Thread block abstractions
******************************************************************************/
/**
* @brief AgentScanByKey implements a stateful abstraction of CUDA thread
* blocks for participating in device-wide prefix scan by key.
*
* @tparam AgentScanByKeyPolicyT
* Parameterized AgentScanPolicyT tuning policy type
*
* @tparam KeysInputIteratorT
* Random-access input iterator type
*
* @tparam ValuesInputIteratorT
* Random-access input iterator type
*
* @tparam ValuesOutputIteratorT
* Random-access output iterator type
*
* @tparam EqualityOp
* Equality functor type
*
* @tparam ScanOpT
* Scan functor type
*
* @tparam InitValueT
* The init_value element for ScanOpT type (cub::NullType for inclusive scan)
*
* @tparam OffsetT
* Signed integer type for global offsets
*
*/
template <typename AgentScanByKeyPolicyT,
typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOp,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT>
struct AgentScanByKey
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
using KeyT = cub::detail::value_t<KeysInputIteratorT>;
using InputT = cub::detail::value_t<ValuesInputIteratorT>;
using SizeValuePairT = KeyValuePair<OffsetT, AccumT>;
using KeyValuePairT = KeyValuePair<KeyT, AccumT>;
using ReduceBySegmentOpT = ReduceBySegmentOp<ScanOpT>;
using ScanTileStateT = ReduceByKeyScanTileState<AccumT, OffsetT>;
// Constants
// Inclusive scan if no init_value type is provided
static constexpr int IS_INCLUSIVE = std::is_same<InitValueT, NullType>::value;
static constexpr int BLOCK_THREADS = AgentScanByKeyPolicyT::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD =
AgentScanByKeyPolicyT::ITEMS_PER_THREAD;
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
using WrappedKeysInputIteratorT = cub::detail::conditional_t<
std::is_pointer<KeysInputIteratorT>::value,
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
KeysInputIteratorT>;
using WrappedValuesInputIteratorT = cub::detail::conditional_t<
std::is_pointer<ValuesInputIteratorT>::value,
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER,
InputT,
OffsetT>,
ValuesInputIteratorT>;
using BlockLoadKeysT = BlockLoad<KeyT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
using BlockLoadValuesT = BlockLoad<AccumT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
using BlockStoreValuesT = BlockStore<AccumT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentScanByKeyPolicyT::STORE_ALGORITHM>;
using BlockDiscontinuityKeysT = BlockDiscontinuity<KeyT, BLOCK_THREADS, 1, 1>;
using DelayConstructorT = typename AgentScanByKeyPolicyT::detail::delay_constructor_t;
using TilePrefixCallbackT =
TilePrefixCallbackOp<SizeValuePairT, ReduceBySegmentOpT, ScanTileStateT, 0, DelayConstructorT>;
using BlockScanT = BlockScan<SizeValuePairT,
BLOCK_THREADS,
AgentScanByKeyPolicyT::SCAN_ALGORITHM,
1,
1>;
union TempStorage_
{
struct ScanStorage
{
typename BlockScanT::TempStorage scan;
typename TilePrefixCallbackT::TempStorage prefix;
typename BlockDiscontinuityKeysT::TempStorage discontinuity;
} scan_storage;
typename BlockLoadKeysT::TempStorage load_keys;
typename BlockLoadValuesT::TempStorage load_values;
typename BlockStoreValuesT::TempStorage store_values;
};
struct TempStorage : cub::Uninitialized<TempStorage_>
{};
//---------------------------------------------------------------------
// Per-thread fields
//---------------------------------------------------------------------
TempStorage_ &storage;
WrappedKeysInputIteratorT d_keys_in;
KeyT *d_keys_prev_in;
WrappedValuesInputIteratorT d_values_in;
ValuesOutputIteratorT d_values_out;
InequalityWrapper<EqualityOp> inequality_op;
ScanOpT scan_op;
ReduceBySegmentOpT pair_scan_op;
InitValueT init_value;
//---------------------------------------------------------------------
// Block scan utility methods (first tile)
//---------------------------------------------------------------------
// Exclusive scan specialization
__device__ __forceinline__ void
ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT &tile_aggregate,
Int2Type<false> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan)
.ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
}
// Inclusive scan specialization
__device__ __forceinline__ void
ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT &tile_aggregate,
Int2Type<true> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan)
.InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
}
//---------------------------------------------------------------------
// Block scan utility methods (subsequent tiles)
//---------------------------------------------------------------------
// Exclusive scan specialization (with prefix from predecessors)
__device__ __forceinline__ void
ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT &tile_aggregate,
TilePrefixCallbackT &prefix_op,
Int2Type<false> /* is_incclusive */)
{
BlockScanT(storage.scan_storage.scan)
.ExclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op);
tile_aggregate = prefix_op.GetBlockAggregate();
}
// Inclusive scan specialization (with prefix from predecessors)
__device__ __forceinline__ void
ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT &tile_aggregate,
TilePrefixCallbackT &prefix_op,
Int2Type<true> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan)
.InclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op);
tile_aggregate = prefix_op.GetBlockAggregate();
}
//---------------------------------------------------------------------
// Zip utility methods
//---------------------------------------------------------------------
template <bool IS_LAST_TILE>
__device__ __forceinline__ void
ZipValuesAndFlags(OffsetT num_remaining,
AccumT (&values)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
{
// Zip values and segment_flags
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
// Set segment_flags for first out-of-bounds item, zero for others
if (IS_LAST_TILE &&
OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM == num_remaining)
{
segment_flags[ITEM] = 1;
}
scan_items[ITEM].value = values[ITEM];
scan_items[ITEM].key = segment_flags[ITEM];
}
}
__device__ __forceinline__ void
UnzipValues(AccumT (&values)[ITEMS_PER_THREAD],
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
{
// Zip values and segment_flags
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
values[ITEM] = scan_items[ITEM].value;
}
}
template <bool IsNull = std::is_same<InitValueT, NullType>::value,
typename std::enable_if<!IsNull, int>::type = 0>
__device__ __forceinline__ void
AddInitToScan(AccumT (&items)[ITEMS_PER_THREAD],
OffsetT (&flags)[ITEMS_PER_THREAD])
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
items[ITEM] = flags[ITEM] ? init_value : scan_op(init_value, items[ITEM]);
}
}
template <bool IsNull = std::is_same<InitValueT, NullType>::value,
typename std::enable_if<IsNull, int>::type = 0>
__device__ __forceinline__ void
AddInitToScan(AccumT (&/*items*/)[ITEMS_PER_THREAD],
OffsetT (&/*flags*/)[ITEMS_PER_THREAD])
{}
//---------------------------------------------------------------------
// Cooperatively scan a device-wide sequence of tiles with other CTAs
//---------------------------------------------------------------------
// Process a tile of input (dynamic chained scan)
//
template <bool IS_LAST_TILE>
__device__ __forceinline__ void ConsumeTile(OffsetT /*num_items*/,
OffsetT num_remaining,
int tile_idx,
OffsetT tile_base,
ScanTileStateT &tile_state)
{
// Load items
KeyT keys[ITEMS_PER_THREAD];
AccumT values[ITEMS_PER_THREAD];
OffsetT segment_flags[ITEMS_PER_THREAD];
SizeValuePairT scan_items[ITEMS_PER_THREAD];
if (IS_LAST_TILE)
{
// Fill last element with the first element
// because collectives are not suffix guarded
BlockLoadKeysT(storage.load_keys)
.Load(d_keys_in + tile_base,
keys,
num_remaining,
*(d_keys_in + tile_base));
}
else
{
BlockLoadKeysT(storage.load_keys).Load(d_keys_in + tile_base, keys);
}
CTA_SYNC();
if (IS_LAST_TILE)
{
// Fill last element with the first element
// because collectives are not suffix guarded
BlockLoadValuesT(storage.load_values)
.Load(d_values_in + tile_base,
values,
num_remaining,
*(d_values_in + tile_base));
}
else
{
BlockLoadValuesT(storage.load_values)
.Load(d_values_in + tile_base, values);
}
CTA_SYNC();
// first tile
if (tile_idx == 0)
{
BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
.FlagHeads(segment_flags, keys, inequality_op);
// Zip values and segment_flags
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining,
values,
segment_flags,
scan_items);
// Exclusive scan of values and segment_flags
SizeValuePairT tile_aggregate;
ScanTile(scan_items, tile_aggregate, Int2Type<IS_INCLUSIVE>());
if (threadIdx.x == 0)
{
if (!IS_LAST_TILE)
{
tile_state.SetInclusive(0, tile_aggregate);
}
scan_items[0].key = 0;
}
}
else
{
KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_prev_in[tile_idx]
: KeyT();
BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
.FlagHeads(segment_flags, keys, inequality_op, tile_pred_key);
// Zip values and segment_flags
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining,
values,
segment_flags,
scan_items);
SizeValuePairT tile_aggregate;
TilePrefixCallbackT prefix_op(tile_state,
storage.scan_storage.prefix,
pair_scan_op,
tile_idx);
ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type<IS_INCLUSIVE>());
}
CTA_SYNC();
UnzipValues(values, scan_items);
AddInitToScan(values, segment_flags);
// Store items
if (IS_LAST_TILE)
{
BlockStoreValuesT(storage.store_values)
.Store(d_values_out + tile_base, values, num_remaining);
}
else
{
BlockStoreValuesT(storage.store_values)
.Store(d_values_out + tile_base, values);
}
}
//---------------------------------------------------------------------
// Constructor
//---------------------------------------------------------------------
// Dequeue and scan tiles of items as part of a dynamic chained scan
// with Init functor
__device__ __forceinline__ AgentScanByKey(TempStorage &storage,
KeysInputIteratorT d_keys_in,
KeyT *d_keys_prev_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
EqualityOp equality_op,
ScanOpT scan_op,
InitValueT init_value)
: storage(storage.Alias())
, d_keys_in(d_keys_in)
, d_keys_prev_in(d_keys_prev_in)
, d_values_in(d_values_in)
, d_values_out(d_values_out)
, inequality_op(equality_op)
, scan_op(scan_op)
, pair_scan_op(scan_op)
, init_value(init_value)
{}
/**
* Scan tiles of items as part of a dynamic chained scan
*
* @param num_items
* Total number of input items
*
* @param tile_state
* Global tile state descriptor
*
* start_tile
* The starting tile for the current grid
*/
__device__ __forceinline__ void ConsumeRange(OffsetT num_items,
ScanTileStateT &tile_state,
int start_tile)
{
int tile_idx = blockIdx.x;
OffsetT tile_base = OffsetT(ITEMS_PER_TILE) * tile_idx;
OffsetT num_remaining = num_items - tile_base;
if (num_remaining > ITEMS_PER_TILE)
{
// Not the last tile (full)
ConsumeTile<false>(num_items,
num_remaining,
tile_idx,
tile_base,
tile_state);
}
else if (num_remaining > 0)
{
// The last tile (possibly partially-full)
ConsumeTile<true>(num_items,
num_remaining,
tile_idx,
tile_base,
tile_state);
}
}
};
CUB_NAMESPACE_END