thrust / dependencies /cub /test /catch2_test_warp_exchange.cu
camenduru's picture
thanks to nvidia ❤
8ae5fc5
/******************************************************************************
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <cub/util_macro.cuh>
#include <cub/warp/warp_exchange.cuh>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/reverse.h>
#include <thrust/sequence.h>
#include <type_traits>
#include "fill_striped.cuh"
// Has to go after all cub headers. Otherwise, this test won't catch unused
// variables in cub kernels.
#include "catch2_test_helper.h"
template <typename InputT, typename OutputT, int ItemsPerThread, typename = void>
struct exchange_data_t;
template <typename InputT, typename OutputT, int ItemsPerThread>
struct exchange_data_t<InputT,
OutputT,
ItemsPerThread,
typename std::enable_if<std::is_same<InputT, OutputT>::value>::type>
{
InputT input[ItemsPerThread];
OutputT (&output)[ItemsPerThread] = input;
template <int LogicalWarpThreads>
inline __device__ void
scatter(cub::WarpExchange<InputT, ItemsPerThread, LogicalWarpThreads> &exchange,
int (&ranks)[ItemsPerThread])
{
exchange.ScatterToStriped(input, ranks);
}
};
template <typename InputT, typename OutputT, int ItemsPerThread>
struct exchange_data_t<InputT,
OutputT,
ItemsPerThread,
typename std::enable_if<!std::is_same<InputT, OutputT>::value>::type>
{
InputT input[ItemsPerThread];
OutputT output[ItemsPerThread];
template <int LogicalWarpThreads>
inline __device__ void
scatter(cub::WarpExchange<InputT, ItemsPerThread, LogicalWarpThreads> &exchange,
int (&ranks)[ItemsPerThread])
{
exchange.ScatterToStriped(input, output, ranks);
}
};
template <int LOGICAL_WARP_THREADS,
int ITEMS_PER_THREAD,
int TOTAL_WARPS,
typename InputT,
typename OutputT>
__global__ void scatter_kernel(const InputT *input_data, OutputT *output_data)
{
using warp_exchange_t = cub::WarpExchange<InputT, ITEMS_PER_THREAD, LOGICAL_WARP_THREADS>;
using storage_t = typename warp_exchange_t::TempStorage;
constexpr int tile_size = ITEMS_PER_THREAD * LOGICAL_WARP_THREADS;
__shared__ storage_t temp_storage[TOTAL_WARPS];
const int tid = cub::RowMajorTid(blockDim.x, blockDim.y, blockDim.z);
// Get warp index
const int warp_id = tid / LOGICAL_WARP_THREADS;
const int lane_id = tid % LOGICAL_WARP_THREADS;
warp_exchange_t exchange(temp_storage[warp_id]);
exchange_data_t<InputT, OutputT, ITEMS_PER_THREAD> exchange_data;
// Reverse data
int ranks[ITEMS_PER_THREAD];
input_data += warp_id * tile_size;
output_data += warp_id * tile_size;
for (int item = 0; item < ITEMS_PER_THREAD; item++)
{
const auto item_idx = lane_id * ITEMS_PER_THREAD + item;
exchange_data.input[item] = input_data[item_idx];
ranks[item] = tile_size - 1 - item_idx;
}
exchange_data.scatter(exchange, ranks);
// Striped to blocked
for (int item = 0; item < ITEMS_PER_THREAD; item++)
{
output_data[item * LOGICAL_WARP_THREADS + lane_id] = exchange_data.output[item];
}
}
template <int LOGICAL_WARP_THREADS,
int ITEMS_PER_THREAD,
int TOTAL_WARPS,
typename InputT,
typename OutputT>
void warp_scatter_strided(thrust::device_vector<InputT> &in, thrust::device_vector<OutputT> &out)
{
scatter_kernel<LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, TOTAL_WARPS, InputT, OutputT>
<<<1, LOGICAL_WARP_THREADS * TOTAL_WARPS>>>(thrust::raw_pointer_cast(in.data()),
thrust::raw_pointer_cast(out.data()));
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
}
template <int LOGICAL_WARP_THREADS,
int ITEMS_PER_THREAD,
int TOTAL_WARPS,
typename InputT,
typename OutputT,
typename ActionT>
__global__ void kernel(const InputT *input_data, OutputT *output_data, ActionT action)
{
using warp_exchange_t = cub::WarpExchange<InputT, ITEMS_PER_THREAD, LOGICAL_WARP_THREADS>;
using storage_t = typename warp_exchange_t::TempStorage;
constexpr int tile_size = ITEMS_PER_THREAD * LOGICAL_WARP_THREADS;
__shared__ storage_t temp_storage[TOTAL_WARPS];
const int tid = cub::RowMajorTid(blockDim.x, blockDim.y, blockDim.z);
// Get warp index
const int warp_id = tid / LOGICAL_WARP_THREADS;
const int lane_id = tid % LOGICAL_WARP_THREADS;
warp_exchange_t exchange(temp_storage[warp_id]);
exchange_data_t<InputT, OutputT, ITEMS_PER_THREAD> exchange_data;
input_data += warp_id * tile_size;
output_data += warp_id * tile_size;
for (int item = 0; item < ITEMS_PER_THREAD; item++)
{
exchange_data.input[item] = input_data[lane_id * ITEMS_PER_THREAD + item];
}
action(exchange_data.input, exchange_data.output, exchange);
for (int item = 0; item < ITEMS_PER_THREAD; item++)
{
output_data[lane_id * ITEMS_PER_THREAD + item] = exchange_data.output[item];
}
}
template <int LOGICAL_WARP_THREADS,
int ITEMS_PER_THREAD,
int TOTAL_WARPS,
typename InputT,
typename OutputT,
typename ActionT>
void warp_exchange(thrust::device_vector<InputT> &in,
thrust::device_vector<OutputT> &out,
ActionT action)
{
kernel<LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, TOTAL_WARPS, InputT, OutputT, ActionT>
<<<1, LOGICAL_WARP_THREADS * TOTAL_WARPS>>>(thrust::raw_pointer_cast(in.data()),
thrust::raw_pointer_cast(out.data()),
action);
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
}
struct blocked_to_striped
{
template <typename InputT,
typename OutputT,
int LogicalWarpThreads,
int ItemsPerThread,
int ITEMS_PER_THREAD>
__device__ void operator()(InputT (&input)[ITEMS_PER_THREAD],
OutputT (&output)[ITEMS_PER_THREAD],
cub::WarpExchange<InputT, ItemsPerThread, LogicalWarpThreads> &exchange)
{
exchange.BlockedToStriped(input, output);
}
};
struct striped_to_blocked
{
template <typename InputT,
typename OutputT,
int LogicalWarpThreads,
int ItemsPerThread,
int ITEMS_PER_THREAD>
__device__ void operator()(InputT (&input)[ITEMS_PER_THREAD],
OutputT (&output)[ITEMS_PER_THREAD],
cub::WarpExchange<InputT, ItemsPerThread, LogicalWarpThreads> &exchange)
{
exchange.StripedToBlocked(input, output);
}
};
template <typename T>
thrust::host_vector<T> compute_host_reference(const thrust::device_vector<T> &d_input,
int tile_size)
{
thrust::host_vector<T> input = d_input;
int num_warps = CUB_QUOTIENT_CEILING(static_cast<int>(d_input.size()), tile_size);
for (int warp_id = 0; warp_id < num_warps; warp_id++)
{
const int warp_data_begin = tile_size * warp_id;
const int warp_data_end = warp_data_begin + tile_size;
thrust::reverse(input.begin() + warp_data_begin, input.begin() + warp_data_end);
}
return input;
}
using inout_types = c2h::type_list<c2h::pair<std::uint16_t, std::int64_t>,
c2h::pair<std::uint16_t, std::uint32_t>,
c2h::pair<std::int32_t, std::int32_t>,
c2h::pair<std::int64_t, std::int64_t>,
c2h::pair<uchar3, uchar3>,
c2h::pair<ulonglong4, ulonglong4>>;
using logical_warp_threads = c2h::enum_type_list<int, 4, 16, 32>;
using items_per_thread = c2h::enum_type_list<int, 1, 4, 7>;
template <int logical_warp_threads>
struct total_warps_t
{
private:
static constexpr int max_warps = 2;
static constexpr bool is_arch_warp = (logical_warp_threads == CUB_WARP_THREADS(0));
static constexpr bool is_pow_of_two = ((logical_warp_threads & (logical_warp_threads - 1)) == 0);
static constexpr int total_warps = (is_arch_warp || is_pow_of_two) ? max_warps : 1;
public:
static constexpr int value() { return total_warps; }
};
template <class TestType>
struct params_t
{
using in_type = typename c2h::first<c2h::get<0, TestType>>;
using out_type = typename c2h::second<c2h::get<0, TestType>>;
static constexpr int logical_warp_threads = c2h::get<1, TestType>::value;
static constexpr int items_per_thread = c2h::get<2, TestType>::value;
static constexpr int total_warps = total_warps_t<logical_warp_threads>::value();
static constexpr int tile_size = logical_warp_threads * items_per_thread;
static constexpr int total_item_count = total_warps * tile_size;
};
CUB_TEST("Scatter to striped works",
"[exchange][warp]",
inout_types,
logical_warp_threads,
items_per_thread)
{
using params = params_t<TestType>;
using in_type = typename params::in_type;
using out_type = typename params::out_type;
thrust::device_vector<out_type> d_out(params::total_item_count);
thrust::device_vector<in_type> d_in(params::total_item_count);
c2h::gen(c2h::modulo_t{d_in.size()}, d_in);
warp_scatter_strided<params::logical_warp_threads, params::items_per_thread, params::total_warps>(
d_in,
d_out);
auto h_expected_output = compute_host_reference(d_in, params::tile_size);
REQUIRE(h_expected_output == d_out);
}
CUB_TEST("Blocked to striped works",
"[exchange][warp]",
inout_types,
logical_warp_threads,
items_per_thread)
{
using params = params_t<TestType>;
using in_type = typename params::in_type;
using out_type = typename params::out_type;
thrust::device_vector<out_type> d_out(params::total_item_count, out_type{});
thrust::device_vector<in_type> d_in(params::total_item_count);
c2h::gen(c2h::modulo_t{d_in.size()}, d_in);
warp_exchange<params::logical_warp_threads, params::items_per_thread, params::total_warps>(
d_in,
d_out,
blocked_to_striped{});
thrust::host_vector<out_type> h_expected_output(d_out.size());
fill_striped<params::logical_warp_threads,
params::items_per_thread,
params::logical_warp_threads * params::total_warps>(h_expected_output.begin());
REQUIRE(h_expected_output == d_out);
}
CUB_TEST("Striped to blocked works",
"[exchange][warp]",
inout_types,
logical_warp_threads,
items_per_thread)
{
using params = params_t<TestType>;
using in_type = typename params::in_type;
using out_type = typename params::out_type;
thrust::device_vector<out_type> d_out(params::total_item_count, out_type{});
thrust::host_vector<in_type> h_in(params::total_item_count);
fill_striped<params::logical_warp_threads,
params::items_per_thread,
params::logical_warp_threads * params::total_warps>(h_in.begin());
thrust::device_vector<in_type> d_in = h_in;
warp_exchange<params::logical_warp_threads, params::items_per_thread, params::total_warps>(
d_in,
d_out,
striped_to_blocked{});
thrust::device_vector<out_type> d_expected_output(d_out.size());
c2h::gen(c2h::modulo_t{d_out.size()}, d_expected_output);
REQUIRE(d_expected_output == d_out);
}