| /****************************************************************************** | |
| * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * * Redistributions of source code must retain the above copyright | |
| * notice, this list of conditions and the following disclaimer. | |
| * * Redistributions in binary form must reproduce the above copyright | |
| * notice, this list of conditions and the following disclaimer in the | |
| * documentation and/or other materials provided with the distribution. | |
| * * Neither the name of the NVIDIA CORPORATION nor the | |
| * names of its contributors may be used to endorse or promote products | |
| * derived from this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| ******************************************************************************/ | |
| #pragma once | |
| #include "../config.cuh" | |
| #include "../util_type.cuh" | |
| #include "../util_namespace.cuh" | |
| #include "../block/block_load.cuh" | |
| #include "../block/block_store.cuh" | |
| #include "../block/block_merge_sort.cuh" | |
| #include <thrust/system/cuda/detail/core/util.h> | |
| CUB_NAMESPACE_BEGIN | |
| template < | |
| int _BLOCK_THREADS, | |
| int _ITEMS_PER_THREAD = 1, | |
| cub::BlockLoadAlgorithm _LOAD_ALGORITHM = cub::BLOCK_LOAD_DIRECT, | |
| cub::CacheLoadModifier _LOAD_MODIFIER = cub::LOAD_LDG, | |
| cub::BlockStoreAlgorithm _STORE_ALGORITHM = cub::BLOCK_STORE_DIRECT> | |
| struct AgentMergeSortPolicy | |
| { | |
| static constexpr int BLOCK_THREADS = _BLOCK_THREADS; | |
| static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD; | |
| static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD; | |
| static constexpr cub::BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; | |
| static constexpr cub::CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; | |
| static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; | |
| }; | |
| /// \brief This agent is responsible for the initial in-tile sorting. | |
| template <typename Policy, | |
| typename KeyInputIteratorT, | |
| typename ValueInputIteratorT, | |
| typename KeyIteratorT, | |
| typename ValueIteratorT, | |
| typename OffsetT, | |
| typename CompareOpT, | |
| typename KeyT, | |
| typename ValueT> | |
| struct AgentBlockSort | |
| { | |
| //--------------------------------------------------------------------- | |
| // Types and constants | |
| //--------------------------------------------------------------------- | |
| static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value; | |
| using BlockMergeSortT = | |
| BlockMergeSort<KeyT, Policy::BLOCK_THREADS, Policy::ITEMS_PER_THREAD, ValueT>; | |
| using KeysLoadIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeyInputIteratorT>::type; | |
| using ItemsLoadIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ValueInputIteratorT>::type; | |
| using BlockLoadKeys = typename cub::BlockLoadType<Policy, KeysLoadIt>::type; | |
| using BlockLoadItems = typename cub::BlockLoadType<Policy, ItemsLoadIt>::type; | |
| using BlockStoreKeysIt = typename cub::BlockStoreType<Policy, KeyIteratorT>::type; | |
| using BlockStoreItemsIt = typename cub::BlockStoreType<Policy, ValueIteratorT>::type; | |
| using BlockStoreKeysRaw = typename cub::BlockStoreType<Policy, KeyT *>::type; | |
| using BlockStoreItemsRaw = typename cub::BlockStoreType<Policy, ValueT *>::type; | |
| union _TempStorage | |
| { | |
| typename BlockLoadKeys::TempStorage load_keys; | |
| typename BlockLoadItems::TempStorage load_items; | |
| typename BlockStoreKeysIt::TempStorage store_keys_it; | |
| typename BlockStoreItemsIt::TempStorage store_items_it; | |
| typename BlockStoreKeysRaw::TempStorage store_keys_raw; | |
| typename BlockStoreItemsRaw::TempStorage store_items_raw; | |
| typename BlockMergeSortT::TempStorage block_merge; | |
| }; | |
| /// Alias wrapper allowing storage to be unioned | |
| struct TempStorage : Uninitialized<_TempStorage> {}; | |
| static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS; | |
| static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD; | |
| static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE; | |
| static constexpr int SHARED_MEMORY_SIZE = | |
| static_cast<int>(sizeof(TempStorage)); | |
| //--------------------------------------------------------------------- | |
| // Per thread data | |
| //--------------------------------------------------------------------- | |
| bool ping; | |
| _TempStorage &storage; | |
| KeysLoadIt keys_in; | |
| ItemsLoadIt items_in; | |
| OffsetT keys_count; | |
| KeyIteratorT keys_out_it; | |
| ValueIteratorT items_out_it; | |
| KeyT *keys_out_raw; | |
| ValueT *items_out_raw; | |
| CompareOpT compare_op; | |
| __device__ __forceinline__ AgentBlockSort(bool ping_, | |
| TempStorage &storage_, | |
| KeysLoadIt keys_in_, | |
| ItemsLoadIt items_in_, | |
| OffsetT keys_count_, | |
| KeyIteratorT keys_out_it_, | |
| ValueIteratorT items_out_it_, | |
| KeyT *keys_out_raw_, | |
| ValueT *items_out_raw_, | |
| CompareOpT compare_op_) | |
| : ping(ping_) | |
| , storage(storage_.Alias()) | |
| , keys_in(keys_in_) | |
| , items_in(items_in_) | |
| , keys_count(keys_count_) | |
| , keys_out_it(keys_out_it_) | |
| , items_out_it(items_out_it_) | |
| , keys_out_raw(keys_out_raw_) | |
| , items_out_raw(items_out_raw_) | |
| , compare_op(compare_op_) | |
| { | |
| } | |
| __device__ __forceinline__ void Process() | |
| { | |
| auto tile_idx = static_cast<OffsetT>(blockIdx.x); | |
| auto num_tiles = static_cast<OffsetT>(gridDim.x); | |
| auto tile_base = tile_idx * ITEMS_PER_TILE; | |
| int items_in_tile = (cub::min)(keys_count - tile_base, int{ITEMS_PER_TILE}); | |
| if (tile_idx < num_tiles - 1) | |
| { | |
| consume_tile<false>(tile_base, ITEMS_PER_TILE); | |
| } | |
| else | |
| { | |
| consume_tile<true>(tile_base, items_in_tile); | |
| } | |
| } | |
| template <bool IS_LAST_TILE> | |
| __device__ __forceinline__ void consume_tile(OffsetT tile_base, | |
| int num_remaining) | |
| { | |
| ValueT items_local[ITEMS_PER_THREAD]; | |
| if (!KEYS_ONLY) | |
| { | |
| if (IS_LAST_TILE) | |
| { | |
| BlockLoadItems(storage.load_items) | |
| .Load(items_in + tile_base, | |
| items_local, | |
| num_remaining, | |
| *(items_in + tile_base)); | |
| } | |
| else | |
| { | |
| BlockLoadItems(storage.load_items).Load(items_in + tile_base, items_local); | |
| } | |
| CTA_SYNC(); | |
| } | |
| KeyT keys_local[ITEMS_PER_THREAD]; | |
| if (IS_LAST_TILE) | |
| { | |
| BlockLoadKeys(storage.load_keys) | |
| .Load(keys_in + tile_base, | |
| keys_local, | |
| num_remaining, | |
| *(keys_in + tile_base)); | |
| } | |
| else | |
| { | |
| BlockLoadKeys(storage.load_keys) | |
| .Load(keys_in + tile_base, keys_local); | |
| } | |
| CTA_SYNC(); | |
| if (IS_LAST_TILE) | |
| { | |
| BlockMergeSortT(storage.block_merge) | |
| .Sort(keys_local, items_local, compare_op, num_remaining, keys_local[0]); | |
| } | |
| else | |
| { | |
| BlockMergeSortT(storage.block_merge).Sort(keys_local, items_local, compare_op); | |
| } | |
| CTA_SYNC(); | |
| if (ping) | |
| { | |
| if (IS_LAST_TILE) | |
| { | |
| BlockStoreKeysIt(storage.store_keys_it) | |
| .Store(keys_out_it + tile_base, keys_local, num_remaining); | |
| } | |
| else | |
| { | |
| BlockStoreKeysIt(storage.store_keys_it) | |
| .Store(keys_out_it + tile_base, keys_local); | |
| } | |
| if (!KEYS_ONLY) | |
| { | |
| CTA_SYNC(); | |
| if (IS_LAST_TILE) | |
| { | |
| BlockStoreItemsIt(storage.store_items_it) | |
| .Store(items_out_it + tile_base, items_local, num_remaining); | |
| } | |
| else | |
| { | |
| BlockStoreItemsIt(storage.store_items_it) | |
| .Store(items_out_it + tile_base, items_local); | |
| } | |
| } | |
| } | |
| else | |
| { | |
| if (IS_LAST_TILE) | |
| { | |
| BlockStoreKeysRaw(storage.store_keys_raw) | |
| .Store(keys_out_raw + tile_base, keys_local, num_remaining); | |
| } | |
| else | |
| { | |
| BlockStoreKeysRaw(storage.store_keys_raw) | |
| .Store(keys_out_raw + tile_base, keys_local); | |
| } | |
| if (!KEYS_ONLY) | |
| { | |
| CTA_SYNC(); | |
| if (IS_LAST_TILE) | |
| { | |
| BlockStoreItemsRaw(storage.store_items_raw) | |
| .Store(items_out_raw + tile_base, items_local, num_remaining); | |
| } | |
| else | |
| { | |
| BlockStoreItemsRaw(storage.store_items_raw) | |
| .Store(items_out_raw + tile_base, items_local); | |
| } | |
| } | |
| } | |
| } | |
| }; | |
| /** | |
| * \brief This agent is responsible for partitioning a merge path into equal segments | |
| * | |
| * There are two sorted arrays to be merged into one array. If the first array | |
| * is partitioned between parallel workers by slicing it into ranges of equal | |
| * size, there could be a significant workload imbalance. The imbalance is | |
| * caused by the fact that the distribution of elements from the second array | |
| * is unknown beforehand. Instead, the MergePath is partitioned between workers. | |
| * This approach guarantees an equal amount of work being assigned to each worker. | |
| * | |
| * This approach is outlined in the paper: | |
| * Odeh et al, "Merge Path - Parallel Merging Made Simple" | |
| * doi:10.1109/IPDPSW.2012.202 | |
| */ | |
| template < | |
| typename KeyIteratorT, | |
| typename OffsetT, | |
| typename CompareOpT, | |
| typename KeyT> | |
| struct AgentPartition | |
| { | |
| bool ping; | |
| KeyIteratorT keys_ping; | |
| KeyT *keys_pong; | |
| OffsetT keys_count; | |
| OffsetT partition_idx; | |
| OffsetT *merge_partitions; | |
| CompareOpT compare_op; | |
| OffsetT target_merged_tiles_number; | |
| int items_per_tile; | |
| __device__ __forceinline__ AgentPartition(bool ping, | |
| KeyIteratorT keys_ping, | |
| KeyT *keys_pong, | |
| OffsetT keys_count, | |
| OffsetT partition_idx, | |
| OffsetT *merge_partitions, | |
| CompareOpT compare_op, | |
| OffsetT target_merged_tiles_number, | |
| int items_per_tile) | |
| : ping(ping) | |
| , keys_ping(keys_ping) | |
| , keys_pong(keys_pong) | |
| , keys_count(keys_count) | |
| , partition_idx(partition_idx) | |
| , merge_partitions(merge_partitions) | |
| , compare_op(compare_op) | |
| , target_merged_tiles_number(target_merged_tiles_number) | |
| , items_per_tile(items_per_tile) | |
| {} | |
| __device__ __forceinline__ void Process() | |
| { | |
| OffsetT merged_tiles_number = target_merged_tiles_number / 2; | |
| // target_merged_tiles_number is a power of two. | |
| OffsetT mask = target_merged_tiles_number - 1; | |
| // The first tile number in the tiles group being merged, equal to: | |
| // target_merged_tiles_number * (partition_idx / target_merged_tiles_number) | |
| OffsetT list = ~mask & partition_idx; | |
| OffsetT start = items_per_tile * list; | |
| OffsetT size = items_per_tile * merged_tiles_number; | |
| // Tile number within the tile group being merged, equal to: | |
| // partition_idx / target_merged_tiles_number | |
| OffsetT local_tile_idx = mask & partition_idx; | |
| OffsetT keys1_beg = (cub::min)(keys_count, start); | |
| OffsetT keys1_end = (cub::min)(keys_count, start + size); | |
| OffsetT keys2_beg = keys1_end; | |
| OffsetT keys2_end = (cub::min)(keys_count, keys2_beg + size); | |
| OffsetT partition_at = (cub::min)(keys2_end - keys1_beg, | |
| items_per_tile * local_tile_idx); | |
| OffsetT partition_diag = ping ? MergePath<KeyT>(keys_ping + keys1_beg, | |
| keys_ping + keys2_beg, | |
| keys1_end - keys1_beg, | |
| keys2_end - keys2_beg, | |
| partition_at, | |
| compare_op) | |
| : MergePath<KeyT>(keys_pong + keys1_beg, | |
| keys_pong + keys2_beg, | |
| keys1_end - keys1_beg, | |
| keys2_end - keys2_beg, | |
| partition_at, | |
| compare_op); | |
| merge_partitions[partition_idx] = keys1_beg + partition_diag; | |
| } | |
| }; | |
| /// \brief The agent is responsible for merging N consecutive sorted arrays into N/2 sorted arrays. | |
| template < | |
| typename Policy, | |
| typename KeyIteratorT, | |
| typename ValueIteratorT, | |
| typename OffsetT, | |
| typename CompareOpT, | |
| typename KeyT, | |
| typename ValueT> | |
| struct AgentMerge | |
| { | |
| //--------------------------------------------------------------------- | |
| // Types and constants | |
| //--------------------------------------------------------------------- | |
| using KeysLoadPingIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeyIteratorT>::type; | |
| using ItemsLoadPingIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ValueIteratorT>::type; | |
| using KeysLoadPongIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeyT *>::type; | |
| using ItemsLoadPongIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ValueT *>::type; | |
| using KeysOutputPongIt = KeyIteratorT; | |
| using ItemsOutputPongIt = ValueIteratorT; | |
| using KeysOutputPingIt = KeyT*; | |
| using ItemsOutputPingIt = ValueT*; | |
| using BlockStoreKeysPong = typename BlockStoreType<Policy, KeysOutputPongIt>::type; | |
| using BlockStoreItemsPong = typename BlockStoreType<Policy, ItemsOutputPongIt>::type; | |
| using BlockStoreKeysPing = typename BlockStoreType<Policy, KeysOutputPingIt>::type; | |
| using BlockStoreItemsPing = typename BlockStoreType<Policy, ItemsOutputPingIt>::type; | |
| /// Parameterized BlockReduce primitive | |
| union _TempStorage | |
| { | |
| typename BlockStoreKeysPing::TempStorage store_keys_ping; | |
| typename BlockStoreItemsPing::TempStorage store_items_ping; | |
| typename BlockStoreKeysPong::TempStorage store_keys_pong; | |
| typename BlockStoreItemsPong::TempStorage store_items_pong; | |
| KeyT keys_shared[Policy::ITEMS_PER_TILE + 1]; | |
| ValueT items_shared[Policy::ITEMS_PER_TILE + 1]; | |
| }; | |
| /// Alias wrapper allowing storage to be unioned | |
| struct TempStorage : Uninitialized<_TempStorage> {}; | |
| static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value; | |
| static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS; | |
| static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD; | |
| static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE; | |
| static constexpr int SHARED_MEMORY_SIZE = | |
| static_cast<int>(sizeof(TempStorage)); | |
| //--------------------------------------------------------------------- | |
| // Per thread data | |
| //--------------------------------------------------------------------- | |
| bool ping; | |
| _TempStorage& storage; | |
| KeysLoadPingIt keys_in_ping; | |
| ItemsLoadPingIt items_in_ping; | |
| KeysLoadPongIt keys_in_pong; | |
| ItemsLoadPongIt items_in_pong; | |
| OffsetT keys_count; | |
| KeysOutputPongIt keys_out_pong; | |
| ItemsOutputPongIt items_out_pong; | |
| KeysOutputPingIt keys_out_ping; | |
| ItemsOutputPingIt items_out_ping; | |
| CompareOpT compare_op; | |
| OffsetT *merge_partitions; | |
| OffsetT target_merged_tiles_number; | |
| //--------------------------------------------------------------------- | |
| // Utility functions | |
| //--------------------------------------------------------------------- | |
| /** | |
| * \brief Concatenates up to ITEMS_PER_THREAD elements from input{1,2} into output array | |
| * | |
| * Reads data in a coalesced fashion [BLOCK_THREADS * item + tid] and | |
| * stores the result in output[item]. | |
| */ | |
| template <bool IS_FULL_TILE, class T, class It1, class It2> | |
| __device__ __forceinline__ void | |
| gmem_to_reg(T (&output)[ITEMS_PER_THREAD], | |
| It1 input1, | |
| It2 input2, | |
| int count1, | |
| int count2) | |
| { | |
| if (IS_FULL_TILE) | |
| { | |
| #pragma unroll | |
| for (int item = 0; item < ITEMS_PER_THREAD; ++item) | |
| { | |
| int idx = BLOCK_THREADS * item + threadIdx.x; | |
| output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; | |
| } | |
| } | |
| else | |
| { | |
| #pragma unroll | |
| for (int item = 0; item < ITEMS_PER_THREAD; ++item) | |
| { | |
| int idx = BLOCK_THREADS * item + threadIdx.x; | |
| if (idx < count1 + count2) | |
| { | |
| output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; | |
| } | |
| } | |
| } | |
| } | |
| /// \brief Stores data in a coalesced fashion in[item] -> out[BLOCK_THREADS * item + tid] | |
| template <class T, class It> | |
| __device__ __forceinline__ void | |
| reg_to_shared(It output, | |
| T (&input)[ITEMS_PER_THREAD]) | |
| { | |
| #pragma unroll | |
| for (int item = 0; item < ITEMS_PER_THREAD; ++item) | |
| { | |
| int idx = BLOCK_THREADS * item + threadIdx.x; | |
| output[idx] = input[item]; | |
| } | |
| } | |
| template <bool IS_FULL_TILE> | |
| __device__ __forceinline__ void | |
| consume_tile(int tid, OffsetT tile_idx, OffsetT tile_base, int count) | |
| { | |
| OffsetT partition_beg = merge_partitions[tile_idx + 0]; | |
| OffsetT partition_end = merge_partitions[tile_idx + 1]; | |
| // target_merged_tiles_number is a power of two. | |
| OffsetT merged_tiles_number = target_merged_tiles_number / 2; | |
| OffsetT mask = target_merged_tiles_number - 1; | |
| // The first tile number in the tiles group being merged, equal to: | |
| // target_merged_tiles_number * (tile_idx / target_merged_tiles_number) | |
| OffsetT list = ~mask & tile_idx; | |
| OffsetT start = ITEMS_PER_TILE * list; | |
| OffsetT size = ITEMS_PER_TILE * merged_tiles_number; | |
| OffsetT diag = ITEMS_PER_TILE * tile_idx - start; | |
| OffsetT keys1_beg = partition_beg; | |
| OffsetT keys1_end = partition_end; | |
| OffsetT keys2_beg = (cub::min)(keys_count, 2 * start + size + diag - partition_beg); | |
| OffsetT keys2_end = (cub::min)(keys_count, 2 * start + size + diag + ITEMS_PER_TILE - partition_end); | |
| // Check if it's the last tile in the tile group being merged | |
| if (mask == (mask & tile_idx)) | |
| { | |
| keys1_end = (cub::min)(keys_count, start + size); | |
| keys2_end = (cub::min)(keys_count, start + size * 2); | |
| } | |
| // number of keys per tile | |
| // | |
| int num_keys1 = static_cast<int>(keys1_end - keys1_beg); | |
| int num_keys2 = static_cast<int>(keys2_end - keys2_beg); | |
| // load keys1 & keys2 | |
| KeyT keys_local[ITEMS_PER_THREAD]; | |
| if (ping) | |
| { | |
| gmem_to_reg<IS_FULL_TILE>(keys_local, | |
| keys_in_ping + keys1_beg, | |
| keys_in_ping + keys2_beg, | |
| num_keys1, | |
| num_keys2); | |
| } | |
| else | |
| { | |
| gmem_to_reg<IS_FULL_TILE>(keys_local, | |
| keys_in_pong + keys1_beg, | |
| keys_in_pong + keys2_beg, | |
| num_keys1, | |
| num_keys2); | |
| } | |
| reg_to_shared(&storage.keys_shared[0], keys_local); | |
| // preload items into registers already | |
| // | |
| ValueT items_local[ITEMS_PER_THREAD]; | |
| if (!KEYS_ONLY) | |
| { | |
| if (ping) | |
| { | |
| gmem_to_reg<IS_FULL_TILE>(items_local, | |
| items_in_ping + keys1_beg, | |
| items_in_ping + keys2_beg, | |
| num_keys1, | |
| num_keys2); | |
| } | |
| else | |
| { | |
| gmem_to_reg<IS_FULL_TILE>(items_local, | |
| items_in_pong + keys1_beg, | |
| items_in_pong + keys2_beg, | |
| num_keys1, | |
| num_keys2); | |
| } | |
| } | |
| CTA_SYNC(); | |
| // use binary search in shared memory | |
| // to find merge path for each of thread | |
| // we can use int type here, because the number of | |
| // items in shared memory is limited | |
| // | |
| int diag0_local = (cub::min)(num_keys1 + num_keys2, ITEMS_PER_THREAD * tid); | |
| int keys1_beg_local = MergePath<KeyT>(&storage.keys_shared[0], | |
| &storage.keys_shared[num_keys1], | |
| num_keys1, | |
| num_keys2, | |
| diag0_local, | |
| compare_op); | |
| int keys1_end_local = num_keys1; | |
| int keys2_beg_local = diag0_local - keys1_beg_local; | |
| int keys2_end_local = num_keys2; | |
| int num_keys1_local = keys1_end_local - keys1_beg_local; | |
| int num_keys2_local = keys2_end_local - keys2_beg_local; | |
| // perform serial merge | |
| // | |
| int indices[ITEMS_PER_THREAD]; | |
| SerialMerge(&storage.keys_shared[0], | |
| keys1_beg_local, | |
| keys2_beg_local + num_keys1, | |
| num_keys1_local, | |
| num_keys2_local, | |
| keys_local, | |
| indices, | |
| compare_op); | |
| CTA_SYNC(); | |
| // write keys | |
| // | |
| if (ping) | |
| { | |
| if (IS_FULL_TILE) | |
| { | |
| BlockStoreKeysPing(storage.store_keys_ping) | |
| .Store(keys_out_ping + tile_base, keys_local); | |
| } | |
| else | |
| { | |
| BlockStoreKeysPing(storage.store_keys_ping) | |
| .Store(keys_out_ping + tile_base, keys_local, num_keys1 + num_keys2); | |
| } | |
| } | |
| else | |
| { | |
| if (IS_FULL_TILE) | |
| { | |
| BlockStoreKeysPong(storage.store_keys_pong) | |
| .Store(keys_out_pong + tile_base, keys_local); | |
| } | |
| else | |
| { | |
| BlockStoreKeysPong(storage.store_keys_pong) | |
| .Store(keys_out_pong + tile_base, keys_local, num_keys1 + num_keys2); | |
| } | |
| } | |
| // if items are provided, merge them | |
| if (!KEYS_ONLY) | |
| { | |
| CTA_SYNC(); | |
| reg_to_shared(&storage.items_shared[0], items_local); | |
| CTA_SYNC(); | |
| // gather items from shared mem | |
| // | |
| #pragma unroll | |
| for (int item = 0; item < ITEMS_PER_THREAD; ++item) | |
| { | |
| items_local[item] = storage.items_shared[indices[item]]; | |
| } | |
| CTA_SYNC(); | |
| // write from reg to gmem | |
| // | |
| if (ping) | |
| { | |
| if (IS_FULL_TILE) | |
| { | |
| BlockStoreItemsPing(storage.store_items_ping) | |
| .Store(items_out_ping + tile_base, items_local); | |
| } | |
| else | |
| { | |
| BlockStoreItemsPing(storage.store_items_ping) | |
| .Store(items_out_ping + tile_base, items_local, count); | |
| } | |
| } | |
| else | |
| { | |
| if (IS_FULL_TILE) | |
| { | |
| BlockStoreItemsPong(storage.store_items_pong) | |
| .Store(items_out_pong + tile_base, items_local); | |
| } | |
| else | |
| { | |
| BlockStoreItemsPong(storage.store_items_pong) | |
| .Store(items_out_pong + tile_base, items_local, count); | |
| } | |
| } | |
| } | |
| } | |
| __device__ __forceinline__ AgentMerge(bool ping_, | |
| TempStorage &storage_, | |
| KeysLoadPingIt keys_in_ping_, | |
| ItemsLoadPingIt items_in_ping_, | |
| KeysLoadPongIt keys_in_pong_, | |
| ItemsLoadPongIt items_in_pong_, | |
| OffsetT keys_count_, | |
| KeysOutputPingIt keys_out_ping_, | |
| ItemsOutputPingIt items_out_ping_, | |
| KeysOutputPongIt keys_out_pong_, | |
| ItemsOutputPongIt items_out_pong_, | |
| CompareOpT compare_op_, | |
| OffsetT *merge_partitions_, | |
| OffsetT target_merged_tiles_number_) | |
| : ping(ping_) | |
| , storage(storage_.Alias()) | |
| , keys_in_ping(keys_in_ping_) | |
| , items_in_ping(items_in_ping_) | |
| , keys_in_pong(keys_in_pong_) | |
| , items_in_pong(items_in_pong_) | |
| , keys_count(keys_count_) | |
| , keys_out_pong(keys_out_pong_) | |
| , items_out_pong(items_out_pong_) | |
| , keys_out_ping(keys_out_ping_) | |
| , items_out_ping(items_out_ping_) | |
| , compare_op(compare_op_) | |
| , merge_partitions(merge_partitions_) | |
| , target_merged_tiles_number(target_merged_tiles_number_) | |
| {} | |
| __device__ __forceinline__ void Process() | |
| { | |
| int tile_idx = static_cast<int>(blockIdx.x); | |
| int num_tiles = static_cast<int>(gridDim.x); | |
| OffsetT tile_base = OffsetT(tile_idx) * ITEMS_PER_TILE; | |
| int tid = static_cast<int>(threadIdx.x); | |
| int items_in_tile = static_cast<int>( | |
| (cub::min)(static_cast<OffsetT>(ITEMS_PER_TILE), keys_count - tile_base)); | |
| if (tile_idx < num_tiles - 1) | |
| { | |
| consume_tile<true>(tid, tile_idx, tile_base, ITEMS_PER_TILE); | |
| } | |
| else | |
| { | |
| consume_tile<false>(tid, tile_idx, tile_base, items_in_tile); | |
| } | |
| } | |
| }; | |
| CUB_NAMESPACE_END | |