| | #pragma once |
| |
|
| | #include <thrust/detail/config.h> |
| |
|
| | #if THRUST_CPP_DIALECT >= 2014 |
| |
|
| | #include <thrust/device_allocator.h> |
| | #include <thrust/future.h> |
| |
|
| | #include <unittest/unittest.h> |
| |
|
| | #include <string> |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | namespace testing |
| | { |
| |
|
| | namespace async |
| | { |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | template <typename AlgoDef> |
| | struct test_policy_overloads |
| | { |
| | using algo_def = AlgoDef; |
| | using input_type = typename algo_def::input_type; |
| | using output_type = typename algo_def::output_type; |
| | using postfix_args_type = typename algo_def::postfix_args_type; |
| |
|
| | static constexpr std::size_t num_postfix_arg_sets = |
| | std::tuple_size<postfix_args_type>::value; |
| |
|
| | |
| | static void run(std::size_t num_values) |
| | { |
| | test_postfix_overloads(num_values); |
| | } |
| |
|
| | private: |
| | template <std::size_t Size> |
| | using size_const = std::integral_constant<std::size_t, Size>; |
| |
|
| | |
| | |
| | template <std::size_t PostfixIdx = 0> |
| | static void test_postfix_overloads(std::size_t const num_values, |
| | size_const<PostfixIdx> = {}) |
| | { |
| | static_assert(PostfixIdx < num_postfix_arg_sets, "Internal error."); |
| |
|
| | run_basic_policy_tests<PostfixIdx>(num_values); |
| | run_after_future_tests<PostfixIdx>(num_values); |
| |
|
| | |
| | test_postfix_overloads(num_values, size_const<PostfixIdx + 1>{}); |
| | } |
| |
|
| | static void test_postfix_overloads(std::size_t const, |
| | size_const<num_postfix_arg_sets>) |
| | { |
| | |
| | } |
| |
|
| | |
| | |
| | |
| | template <std::size_t PostfixIdx> |
| | static void run_basic_policy_tests(std::size_t const num_values) |
| | { |
| | |
| | |
| | auto using_default_stream = [](auto& e) { |
| | ASSERT_NOT_EQUAL(thrust::cuda_cub::default_stream(), |
| | e.stream().native_handle()); |
| | }; |
| |
|
| | |
| | |
| | thrust::system::cuda::detail::unique_stream test_stream{}; |
| | auto using_test_stream = [&test_stream](auto& e) { |
| | ASSERT_EQUAL(test_stream.native_handle(), e.stream().native_handle()); |
| | }; |
| |
|
| | |
| | basic_policy_test<PostfixIdx>("(no policy)", |
| | std::make_tuple(), |
| | using_default_stream, |
| | num_values); |
| |
|
| | basic_policy_test<PostfixIdx>("thrust::device", |
| | std::make_tuple(thrust::device), |
| | using_default_stream, |
| | num_values); |
| |
|
| | basic_policy_test<PostfixIdx>( |
| | "thrust::device(thrust::device_allocator<void>{})", |
| | std::make_tuple(thrust::device(thrust::device_allocator<void>{})), |
| | using_default_stream, |
| | num_values); |
| |
|
| | basic_policy_test<PostfixIdx>("thrust::device.on(test_stream.get())", |
| | std::make_tuple( |
| | thrust::device.on(test_stream.get())), |
| | using_test_stream, |
| | num_values); |
| |
|
| | basic_policy_test<PostfixIdx>( |
| | "thrust::device(thrust::device_allocator<void>{}).on(test_stream.get())", |
| | std::make_tuple( |
| | thrust::device(thrust::device_allocator<void>{}).on(test_stream.get())), |
| | using_test_stream, |
| | num_values); |
| | } |
| |
|
| | |
| | |
| | template <std::size_t PostfixIdx, |
| | typename PrefixArgTuple, |
| | typename ValidateEvent> |
| | static void basic_policy_test(std::string const &policy_desc, |
| | PrefixArgTuple &&prefix_tuple_ref, |
| | ValidateEvent const &validate, |
| | std::size_t num_values) |
| | try |
| | { |
| | |
| | |
| | using prefix_tuple_type = thrust::remove_cvref_t<PrefixArgTuple>; |
| | prefix_tuple_type const prefix_tuple = THRUST_FWD(prefix_tuple_ref); |
| |
|
| | using postfix_tuple_type = |
| | std::tuple_element_t<PostfixIdx, postfix_args_type>; |
| | postfix_tuple_type const postfix_tuple = get_postfix_tuple<PostfixIdx>(); |
| |
|
| | |
| | constexpr auto prefix_tuple_size = std::tuple_size<prefix_tuple_type>{}; |
| | constexpr auto postfix_tuple_size = std::tuple_size<postfix_tuple_type>{}; |
| | using prefix_index_seq = std::make_index_sequence<prefix_tuple_size>; |
| | using postfix_index_seq = std::make_index_sequence<postfix_tuple_size>; |
| |
|
| | |
| | |
| | input_type input_a = algo_def::generate_input(num_values); |
| | input_type input_b = algo_def::generate_input(num_values); |
| | input_type input_c = algo_def::generate_input(num_values); |
| | input_type input_d = algo_def::generate_input(num_values); |
| | input_type input_ref = algo_def::generate_input(num_values); |
| |
|
| | output_type output_a = algo_def::generate_output(num_values, input_a); |
| | output_type output_b = algo_def::generate_output(num_values, input_b); |
| | output_type output_c = algo_def::generate_output(num_values, input_c); |
| | output_type output_d = algo_def::generate_output(num_values, input_d); |
| | output_type output_ref = algo_def::generate_output(num_values, input_ref); |
| |
|
| | |
| | |
| | auto e_a = algo_def::invoke_async(prefix_tuple, |
| | prefix_index_seq{}, |
| | input_a, |
| | output_a, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | auto e_b = algo_def::invoke_async(prefix_tuple, |
| | prefix_index_seq{}, |
| | input_b, |
| | output_b, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | auto e_c = algo_def::invoke_async(prefix_tuple, |
| | prefix_index_seq{}, |
| | input_c, |
| | output_c, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | auto e_d = algo_def::invoke_async(prefix_tuple, |
| | prefix_index_seq{}, |
| | input_d, |
| | output_d, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| |
|
| | |
| | algo_def::invoke_reference(input_ref, |
| | output_ref, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| |
|
| | |
| | algo_def::compare_outputs(e_a, output_ref, output_a); |
| | algo_def::compare_outputs(e_b, output_ref, output_b); |
| | algo_def::compare_outputs(e_c, output_ref, output_c); |
| | algo_def::compare_outputs(e_d, output_ref, output_d); |
| |
|
| | validate(e_a); |
| | validate(e_b); |
| | validate(e_c); |
| | validate(e_d); |
| | } |
| | catch (unittest::UnitTestException &exc) |
| | { |
| | |
| | |
| | using overload_t = std::tuple_element_t<PostfixIdx, postfix_args_type>; |
| |
|
| | std::string const overload_desc = |
| | unittest::demangle(typeid(overload_t).name()); |
| | std::string const input_desc = |
| | unittest::demangle(typeid(input_type).name()); |
| | std::string const output_desc = |
| | unittest::demangle(typeid(output_type).name()); |
| |
|
| | exc << "\n" |
| | << " - algo_def::description = " << algo_def::description() << "\n" |
| | << " - test = basic_policy\n" |
| | << " - policy = " << policy_desc << "\n" |
| | << " - input_type = " << input_desc << "\n" |
| | << " - output_type = " << output_desc << "\n" |
| | << " - tuple of trailing arguments = " << overload_desc << "\n" |
| | << " - num_values = " << num_values; |
| | throw; |
| | } |
| |
|
| | |
| | |
| | template <std::size_t PostfixIdx> |
| | static void run_after_future_tests(std::size_t const num_values) |
| | try |
| | { |
| | using postfix_tuple_type = |
| | std::tuple_element_t<PostfixIdx, postfix_args_type>; |
| | postfix_tuple_type const postfix_tuple = get_postfix_tuple<PostfixIdx>(); |
| |
|
| | |
| | |
| | |
| | constexpr auto postfix_tuple_size = std::tuple_size<postfix_tuple_type>{}; |
| | using prefix_index_seq = std::make_index_sequence<1>; |
| | using postfix_index_seq = std::make_index_sequence<postfix_tuple_size>; |
| |
|
| | |
| | |
| | input_type input_a = algo_def::generate_input(num_values); |
| | input_type input_b = algo_def::generate_input(num_values); |
| | input_type input_c = algo_def::generate_input(num_values); |
| | input_type input_tmp = algo_def::generate_input(num_values); |
| | input_type input_ref = algo_def::generate_input(num_values); |
| |
|
| | output_type output_a = algo_def::generate_output(num_values, input_a); |
| | output_type output_b = algo_def::generate_output(num_values, input_b); |
| | output_type output_c = algo_def::generate_output(num_values, input_c); |
| | output_type output_tmp = algo_def::generate_output(num_values, input_tmp); |
| | output_type output_ref = algo_def::generate_output(num_values, input_ref); |
| |
|
| | auto e_a = algo_def::invoke_async(std::make_tuple(thrust::device), |
| | prefix_index_seq{}, |
| | input_a, |
| | output_a, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | ASSERT_EQUAL(true, e_a.valid_stream()); |
| | auto const stream_a = e_a.stream().native_handle(); |
| |
|
| | |
| | ASSERT_NOT_EQUAL_QUIET(thrust::cuda_cub::default_stream(), stream_a); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | auto e_b = |
| | algo_def::invoke_async(std::forward_as_tuple(thrust::device.after(e_a)), |
| | prefix_index_seq{}, |
| | input_b, |
| | output_b, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | ASSERT_EQUAL(true, e_b.valid_stream()); |
| | auto const stream_b = e_b.stream().native_handle(); |
| |
|
| | |
| | ASSERT_EQUAL_QUIET(stream_a, stream_b); |
| |
|
| | |
| | ASSERT_THROWS_EQUAL(auto x = algo_def::invoke_async( |
| | std::forward_as_tuple(thrust::device.after(e_a)), |
| | prefix_index_seq{}, |
| | input_tmp, |
| | output_tmp, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | THRUST_UNUSED_VAR(x), |
| | thrust::event_error, |
| | thrust::event_error(thrust::event_errc::no_state)); |
| |
|
| | |
| | |
| | |
| | |
| | auto policy_after_e_b = thrust::device.after(e_b); |
| | auto policy_after_e_b_tuple = std::forward_as_tuple(policy_after_e_b); |
| | auto e_c = |
| | algo_def::invoke_async(policy_after_e_b_tuple, |
| | prefix_index_seq{}, |
| | input_c, |
| | output_c, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | ASSERT_EQUAL(true, e_c.valid_stream()); |
| | auto const stream_c = e_c.stream().native_handle(); |
| |
|
| | |
| | ASSERT_EQUAL_QUIET(stream_b, stream_c); |
| |
|
| | |
| | ASSERT_THROWS_EQUAL( |
| | auto x = algo_def::invoke_async(policy_after_e_b_tuple, |
| | prefix_index_seq{}, |
| | input_tmp, |
| | output_tmp, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| | THRUST_UNUSED_VAR(x), |
| | thrust::event_error, |
| | thrust::event_error(thrust::event_errc::no_state)); |
| |
|
| | |
| | algo_def::invoke_reference(input_ref, |
| | output_ref, |
| | postfix_tuple, |
| | postfix_index_seq{}); |
| |
|
| | |
| | |
| | |
| | algo_def::compare_outputs(e_c, output_ref, output_a); |
| | algo_def::compare_outputs(e_c, output_ref, output_b); |
| | algo_def::compare_outputs(e_c, output_ref, output_c); |
| | } |
| | catch (unittest::UnitTestException &exc) |
| | { |
| | |
| | |
| | using postfix_t = std::tuple_element_t<PostfixIdx, postfix_args_type>; |
| |
|
| | std::string const postfix_desc = |
| | unittest::demangle(typeid(postfix_t).name()); |
| | std::string const input_desc = |
| | unittest::demangle(typeid(input_type).name()); |
| | std::string const output_desc = |
| | unittest::demangle(typeid(output_type).name()); |
| |
|
| | exc << "\n" |
| | << " - algo_def::description = " << algo_def::description() << "\n" |
| | << " - test = after_future\n" |
| | << " - input_type = " << input_desc << "\n" |
| | << " - output_type = " << output_desc << "\n" |
| | << " - tuple of trailing arguments = " << postfix_desc << "\n" |
| | << " - num_values = " << num_values; |
| | throw; |
| | } |
| |
|
| | |
| | |
| | template <std::size_t PostfixIdx> |
| | static auto get_postfix_tuple() |
| | { |
| | return std::get<PostfixIdx>(algo_def::generate_postfix_args()); |
| | } |
| | }; |
| |
|
| | } |
| | } |
| |
|
| | #endif |
| |
|