|
|
#pragma once |
|
|
|
|
|
#include <c10/util/Array.h> |
|
|
#include <c10/util/TypeList.h> |
|
|
#include <array> |
|
|
#include <functional> |
|
|
#include <type_traits> |
|
|
|
|
|
namespace c10 { |
|
|
namespace guts { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class Func> |
|
|
struct function_traits { |
|
|
static_assert( |
|
|
!std::is_same<Func, Func>::value, |
|
|
"In function_traits<Func>, Func must be a plain function type."); |
|
|
}; |
|
|
template <class Result, class... Args> |
|
|
struct function_traits<Result(Args...)> { |
|
|
using func_type = Result(Args...); |
|
|
using return_type = Result; |
|
|
using parameter_types = typelist::typelist<Args...>; |
|
|
static constexpr auto number_of_parameters = sizeof...(Args); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Functor> |
|
|
struct infer_function_traits { |
|
|
using type = function_traits< |
|
|
c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>; |
|
|
}; |
|
|
|
|
|
template <typename Result, typename... Args> |
|
|
struct infer_function_traits<Result (*)(Args...)> { |
|
|
using type = function_traits<Result(Args...)>; |
|
|
}; |
|
|
|
|
|
template <typename Result, typename... Args> |
|
|
struct infer_function_traits<Result(Args...)> { |
|
|
using type = function_traits<Result(Args...)>; |
|
|
}; |
|
|
|
|
|
template <typename T> |
|
|
using infer_function_traits_t = typename infer_function_traits<T>::type; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Result, typename ArgList> |
|
|
struct make_function_traits { |
|
|
static_assert( |
|
|
false_t<ArgList>::value, |
|
|
"In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>."); |
|
|
}; |
|
|
|
|
|
template <typename Result, typename... Args> |
|
|
struct make_function_traits<Result, typelist::typelist<Args...>> { |
|
|
using type = function_traits<Result(Args...)>; |
|
|
}; |
|
|
|
|
|
template <typename Result, typename ArgList> |
|
|
using make_function_traits_t = |
|
|
typename make_function_traits<Result, ArgList>::type; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace detail { |
|
|
template < |
|
|
template <class> |
|
|
class Condition, |
|
|
size_t index, |
|
|
class Enable, |
|
|
class... Args> |
|
|
struct extract_arg_by_filtered_index_; |
|
|
template < |
|
|
template <class> |
|
|
class Condition, |
|
|
size_t index, |
|
|
class Head, |
|
|
class... Tail> |
|
|
struct extract_arg_by_filtered_index_< |
|
|
Condition, |
|
|
index, |
|
|
std::enable_if_t<!Condition<Head>::value>, |
|
|
Head, |
|
|
Tail...> { |
|
|
static decltype(auto) call(Head&& , Tail&&... tail) { |
|
|
return extract_arg_by_filtered_index_<Condition, index, void, Tail...>:: |
|
|
call(std::forward<Tail>(tail)...); |
|
|
} |
|
|
}; |
|
|
template < |
|
|
template <class> |
|
|
class Condition, |
|
|
size_t index, |
|
|
class Head, |
|
|
class... Tail> |
|
|
struct extract_arg_by_filtered_index_< |
|
|
Condition, |
|
|
index, |
|
|
std::enable_if_t<Condition<Head>::value && index != 0>, |
|
|
Head, |
|
|
Tail...> { |
|
|
static decltype(auto) call(Head&& , Tail&&... tail) { |
|
|
return extract_arg_by_filtered_index_<Condition, index - 1, void, Tail...>:: |
|
|
call(std::forward<Tail>(tail)...); |
|
|
} |
|
|
}; |
|
|
template <template <class> class Condition, size_t index> |
|
|
struct extract_arg_by_filtered_index_<Condition, index, void> { |
|
|
static void call() { |
|
|
static_assert( |
|
|
index != index, "extract_arg_by_filtered_index out of range."); |
|
|
} |
|
|
}; |
|
|
template < |
|
|
template <class> |
|
|
class Condition, |
|
|
size_t index, |
|
|
class Head, |
|
|
class... Tail> |
|
|
struct extract_arg_by_filtered_index_< |
|
|
Condition, |
|
|
index, |
|
|
std::enable_if_t<Condition<Head>::value && index == 0>, |
|
|
Head, |
|
|
Tail...> { |
|
|
static decltype(auto) call(Head&& head, Tail&&... ) { |
|
|
return std::forward<Head>(head); |
|
|
} |
|
|
}; |
|
|
} |
|
|
template <template <class> class Condition, size_t index, class... Args> |
|
|
decltype(auto) extract_arg_by_filtered_index(Args&&... args) { |
|
|
static_assert( |
|
|
is_type_condition<Condition>::value, |
|
|
"In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member."); |
|
|
return detail:: |
|
|
extract_arg_by_filtered_index_<Condition, index, void, Args...>::call( |
|
|
std::forward<Args>(args)...); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace detail { |
|
|
|
|
|
template <class ResultType, size_t num_results> |
|
|
struct filter_map_ { |
|
|
template < |
|
|
template <class> |
|
|
class Condition, |
|
|
class Mapper, |
|
|
class... Args, |
|
|
size_t... INDEX> |
|
|
static guts::array<ResultType, num_results> call( |
|
|
const Mapper& mapper, |
|
|
std::index_sequence<INDEX...>, |
|
|
Args&&... args) { |
|
|
return guts::array<ResultType, num_results>{ |
|
|
mapper(extract_arg_by_filtered_index<Condition, INDEX>( |
|
|
std::forward<Args>(args)...))...}; |
|
|
} |
|
|
}; |
|
|
template <class ResultType> |
|
|
struct filter_map_<ResultType, 0> { |
|
|
template < |
|
|
template <class> |
|
|
class Condition, |
|
|
class Mapper, |
|
|
class... Args, |
|
|
size_t... INDEX> |
|
|
static guts::array<ResultType, 0> call( |
|
|
const Mapper& , |
|
|
std::index_sequence<INDEX...>, |
|
|
Args&&... ) { |
|
|
return guts::array<ResultType, 0>{}; |
|
|
} |
|
|
}; |
|
|
} |
|
|
|
|
|
template < |
|
|
class ResultType, |
|
|
template <class> |
|
|
class Condition, |
|
|
class Mapper, |
|
|
class... Args> |
|
|
decltype(auto) filter_map(const Mapper& mapper, Args&&... args) { |
|
|
static_assert( |
|
|
is_type_condition<Condition>::value, |
|
|
"In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member."); |
|
|
|
|
|
static constexpr size_t num_results = |
|
|
typelist::count_if<Condition, typelist::typelist<Args...>>::value; |
|
|
return detail::filter_map_<ResultType, num_results>:: |
|
|
template call<Condition, Mapper, Args...>( |
|
|
mapper, |
|
|
std::make_index_sequence<num_results>(), |
|
|
std::forward<Args>(args)...); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <size_t Start, size_t N, size_t... Is> |
|
|
struct make_offset_index_sequence_impl |
|
|
: make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> { |
|
|
static_assert( |
|
|
static_cast<int>(Start) >= 0, |
|
|
"make_offset_index_sequence: Start < 0"); |
|
|
static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0"); |
|
|
}; |
|
|
|
|
|
template <size_t Start, size_t... Is> |
|
|
struct make_offset_index_sequence_impl<Start, 0, Is...> { |
|
|
typedef std::index_sequence<Is...> type; |
|
|
}; |
|
|
|
|
|
template <size_t Start, size_t N> |
|
|
using make_offset_index_sequence = |
|
|
typename make_offset_index_sequence_impl<Start, N>::type; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class Tuple, size_t... Is> |
|
|
constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) { |
|
|
return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class Tuple, int N, class Enable = void> |
|
|
struct TupleTake {}; |
|
|
|
|
|
template <class Tuple, int N> |
|
|
struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> { |
|
|
static auto call(Tuple t) { |
|
|
constexpr size_t size = std::tuple_size<Tuple>(); |
|
|
static_assert(N <= size, "tuple_take: N > size"); |
|
|
return tuple_elements(t, std::make_index_sequence<N>{}); |
|
|
} |
|
|
}; |
|
|
|
|
|
template <class Tuple, int N> |
|
|
struct TupleTake < Tuple, |
|
|
N, std::enable_if_t<N<0, void>> { |
|
|
static auto call(Tuple t) { |
|
|
constexpr size_t size = std::tuple_size<Tuple>(); |
|
|
static_assert(-N <= size, "tuple_take: -N > size"); |
|
|
return tuple_elements(t, make_offset_index_sequence<size + N, -N>{}); |
|
|
} |
|
|
}; |
|
|
|
|
|
template <class Tuple, int N> |
|
|
auto tuple_take(Tuple t) { |
|
|
return TupleTake<Tuple, N>::call(t); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class Tuple, size_t Start, size_t N> |
|
|
constexpr auto tuple_slice(Tuple t) { |
|
|
constexpr size_t size = std::tuple_size<Tuple>(); |
|
|
static_assert(Start + N <= size, "tuple_slice: Start + N > size"); |
|
|
return tuple_elements(t, make_offset_index_sequence<Start, N>{}); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace detail { |
|
|
template <class Mapper, class... Args, size_t... Indices> |
|
|
auto tuple_map( |
|
|
std::tuple<Args...>&& tuple, |
|
|
const Mapper& mapper, |
|
|
std::index_sequence<Indices...>) { |
|
|
return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>( |
|
|
tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...); |
|
|
} |
|
|
} |
|
|
|
|
|
template <class Mapper, class... Args> |
|
|
auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) { |
|
|
return detail::tuple_map( |
|
|
std::move(tuple), mapper, std::index_sequence_for<Args...>()); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace detail { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
size_t index, |
|
|
class HeadTuple, |
|
|
class... TailTuples, |
|
|
std::enable_if_t< |
|
|
index<std::tuple_size<HeadTuple>::value, int> = 0> decltype(auto) |
|
|
extract_tuple_element_by_index( |
|
|
HeadTuple&& head_tuple, |
|
|
TailTuples&&... ) { |
|
|
|
|
|
return std::get<index>(std::forward<HeadTuple>(head_tuple)); |
|
|
} |
|
|
|
|
|
template < |
|
|
size_t index, |
|
|
class HeadTuple, |
|
|
class... TailTuples, |
|
|
std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0> |
|
|
decltype(auto) extract_tuple_element_by_index( |
|
|
HeadTuple&& , |
|
|
TailTuples&&... tail_tuples) { |
|
|
|
|
|
return extract_tuple_element_by_index< |
|
|
index - std::tuple_size<HeadTuple>::value, |
|
|
TailTuples...>(std::forward<TailTuples>(tail_tuples)...); |
|
|
} |
|
|
|
|
|
static_assert( |
|
|
std::is_same< |
|
|
int&&, |
|
|
decltype(extract_tuple_element_by_index<2>( |
|
|
std::tuple<int32_t>(2), |
|
|
std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))>:: |
|
|
value, |
|
|
"extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value"); |
|
|
|
|
|
template <class ConcatenatedTuple, class... Tuples, size_t... ElementIndices> |
|
|
auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) { |
|
|
return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>( |
|
|
std::forward<Tuples>(tuples)...)...); |
|
|
} |
|
|
} |
|
|
|
|
|
template <class... Tuples> |
|
|
auto tuple_concat(Tuples&&... tuples) { |
|
|
using flattened_types = |
|
|
guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>; |
|
|
using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>; |
|
|
constexpr size_t num_elements = guts::typelist::size<flattened_types>::value; |
|
|
return detail::tuple_concat<concatenated_tuple, Tuples...>( |
|
|
std::forward<Tuples>(tuples)..., |
|
|
std::make_index_sequence<num_elements>()); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class... ISeqs> |
|
|
struct concat_iseq { |
|
|
static_assert( |
|
|
false_t<ISeqs...>::value, |
|
|
"In concat_iseq<T1, ...>, the T arguments each must be std::integer_sequence<...> with the same IntType."); |
|
|
}; |
|
|
template <> |
|
|
struct concat_iseq<> { |
|
|
using type = std::index_sequence<>; |
|
|
}; |
|
|
template <class IntType, IntType... Indices> |
|
|
struct concat_iseq<std::integer_sequence<IntType, Indices...>> { |
|
|
using type = std::integer_sequence<IntType, Indices...>; |
|
|
}; |
|
|
template < |
|
|
class IntType, |
|
|
IntType... Head1Indices, |
|
|
IntType... Head2Indices, |
|
|
class... TailISeqs> |
|
|
struct concat_iseq< |
|
|
std::integer_sequence<IntType, Head1Indices...>, |
|
|
std::integer_sequence<IntType, Head2Indices...>, |
|
|
TailISeqs...> { |
|
|
using type = typename concat_iseq< |
|
|
std::integer_sequence<IntType, Head1Indices..., Head2Indices...>, |
|
|
TailISeqs...>::type; |
|
|
}; |
|
|
template <class... ISeqs> |
|
|
using concat_iseq_t = typename concat_iseq<ISeqs...>::type; |
|
|
|
|
|
} |
|
|
} |
|
|
|