thrust / install /include /cub /block /block_merge_sort.cuh
camenduru's picture
thanks to nvidia ❤
0dc1b04
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cub/thread/thread_sort.cuh>
#include <cub/util_math.cuh>
#include <cub/util_namespace.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
CUB_NAMESPACE_BEGIN
// Additional details of the Merge-Path Algorithm can be found in:
// S. Odeh, O. Green, Z. Mwassi, O. Shmueli, Y. Birk, " Merge Path - Parallel
// Merging Made Simple", Multithreaded Architectures and Applications (MTAAP)
// Workshop, IEEE 26th International Parallel & Distributed Processing
// Symposium (IPDPS), 2012
template <typename KeyT,
typename KeyIteratorT,
typename OffsetT,
typename BinaryPred>
__device__ __forceinline__ OffsetT MergePath(KeyIteratorT keys1,
KeyIteratorT keys2,
OffsetT keys1_count,
OffsetT keys2_count,
OffsetT diag,
BinaryPred binary_pred)
{
OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
OffsetT keys1_end = (cub::min)(diag, keys1_count);
while (keys1_begin < keys1_end)
{
OffsetT mid = cub::MidPoint<OffsetT>(keys1_begin, keys1_end);
KeyT key1 = keys1[mid];
KeyT key2 = keys2[diag - 1 - mid];
bool pred = binary_pred(key2, key1);
if (pred)
{
keys1_end = mid;
}
else
{
keys1_begin = mid + 1;
}
}
return keys1_begin;
}
template <typename KeyT, typename CompareOp, int ITEMS_PER_THREAD>
__device__ __forceinline__ void SerialMerge(KeyT *keys_shared,
int keys1_beg,
int keys2_beg,
int keys1_count,
int keys2_count,
KeyT (&output)[ITEMS_PER_THREAD],
int (&indices)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
int keys1_end = keys1_beg + keys1_count;
int keys2_end = keys2_beg + keys2_count;
KeyT key1 = keys_shared[keys1_beg];
KeyT key2 = keys_shared[keys2_beg];
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
bool p = (keys2_beg < keys2_end) &&
((keys1_beg >= keys1_end)
|| compare_op(key2, key1));
output[item] = p ? key2 : key1;
indices[item] = p ? keys2_beg++ : keys1_beg++;
if (p)
{
key2 = keys_shared[keys2_beg];
}
else
{
key1 = keys_shared[keys1_beg];
}
}
}
/**
* @brief Generalized merge sort algorithm
*
* This class is used to reduce code duplication. Warp and Block merge sort
* differ only in how they compute thread index and how they synchronize
* threads. Since synchronization might require access to custom data
* (like member mask), CRTP is used.
*
* @par
* The code snippet below illustrates the way this class can be used.
* @par
* @code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_merge_sort.cuh>
*
* constexpr int BLOCK_THREADS = 256;
* constexpr int ITEMS_PER_THREAD = 9;
*
* class BlockMergeSort : public BlockMergeSortStrategy<int,
* cub::NullType,
* BLOCK_THREADS,
* ITEMS_PER_THREAD,
* BlockMergeSort>
* {
* using BlockMergeSortStrategyT =
* BlockMergeSortStrategy<int,
* cub::NullType,
* BLOCK_THREADS,
* ITEMS_PER_THREAD,
* BlockMergeSort>;
* public:
* __device__ __forceinline__ explicit BlockMergeSort(
* typename BlockMergeSortStrategyT::TempStorage &temp_storage)
* : BlockMergeSortStrategyT(temp_storage, threadIdx.x)
* {}
*
* __device__ __forceinline__ void SyncImplementation() const
* {
* __syncthreads();
* }
* };
* @endcode
*
* @tparam KeyT
* KeyT type
*
* @tparam ValueT
* ValueT type. cub::NullType indicates a keys-only sort
*
* @tparam SynchronizationPolicy
* Provides a way of synchronizing threads. Should be derived from
* `BlockMergeSortStrategy`.
*/
template <typename KeyT,
typename ValueT,
int NUM_THREADS,
int ITEMS_PER_THREAD,
typename SynchronizationPolicy>
class BlockMergeSortStrategy
{
static_assert(PowerOfTwo<NUM_THREADS>::VALUE,
"NUM_THREADS must be a power of two");
private:
static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * NUM_THREADS;
// Whether or not there are values to be trucked along with keys
static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
/// Shared memory type required by this thread block
union _TempStorage
{
KeyT keys_shared[ITEMS_PER_TILE + 1];
ValueT items_shared[ITEMS_PER_TILE + 1];
}; // union TempStorage
/// Shared storage reference
_TempStorage &temp_storage;
/// Internal storage allocator
__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}
const unsigned int linear_tid;
public:
/// \smemstorage{BlockMergeSort}
struct TempStorage : Uninitialized<_TempStorage> {};
BlockMergeSortStrategy() = delete;
explicit __device__ __forceinline__
BlockMergeSortStrategy(unsigned int linear_tid)
: temp_storage(PrivateStorage())
, linear_tid(linear_tid)
{}
__device__ __forceinline__ BlockMergeSortStrategy(TempStorage &temp_storage,
unsigned int linear_tid)
: temp_storage(temp_storage.Alias())
, linear_tid(linear_tid)
{}
__device__ __forceinline__ unsigned int get_linear_tid() const
{
return linear_tid;
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* Sort is not guaranteed to be stable. That is, suppose that i and j are
* equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
__device__ __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
ValueT items[ITEMS_PER_THREAD];
Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - Sort is not guaranteed to be stable. That is, suppose that `i` and `j`
* are equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered
* after `oob_default`, it won't be placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
__device__ __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
ValueT items[ITEMS_PER_THREAD];
Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using a merge sorting method.
*
* @par
* Sort is not guaranteed to be stable. That is, suppose that `i` and `j` are
* equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
__device__ __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - Sort is not guaranteed to be stable. That is, suppose that `i` and `j`
* are equivalent: neither one is less than the other. It is not guaranteed
* that the relative order of these two elements will be preserved by sort.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered
* after `oob_default`, it won't be placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @tparam IS_LAST_TILE
* True if `valid_items` isn't equal to the `ITEMS_PER_TILE`
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp,
bool IS_LAST_TILE = true>
__device__ __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
if (IS_LAST_TILE)
{
// if last tile, find valid max_key
// and fill the remaining keys with it
//
KeyT max_key = oob_default;
#pragma unroll
for (int item = 1; item < ITEMS_PER_THREAD; ++item)
{
if (ITEMS_PER_THREAD * linear_tid + item < valid_items)
{
max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
}
else
{
keys[item] = max_key;
}
}
}
// if first element of thread is in input range, stable sort items
//
if (!IS_LAST_TILE || ITEMS_PER_THREAD * linear_tid < valid_items)
{
StableOddEvenSort(keys, items, compare_op);
}
// each thread has sorted keys
// merge sort keys in shared memory
//
for (int target_merged_threads_number = 2;
target_merged_threads_number <= NUM_THREADS;
target_merged_threads_number *= 2)
{
int merged_threads_number = target_merged_threads_number / 2;
int mask = target_merged_threads_number - 1;
Sync();
// store keys in shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = ITEMS_PER_THREAD * linear_tid + item;
temp_storage.keys_shared[idx] = keys[item];
}
Sync();
int indices[ITEMS_PER_THREAD];
int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
int size = ITEMS_PER_THREAD * merged_threads_number;
int thread_idx_in_thread_group_being_merged = mask & linear_tid;
int diag =
(cub::min)(valid_items,
ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
int keys1_beg = (cub::min)(valid_items, start);
int keys1_end = (cub::min)(valid_items, keys1_beg + size);
int keys2_beg = keys1_end;
int keys2_end = (cub::min)(valid_items, keys2_beg + size);
int keys1_count = keys1_end - keys1_beg;
int keys2_count = keys2_end - keys2_beg;
int partition_diag = MergePath<KeyT>(&temp_storage.keys_shared[keys1_beg],
&temp_storage.keys_shared[keys2_beg],
keys1_count,
keys2_count,
diag,
compare_op);
int keys1_beg_loc = keys1_beg + partition_diag;
int keys1_end_loc = keys1_end;
int keys2_beg_loc = keys2_beg + diag - partition_diag;
int keys2_end_loc = keys2_end;
int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
SerialMerge(&temp_storage.keys_shared[0],
keys1_beg_loc,
keys2_beg_loc,
keys1_count_loc,
keys2_count_loc,
keys,
indices,
compare_op);
if (!KEYS_ONLY)
{
Sync();
// store keys in shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = ITEMS_PER_THREAD * linear_tid + item;
temp_storage.items_shared[idx] = items[item];
}
Sync();
// gather items from shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
items[item] = temp_storage.items_shared[indices[item]];
}
}
}
} // func block_merge_sort
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes `y`,
* and if the two elements are equivalent (neither `x < y` nor `y < x`) then
* a postcondition of StableSort is that `x` still precedes `y`.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
__device__ __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
Sort(keys, compare_op);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes `y`,
* and if the two elements are equivalent (neither `x < y` nor `y < x`) then
* a postcondition of StableSort is that `x` still precedes `y`.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
__device__ __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
Sort(keys, items, compare_op);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes
* `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`)
* then a postcondition of StableSort is that `x` still precedes `y`.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`.
* If there is a value that is ordered after `oob_default`, it won't be
* placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @param[in,out] keys
* Keys to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp>
__device__ __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
Sort(keys, compare_op, valid_items, oob_default);
}
/**
* @brief Sorts items partitioned across a CUDA thread block using
* a merge sorting method.
*
* @par
* - StableSort is stable: it preserves the relative ordering of equivalent
* elements. That is, if `x` and `y` are elements such that `x` precedes
* `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`)
* then a postcondition of StableSort is that `x` still precedes `y`.
* - The value of `oob_default` is assigned to all elements that are out of
* `valid_items` boundaries. It's expected that `oob_default` is ordered
* after any value in the `valid_items` boundaries. The algorithm always
* sorts a fixed amount of elements, which is equal to
* `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered
* after `oob_default`, it won't be placed within `valid_items` boundaries.
*
* @tparam CompareOp
* functor type having member `bool operator()(KeyT lhs, KeyT rhs)`.
* `CompareOp` is a model of [Strict Weak Ordering].
*
* @tparam IS_LAST_TILE
* True if `valid_items` isn't equal to the `ITEMS_PER_TILE`
*
* @param[in,out] keys
* Keys to sort
*
* @param[in,out] items
* Values to sort
*
* @param[in] compare_op
* Comparison function object which returns true if the first argument is
* ordered before the second
*
* @param[in] valid_items
* Number of valid items to sort
*
* @param[in] oob_default
* Default value to assign out-of-bound items
*
* [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order
*/
template <typename CompareOp,
bool IS_LAST_TILE = true>
__device__ __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
Sort<CompareOp, IS_LAST_TILE>(keys,
items,
compare_op,
valid_items,
oob_default);
}
private:
__device__ __forceinline__ void Sync() const
{
static_cast<const SynchronizationPolicy*>(this)->SyncImplementation();
}
};
/**
* @brief The BlockMergeSort class provides methods for sorting items
* partitioned across a CUDA thread block using a merge sorting method.
* @ingroup BlockModule
*
* @tparam KeyT
* KeyT type
*
* @tparam BLOCK_DIM_X
* The thread block length in threads along the X dimension
*
* @tparam ITEMS_PER_THREAD
* The number of items per thread
*
* @tparam ValueT
* **[optional]** ValueT type (default: `cub::NullType`, which indicates
* a keys-only sort)
*
* @tparam BLOCK_DIM_Y
* **[optional]** The thread block length in threads along the Y dimension
* (default: 1)
*
* @tparam BLOCK_DIM_Z
* **[optional]** The thread block length in threads along the Z dimension
* (default: 1)
*
* @par Overview
* BlockMergeSort arranges items into ascending order using a comparison
* functor with less-than semantics. Merge sort can handle arbitrary types
* and comparison functors, but is slower than BlockRadixSort when sorting
* arithmetic types into ascending/descending order.
*
* @par A Simple Example
* @blockcollective{BlockMergeSort}
* @par
* The code snippet below illustrates a sort of 512 integer keys that are
* partitioned across 128 threads * where each thread owns 4 consecutive items.
* @par
* @code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_merge_sort.cuh>
*
* struct CustomLess
* {
* template <typename DataType>
* __device__ bool operator()(const DataType &lhs, const DataType &rhs)
* {
* return lhs < rhs;
* }
* };
*
* __global__ void ExampleKernel(...)
* {
* // Specialize BlockMergeSort for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockMergeSort<int, 128, 4> BlockMergeSort;
*
* // Allocate shared memory for BlockMergeSort
* __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle;
*
* // Obtain a segment of consecutive items that are blocked across threads
* int thread_keys[4];
* ...
*
* BlockMergeSort(temp_storage_shuffle).Sort(thread_keys, CustomLess());
* ...
* }
* @endcode
* @par
* Suppose the set of input `thread_keys` across the block of threads is
* `{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`.
* The corresponding output `thread_keys` in those threads will be
* `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`.
*
* @par Re-using dynamically allocating shared memory
* The following example under the examples/block folder illustrates usage of
* dynamically shared memory with BlockReduce and how to re-purpose
* the same memory region:
* <a href="../../examples/block/example_block_reduce_dyn_smem.cu">example_block_reduce_dyn_smem.cu</a>
*
* This example can be easily adapted to the storage required by BlockMergeSort.
*/
template <typename KeyT,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
typename ValueT = NullType,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1>
class BlockMergeSort
: public BlockMergeSortStrategy<KeyT,
ValueT,
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
ITEMS_PER_THREAD,
BlockMergeSort<KeyT,
BLOCK_DIM_X,
ITEMS_PER_THREAD,
ValueT,
BLOCK_DIM_Y,
BLOCK_DIM_Z>>
{
private:
// The thread block size in threads
static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
using BlockMergeSortStrategyT =
BlockMergeSortStrategy<KeyT,
ValueT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
BlockMergeSort>;
public:
__device__ __forceinline__ BlockMergeSort()
: BlockMergeSortStrategyT(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
__device__ __forceinline__ explicit BlockMergeSort(
typename BlockMergeSortStrategyT::TempStorage &temp_storage)
: BlockMergeSortStrategyT(
temp_storage,
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
private:
__device__ __forceinline__ void SyncImplementation() const
{
CTA_SYNC();
}
friend BlockMergeSortStrategyT;
};
CUB_NAMESPACE_END