| |
| |
|
|
| #pragma once |
|
|
| #include "ck/wrapper/utils/layout_utils.hpp" |
|
|
| |
| |
| namespace ck { |
| namespace wrapper { |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename Shape, typename UnrolledDescriptorType> |
| struct Layout |
| { |
| |
| |
| private: |
| static constexpr auto I0 = Number<0>{}; |
| static constexpr auto I1 = Number<1>{}; |
|
|
| |
| |
| |
| |
| |
| |
| template <typename... Ts> |
| __host__ __device__ constexpr static auto |
| GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape) |
| { |
| return generate_tuple( |
| [&](auto) { |
| if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime()) |
| { |
| |
| return index_t(0); |
| } |
| else |
| { |
| |
| return I0; |
| } |
| }, |
| Number<Tuple<Ts...>::Size()>{}); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename Idx, typename... Ts> |
| __host__ __device__ constexpr static auto |
| GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape) |
| { |
| if constexpr(Idx::value == 0) |
| { |
| if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value) |
| { |
| |
| constexpr index_t merge_nelems = decltype(UnrollNestedTuple( |
| tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size(); |
| using LowerDimsSequence = |
| typename arithmetic_sequence_gen<0, merge_nelems, 1>::type; |
| return LowerDimsSequence::Reverse(); |
| } |
| else |
| { |
| |
| return Sequence<0>{}; |
| } |
| } |
| else |
| { |
| |
| using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{})); |
| const auto next_seq_val = PreviousSeqT::At(I0) + 1; |
| if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value) |
| { |
| constexpr index_t merge_nelems = decltype(UnrollNestedTuple( |
| tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size(); |
| using LowerDimsSequence = |
| typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>:: |
| type; |
| return LowerDimsSequence::Reverse(); |
| } |
| else |
| { |
| return Sequence<next_seq_val>{}; |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename... ShapeDims, typename... IdxDims> |
| __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape, |
| const Tuple<IdxDims...>& idx) |
| { |
| if constexpr(!IsNestedTuple(Tuple<IdxDims...>{})) |
| { |
| |
| return shape; |
| } |
| else |
| { |
| |
| |
| |
| auto aligned_shape = generate_tuple( |
| [&](auto i) { |
| if constexpr(is_detected<is_tuple, |
| tuple_element_t<i, Tuple<IdxDims...>>>::value) |
| { |
| return shape.At(i); |
| } |
| else |
| { |
| return make_tuple(shape.At(i)); |
| } |
| }, |
| Number<Tuple<IdxDims...>::Size()>{}); |
|
|
| |
| return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape), |
| UnrollNestedTuple<0, 1>(idx)); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| template <typename... ShapeDims, typename DescriptorToMerge> |
| __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape, |
| const DescriptorToMerge& desc) |
| { |
| |
| const auto merge_elems = TupleReverse(UnrollNestedTuple(shape)); |
| |
| using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type; |
| const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); |
| const auto upper_dims = make_tuple(Sequence<0>{}); |
| |
| if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime()) |
| { |
| return transform_tensor_descriptor( |
| desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); |
| } |
| else |
| { |
| |
| |
| |
| return transform_tensor_descriptor( |
| desc, |
| make_tuple(make_merge_transform_v1_carry_check(merge_elems)), |
| lower_dims, |
| upper_dims); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge> |
| __host__ __device__ constexpr static auto |
| CreateMergedDescriptor(const Tuple<ShapeDims...>& shape, |
| [[maybe_unused]] const Tuple<IdxDims...>& idxs, |
| DescriptorToMerge& desc) |
| { |
| const auto transforms = generate_tuple( |
| [&](auto i) { |
| |
| if constexpr(is_detected<is_tuple, |
| tuple_element_t<i, Tuple<ShapeDims...>>>::value && |
| !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value) |
| { |
| |
| |
| const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i))); |
| if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime()) |
| { |
| return make_merge_transform(merge_elems); |
| } |
| else |
| { |
| |
| |
| |
| return make_merge_transform_v1_carry_check(merge_elems); |
| } |
| } |
| else |
| { |
| |
| static_assert( |
| !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value && |
| is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value), |
| "Wrong Idx for layout()"); |
| |
| return make_pass_through_transform(shape.At(i)); |
| } |
| }, |
| Number<Tuple<ShapeDims...>::Size()>{}); |
|
|
| const auto lower_dims = |
| generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); }, |
| Number<Tuple<ShapeDims...>::Size()>{}); |
| const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; }, |
| Number<Tuple<ShapeDims...>::Size()>{}); |
|
|
| return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); |
| } |
|
|
| using Descriptor1dType = |
| remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>; |
| using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>; |
| |
|
|
| public: |
| using LayoutShape = Shape; |
| using LayoutUnrolledDescriptorType = UnrolledDescriptorType; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename... ShapeDims, typename... IdxDims> |
| __host__ __device__ constexpr static auto |
| TransformDesc(const Tuple<ShapeDims...>& shape, |
| const Tuple<IdxDims...>& idxs, |
| const UnrolledDescriptorType& naive_descriptor) |
| { |
| if constexpr(Tuple<IdxDims...>::Size() == I1) |
| { |
| |
| return MakeMerge1d(shape, naive_descriptor); |
| } |
| else |
| { |
| |
| |
| |
| |
| static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(), |
| "Idx rank and Shape rank must be the same (except 1d)."); |
| |
| const auto aligned_shape = AlignShapeToIdx(shape, idxs); |
| |
| return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor); |
| } |
| } |
|
|
| using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc( |
| Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>; |
|
|
| __host__ __device__ constexpr auto GetElementSpaceSize() const |
| { |
| return unrolled_descriptor_.GetElementSpaceSize(); |
| } |
|
|
| __host__ __device__ Layout() = delete; |
|
|
| |
| |
| |
| |
| |
| |
| __host__ __device__ constexpr Layout(const Shape& shape, |
| const UnrolledDescriptorType& unnested_descriptor) |
| : unrolled_descriptor_(unnested_descriptor), shape_(shape) |
| { |
| |
| if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime()) |
| { |
| descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_); |
| merged_nests_descriptor_ = |
| TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| template <typename Idxs> |
| __host__ __device__ constexpr index_t operator()() const |
| { |
| static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(), |
| "Compiletime operator used on runtime layout."); |
| using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{})); |
| using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); |
| return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| template <typename... Ts> |
| __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const |
| { |
| if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1) |
| { |
| |
| return descriptor_1d_.CalculateOffset(Idx); |
| } |
| else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size()) |
| { |
| |
| return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx)); |
| } |
| else |
| { |
| |
| const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_); |
| return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| template <index_t IDim> |
| __host__ __device__ constexpr auto GetLength() const |
| { |
| const auto elem = shape_.At(Number<IDim>{}); |
| if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value) |
| { |
| const auto unrolled_element = UnrollNestedTuple(elem); |
| return TupleReduce<I0.value, unrolled_element.Size()>( |
| [](auto x, auto y) { return x * y; }, unrolled_element); |
| } |
| else |
| { |
| return elem; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| __host__ __device__ constexpr auto GetLengths() const |
| { |
| const auto unrolled_shape = UnrollNestedTuple(shape_); |
| return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, |
| unrolled_shape); |
| } |
|
|
| |
| |
| |
| |
| |
| __host__ __device__ constexpr const Shape& GetShape() const { return shape_; } |
|
|
| |
| |
| |
| |
| |
| __host__ __device__ constexpr auto GetDefaultLengthsTuple() const |
| { |
| return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{}); |
| } |
|
|
| |
| |
| |
| |
| |
| __host__ __device__ constexpr auto GetDefaultStartIdxs() const |
| { |
| return GenerateDefaultIdxsTuple(shape_); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| __host__ __device__ constexpr const MergedNestsDescriptorType& |
| GetMergedNestingDescriptor() const |
| { |
| return merged_nests_descriptor_; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| __host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const |
| { |
| return descriptor_1d_; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| __host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const |
| { |
| return unrolled_descriptor_; |
| } |
|
|
| |
| |
| private: |
| |
| UnrolledDescriptorType unrolled_descriptor_; |
| |
| Descriptor1dType descriptor_1d_; |
| |
| MergedNestsDescriptorType merged_nests_descriptor_; |
| |
| |
| |
| |
| const Shape shape_; |
| |
| }; |
|
|
| } |
| } |
|
|