thrust / install /include /cub /agent /agent_merge_sort.cuh
camenduru's picture
thanks to nvidia ❤
0dc1b04
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "../config.cuh"
#include "../util_type.cuh"
#include "../util_namespace.cuh"
#include "../block/block_load.cuh"
#include "../block/block_store.cuh"
#include "../block/block_merge_sort.cuh"
#include <thrust/system/cuda/detail/core/util.h>
CUB_NAMESPACE_BEGIN
template <
int _BLOCK_THREADS,
int _ITEMS_PER_THREAD = 1,
cub::BlockLoadAlgorithm _LOAD_ALGORITHM = cub::BLOCK_LOAD_DIRECT,
cub::CacheLoadModifier _LOAD_MODIFIER = cub::LOAD_LDG,
cub::BlockStoreAlgorithm _STORE_ALGORITHM = cub::BLOCK_STORE_DIRECT>
struct AgentMergeSortPolicy
{
static constexpr int BLOCK_THREADS = _BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD;
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
static constexpr cub::BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM;
static constexpr cub::CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER;
static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM;
};
/// \brief This agent is responsible for the initial in-tile sorting.
template <typename Policy,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
struct AgentBlockSort
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
using BlockMergeSortT =
BlockMergeSort<KeyT, Policy::BLOCK_THREADS, Policy::ITEMS_PER_THREAD, ValueT>;
using KeysLoadIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeyInputIteratorT>::type;
using ItemsLoadIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ValueInputIteratorT>::type;
using BlockLoadKeys = typename cub::BlockLoadType<Policy, KeysLoadIt>::type;
using BlockLoadItems = typename cub::BlockLoadType<Policy, ItemsLoadIt>::type;
using BlockStoreKeysIt = typename cub::BlockStoreType<Policy, KeyIteratorT>::type;
using BlockStoreItemsIt = typename cub::BlockStoreType<Policy, ValueIteratorT>::type;
using BlockStoreKeysRaw = typename cub::BlockStoreType<Policy, KeyT *>::type;
using BlockStoreItemsRaw = typename cub::BlockStoreType<Policy, ValueT *>::type;
union _TempStorage
{
typename BlockLoadKeys::TempStorage load_keys;
typename BlockLoadItems::TempStorage load_items;
typename BlockStoreKeysIt::TempStorage store_keys_it;
typename BlockStoreItemsIt::TempStorage store_items_it;
typename BlockStoreKeysRaw::TempStorage store_keys_raw;
typename BlockStoreItemsRaw::TempStorage store_items_raw;
typename BlockMergeSortT::TempStorage block_merge;
};
/// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD;
static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE;
static constexpr int SHARED_MEMORY_SIZE =
static_cast<int>(sizeof(TempStorage));
//---------------------------------------------------------------------
// Per thread data
//---------------------------------------------------------------------
bool ping;
_TempStorage &storage;
KeysLoadIt keys_in;
ItemsLoadIt items_in;
OffsetT keys_count;
KeyIteratorT keys_out_it;
ValueIteratorT items_out_it;
KeyT *keys_out_raw;
ValueT *items_out_raw;
CompareOpT compare_op;
__device__ __forceinline__ AgentBlockSort(bool ping_,
TempStorage &storage_,
KeysLoadIt keys_in_,
ItemsLoadIt items_in_,
OffsetT keys_count_,
KeyIteratorT keys_out_it_,
ValueIteratorT items_out_it_,
KeyT *keys_out_raw_,
ValueT *items_out_raw_,
CompareOpT compare_op_)
: ping(ping_)
, storage(storage_.Alias())
, keys_in(keys_in_)
, items_in(items_in_)
, keys_count(keys_count_)
, keys_out_it(keys_out_it_)
, items_out_it(items_out_it_)
, keys_out_raw(keys_out_raw_)
, items_out_raw(items_out_raw_)
, compare_op(compare_op_)
{
}
__device__ __forceinline__ void Process()
{
auto tile_idx = static_cast<OffsetT>(blockIdx.x);
auto num_tiles = static_cast<OffsetT>(gridDim.x);
auto tile_base = tile_idx * ITEMS_PER_TILE;
int items_in_tile = (cub::min)(keys_count - tile_base, int{ITEMS_PER_TILE});
if (tile_idx < num_tiles - 1)
{
consume_tile<false>(tile_base, ITEMS_PER_TILE);
}
else
{
consume_tile<true>(tile_base, items_in_tile);
}
}
template <bool IS_LAST_TILE>
__device__ __forceinline__ void consume_tile(OffsetT tile_base,
int num_remaining)
{
ValueT items_local[ITEMS_PER_THREAD];
if (!KEYS_ONLY)
{
if (IS_LAST_TILE)
{
BlockLoadItems(storage.load_items)
.Load(items_in + tile_base,
items_local,
num_remaining,
*(items_in + tile_base));
}
else
{
BlockLoadItems(storage.load_items).Load(items_in + tile_base, items_local);
}
CTA_SYNC();
}
KeyT keys_local[ITEMS_PER_THREAD];
if (IS_LAST_TILE)
{
BlockLoadKeys(storage.load_keys)
.Load(keys_in + tile_base,
keys_local,
num_remaining,
*(keys_in + tile_base));
}
else
{
BlockLoadKeys(storage.load_keys)
.Load(keys_in + tile_base, keys_local);
}
CTA_SYNC();
if (IS_LAST_TILE)
{
BlockMergeSortT(storage.block_merge)
.Sort(keys_local, items_local, compare_op, num_remaining, keys_local[0]);
}
else
{
BlockMergeSortT(storage.block_merge).Sort(keys_local, items_local, compare_op);
}
CTA_SYNC();
if (ping)
{
if (IS_LAST_TILE)
{
BlockStoreKeysIt(storage.store_keys_it)
.Store(keys_out_it + tile_base, keys_local, num_remaining);
}
else
{
BlockStoreKeysIt(storage.store_keys_it)
.Store(keys_out_it + tile_base, keys_local);
}
if (!KEYS_ONLY)
{
CTA_SYNC();
if (IS_LAST_TILE)
{
BlockStoreItemsIt(storage.store_items_it)
.Store(items_out_it + tile_base, items_local, num_remaining);
}
else
{
BlockStoreItemsIt(storage.store_items_it)
.Store(items_out_it + tile_base, items_local);
}
}
}
else
{
if (IS_LAST_TILE)
{
BlockStoreKeysRaw(storage.store_keys_raw)
.Store(keys_out_raw + tile_base, keys_local, num_remaining);
}
else
{
BlockStoreKeysRaw(storage.store_keys_raw)
.Store(keys_out_raw + tile_base, keys_local);
}
if (!KEYS_ONLY)
{
CTA_SYNC();
if (IS_LAST_TILE)
{
BlockStoreItemsRaw(storage.store_items_raw)
.Store(items_out_raw + tile_base, items_local, num_remaining);
}
else
{
BlockStoreItemsRaw(storage.store_items_raw)
.Store(items_out_raw + tile_base, items_local);
}
}
}
}
};
/**
* \brief This agent is responsible for partitioning a merge path into equal segments
*
* There are two sorted arrays to be merged into one array. If the first array
* is partitioned between parallel workers by slicing it into ranges of equal
* size, there could be a significant workload imbalance. The imbalance is
* caused by the fact that the distribution of elements from the second array
* is unknown beforehand. Instead, the MergePath is partitioned between workers.
* This approach guarantees an equal amount of work being assigned to each worker.
*
* This approach is outlined in the paper:
* Odeh et al, "Merge Path - Parallel Merging Made Simple"
* doi:10.1109/IPDPSW.2012.202
*/
template <
typename KeyIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT>
struct AgentPartition
{
bool ping;
KeyIteratorT keys_ping;
KeyT *keys_pong;
OffsetT keys_count;
OffsetT partition_idx;
OffsetT *merge_partitions;
CompareOpT compare_op;
OffsetT target_merged_tiles_number;
int items_per_tile;
__device__ __forceinline__ AgentPartition(bool ping,
KeyIteratorT keys_ping,
KeyT *keys_pong,
OffsetT keys_count,
OffsetT partition_idx,
OffsetT *merge_partitions,
CompareOpT compare_op,
OffsetT target_merged_tiles_number,
int items_per_tile)
: ping(ping)
, keys_ping(keys_ping)
, keys_pong(keys_pong)
, keys_count(keys_count)
, partition_idx(partition_idx)
, merge_partitions(merge_partitions)
, compare_op(compare_op)
, target_merged_tiles_number(target_merged_tiles_number)
, items_per_tile(items_per_tile)
{}
__device__ __forceinline__ void Process()
{
OffsetT merged_tiles_number = target_merged_tiles_number / 2;
// target_merged_tiles_number is a power of two.
OffsetT mask = target_merged_tiles_number - 1;
// The first tile number in the tiles group being merged, equal to:
// target_merged_tiles_number * (partition_idx / target_merged_tiles_number)
OffsetT list = ~mask & partition_idx;
OffsetT start = items_per_tile * list;
OffsetT size = items_per_tile * merged_tiles_number;
// Tile number within the tile group being merged, equal to:
// partition_idx / target_merged_tiles_number
OffsetT local_tile_idx = mask & partition_idx;
OffsetT keys1_beg = (cub::min)(keys_count, start);
OffsetT keys1_end = (cub::min)(keys_count, start + size);
OffsetT keys2_beg = keys1_end;
OffsetT keys2_end = (cub::min)(keys_count, keys2_beg + size);
OffsetT partition_at = (cub::min)(keys2_end - keys1_beg,
items_per_tile * local_tile_idx);
OffsetT partition_diag = ping ? MergePath<KeyT>(keys_ping + keys1_beg,
keys_ping + keys2_beg,
keys1_end - keys1_beg,
keys2_end - keys2_beg,
partition_at,
compare_op)
: MergePath<KeyT>(keys_pong + keys1_beg,
keys_pong + keys2_beg,
keys1_end - keys1_beg,
keys2_end - keys2_beg,
partition_at,
compare_op);
merge_partitions[partition_idx] = keys1_beg + partition_diag;
}
};
/// \brief The agent is responsible for merging N consecutive sorted arrays into N/2 sorted arrays.
template <
typename Policy,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
struct AgentMerge
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
using KeysLoadPingIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeyIteratorT>::type;
using ItemsLoadPingIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ValueIteratorT>::type;
using KeysLoadPongIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeyT *>::type;
using ItemsLoadPongIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ValueT *>::type;
using KeysOutputPongIt = KeyIteratorT;
using ItemsOutputPongIt = ValueIteratorT;
using KeysOutputPingIt = KeyT*;
using ItemsOutputPingIt = ValueT*;
using BlockStoreKeysPong = typename BlockStoreType<Policy, KeysOutputPongIt>::type;
using BlockStoreItemsPong = typename BlockStoreType<Policy, ItemsOutputPongIt>::type;
using BlockStoreKeysPing = typename BlockStoreType<Policy, KeysOutputPingIt>::type;
using BlockStoreItemsPing = typename BlockStoreType<Policy, ItemsOutputPingIt>::type;
/// Parameterized BlockReduce primitive
union _TempStorage
{
typename BlockStoreKeysPing::TempStorage store_keys_ping;
typename BlockStoreItemsPing::TempStorage store_items_ping;
typename BlockStoreKeysPong::TempStorage store_keys_pong;
typename BlockStoreItemsPong::TempStorage store_items_pong;
KeyT keys_shared[Policy::ITEMS_PER_TILE + 1];
ValueT items_shared[Policy::ITEMS_PER_TILE + 1];
};
/// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD;
static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE;
static constexpr int SHARED_MEMORY_SIZE =
static_cast<int>(sizeof(TempStorage));
//---------------------------------------------------------------------
// Per thread data
//---------------------------------------------------------------------
bool ping;
_TempStorage& storage;
KeysLoadPingIt keys_in_ping;
ItemsLoadPingIt items_in_ping;
KeysLoadPongIt keys_in_pong;
ItemsLoadPongIt items_in_pong;
OffsetT keys_count;
KeysOutputPongIt keys_out_pong;
ItemsOutputPongIt items_out_pong;
KeysOutputPingIt keys_out_ping;
ItemsOutputPingIt items_out_ping;
CompareOpT compare_op;
OffsetT *merge_partitions;
OffsetT target_merged_tiles_number;
//---------------------------------------------------------------------
// Utility functions
//---------------------------------------------------------------------
/**
* \brief Concatenates up to ITEMS_PER_THREAD elements from input{1,2} into output array
*
* Reads data in a coalesced fashion [BLOCK_THREADS * item + tid] and
* stores the result in output[item].
*/
template <bool IS_FULL_TILE, class T, class It1, class It2>
__device__ __forceinline__ void
gmem_to_reg(T (&output)[ITEMS_PER_THREAD],
It1 input1,
It2 input2,
int count1,
int count2)
{
if (IS_FULL_TILE)
{
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = BLOCK_THREADS * item + threadIdx.x;
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
}
}
else
{
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = BLOCK_THREADS * item + threadIdx.x;
if (idx < count1 + count2)
{
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
}
}
}
}
/// \brief Stores data in a coalesced fashion in[item] -> out[BLOCK_THREADS * item + tid]
template <class T, class It>
__device__ __forceinline__ void
reg_to_shared(It output,
T (&input)[ITEMS_PER_THREAD])
{
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = BLOCK_THREADS * item + threadIdx.x;
output[idx] = input[item];
}
}
template <bool IS_FULL_TILE>
__device__ __forceinline__ void
consume_tile(int tid, OffsetT tile_idx, OffsetT tile_base, int count)
{
OffsetT partition_beg = merge_partitions[tile_idx + 0];
OffsetT partition_end = merge_partitions[tile_idx + 1];
// target_merged_tiles_number is a power of two.
OffsetT merged_tiles_number = target_merged_tiles_number / 2;
OffsetT mask = target_merged_tiles_number - 1;
// The first tile number in the tiles group being merged, equal to:
// target_merged_tiles_number * (tile_idx / target_merged_tiles_number)
OffsetT list = ~mask & tile_idx;
OffsetT start = ITEMS_PER_TILE * list;
OffsetT size = ITEMS_PER_TILE * merged_tiles_number;
OffsetT diag = ITEMS_PER_TILE * tile_idx - start;
OffsetT keys1_beg = partition_beg;
OffsetT keys1_end = partition_end;
OffsetT keys2_beg = (cub::min)(keys_count, 2 * start + size + diag - partition_beg);
OffsetT keys2_end = (cub::min)(keys_count, 2 * start + size + diag + ITEMS_PER_TILE - partition_end);
// Check if it's the last tile in the tile group being merged
if (mask == (mask & tile_idx))
{
keys1_end = (cub::min)(keys_count, start + size);
keys2_end = (cub::min)(keys_count, start + size * 2);
}
// number of keys per tile
//
int num_keys1 = static_cast<int>(keys1_end - keys1_beg);
int num_keys2 = static_cast<int>(keys2_end - keys2_beg);
// load keys1 & keys2
KeyT keys_local[ITEMS_PER_THREAD];
if (ping)
{
gmem_to_reg<IS_FULL_TILE>(keys_local,
keys_in_ping + keys1_beg,
keys_in_ping + keys2_beg,
num_keys1,
num_keys2);
}
else
{
gmem_to_reg<IS_FULL_TILE>(keys_local,
keys_in_pong + keys1_beg,
keys_in_pong + keys2_beg,
num_keys1,
num_keys2);
}
reg_to_shared(&storage.keys_shared[0], keys_local);
// preload items into registers already
//
ValueT items_local[ITEMS_PER_THREAD];
if (!KEYS_ONLY)
{
if (ping)
{
gmem_to_reg<IS_FULL_TILE>(items_local,
items_in_ping + keys1_beg,
items_in_ping + keys2_beg,
num_keys1,
num_keys2);
}
else
{
gmem_to_reg<IS_FULL_TILE>(items_local,
items_in_pong + keys1_beg,
items_in_pong + keys2_beg,
num_keys1,
num_keys2);
}
}
CTA_SYNC();
// use binary search in shared memory
// to find merge path for each of thread
// we can use int type here, because the number of
// items in shared memory is limited
//
int diag0_local = (cub::min)(num_keys1 + num_keys2, ITEMS_PER_THREAD * tid);
int keys1_beg_local = MergePath<KeyT>(&storage.keys_shared[0],
&storage.keys_shared[num_keys1],
num_keys1,
num_keys2,
diag0_local,
compare_op);
int keys1_end_local = num_keys1;
int keys2_beg_local = diag0_local - keys1_beg_local;
int keys2_end_local = num_keys2;
int num_keys1_local = keys1_end_local - keys1_beg_local;
int num_keys2_local = keys2_end_local - keys2_beg_local;
// perform serial merge
//
int indices[ITEMS_PER_THREAD];
SerialMerge(&storage.keys_shared[0],
keys1_beg_local,
keys2_beg_local + num_keys1,
num_keys1_local,
num_keys2_local,
keys_local,
indices,
compare_op);
CTA_SYNC();
// write keys
//
if (ping)
{
if (IS_FULL_TILE)
{
BlockStoreKeysPing(storage.store_keys_ping)
.Store(keys_out_ping + tile_base, keys_local);
}
else
{
BlockStoreKeysPing(storage.store_keys_ping)
.Store(keys_out_ping + tile_base, keys_local, num_keys1 + num_keys2);
}
}
else
{
if (IS_FULL_TILE)
{
BlockStoreKeysPong(storage.store_keys_pong)
.Store(keys_out_pong + tile_base, keys_local);
}
else
{
BlockStoreKeysPong(storage.store_keys_pong)
.Store(keys_out_pong + tile_base, keys_local, num_keys1 + num_keys2);
}
}
// if items are provided, merge them
if (!KEYS_ONLY)
{
CTA_SYNC();
reg_to_shared(&storage.items_shared[0], items_local);
CTA_SYNC();
// gather items from shared mem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
items_local[item] = storage.items_shared[indices[item]];
}
CTA_SYNC();
// write from reg to gmem
//
if (ping)
{
if (IS_FULL_TILE)
{
BlockStoreItemsPing(storage.store_items_ping)
.Store(items_out_ping + tile_base, items_local);
}
else
{
BlockStoreItemsPing(storage.store_items_ping)
.Store(items_out_ping + tile_base, items_local, count);
}
}
else
{
if (IS_FULL_TILE)
{
BlockStoreItemsPong(storage.store_items_pong)
.Store(items_out_pong + tile_base, items_local);
}
else
{
BlockStoreItemsPong(storage.store_items_pong)
.Store(items_out_pong + tile_base, items_local, count);
}
}
}
}
__device__ __forceinline__ AgentMerge(bool ping_,
TempStorage &storage_,
KeysLoadPingIt keys_in_ping_,
ItemsLoadPingIt items_in_ping_,
KeysLoadPongIt keys_in_pong_,
ItemsLoadPongIt items_in_pong_,
OffsetT keys_count_,
KeysOutputPingIt keys_out_ping_,
ItemsOutputPingIt items_out_ping_,
KeysOutputPongIt keys_out_pong_,
ItemsOutputPongIt items_out_pong_,
CompareOpT compare_op_,
OffsetT *merge_partitions_,
OffsetT target_merged_tiles_number_)
: ping(ping_)
, storage(storage_.Alias())
, keys_in_ping(keys_in_ping_)
, items_in_ping(items_in_ping_)
, keys_in_pong(keys_in_pong_)
, items_in_pong(items_in_pong_)
, keys_count(keys_count_)
, keys_out_pong(keys_out_pong_)
, items_out_pong(items_out_pong_)
, keys_out_ping(keys_out_ping_)
, items_out_ping(items_out_ping_)
, compare_op(compare_op_)
, merge_partitions(merge_partitions_)
, target_merged_tiles_number(target_merged_tiles_number_)
{}
__device__ __forceinline__ void Process()
{
int tile_idx = static_cast<int>(blockIdx.x);
int num_tiles = static_cast<int>(gridDim.x);
OffsetT tile_base = OffsetT(tile_idx) * ITEMS_PER_TILE;
int tid = static_cast<int>(threadIdx.x);
int items_in_tile = static_cast<int>(
(cub::min)(static_cast<OffsetT>(ITEMS_PER_TILE), keys_count - tile_base));
if (tile_idx < num_tiles - 1)
{
consume_tile<true>(tid, tile_idx, tile_base, ITEMS_PER_TILE);
}
else
{
consume_tile<false>(tid, tile_idx, tile_base, items_in_tile);
}
}
};
CUB_NAMESPACE_END