Spaces:
Sleeping
Sleeping
| /*************************************************************************************************** | |
| * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| * SPDX-License-Identifier: BSD-3-Clause | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * | |
| * 1. Redistributions of source code must retain the above copyright notice, this | |
| * list of conditions and the following disclaimer. | |
| * | |
| * 2. Redistributions in binary form must reproduce the above copyright notice, | |
| * this list of conditions and the following disclaimer in the documentation | |
| * and/or other materials provided with the distribution. | |
| * | |
| * 3. Neither the name of the copyright holder nor the names of its | |
| * contributors may be used to endorse or promote products derived from | |
| * this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| **************************************************************************************************/ | |
| /*! \file | |
| \brief Statically sized array of elements that accommodates subbyte trivial types | |
| in a packed storage. | |
| */ | |
| namespace cute | |
| { | |
| // | |
| // Underlying subbyte storage type | |
| // | |
| template <class T> | |
| using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v<T> <= 8), uint8_t, | |
| conditional_t<(cute::sizeof_bits_v<T> <= 16), uint16_t, | |
| conditional_t<(cute::sizeof_bits_v<T> <= 32), uint32_t, | |
| conditional_t<(cute::sizeof_bits_v<T> <= 64), uint64_t, | |
| conditional_t<(cute::sizeof_bits_v<T> <= 128), uint128_t, | |
| T>>>>>; | |
| template <class T> struct subbyte_iterator; | |
| template <class, class> struct swizzle_ptr; | |
| // | |
| // subbyte_reference | |
| // Proxy object for sub-byte element references | |
| // | |
| template <class T> | |
| struct subbyte_reference | |
| { | |
| // Iterator Element type (const or non-const) | |
| using element_type = T; | |
| // Iterator Value type without type qualifier. | |
| using value_type = remove_cv_t<T>; | |
| // Storage type (const or non-const) | |
| using storage_type = conditional_t<(is_const_v<T>), subbyte_storage_type_t<T> const, subbyte_storage_type_t<T>>; | |
| static_assert(sizeof_bits_v<storage_type> % 8 == 0, "Storage type is not supported"); | |
| static_assert(sizeof_bits_v<element_type> <= sizeof_bits_v<storage_type>, | |
| "Size of Element must not be greater than Storage."); | |
| private: | |
| // Bitmask for covering one item | |
| static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v<storage_type> - sizeof_bits_v<element_type>)); | |
| // Flag for fast branching on straddled elements | |
| static constexpr bool is_storage_unaligned = ((sizeof_bits_v<storage_type> % sizeof_bits_v<element_type>) != 0); | |
| friend struct subbyte_iterator<T>; | |
| // Pointer to storage element | |
| storage_type* ptr_ = nullptr; | |
| // Bit index of value_type starting position within storage_type element. | |
| // RI: 0 <= idx_ < sizeof_bit<storage_type> | |
| uint8_t idx_ = 0; | |
| // Ctor | |
| template <class PointerType> | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast<storage_type*>(ptr)), idx_(idx) {} | |
| public: | |
| // Copy Ctor | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_reference(subbyte_reference const& other) { | |
| *this = element_type(other); | |
| } | |
| // Copy Assignment | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_reference& operator=(subbyte_reference const& other) { | |
| return *this = element_type(other); | |
| } | |
| // Assignment | |
| template <class T_ = element_type> | |
| CUTE_HOST_DEVICE constexpr | |
| enable_if_t<!is_const_v<T_>, subbyte_reference&> operator=(element_type x) | |
| { | |
| static_assert(is_same_v<T_, element_type>, "Do not specify template arguments!"); | |
| storage_type item = (reinterpret_cast<storage_type const&>(x) & BitMask); | |
| // Update the current storage element | |
| storage_type bit_mask_0 = storage_type(BitMask << idx_); | |
| ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); | |
| // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) | |
| if (is_storage_unaligned && idx_ + sizeof_bits_v<value_type> > sizeof_bits_v<storage_type>) { | |
| uint8_t straddle_bits = uint8_t(sizeof_bits_v<storage_type> - idx_); | |
| storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); | |
| // Update the next storage element | |
| ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); | |
| } | |
| return *this; | |
| } | |
| // Comparison of referenced values | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() < y.get(); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() > y.get(); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } | |
| // Value | |
| CUTE_HOST_DEVICE | |
| element_type get() const | |
| { | |
| if constexpr (is_same_v<bool, value_type>) { // Extract to bool -- potentially faster impl | |
| return bool((*ptr_) & (BitMask << idx_)); | |
| } else { // Extract to element_type | |
| // Extract from the current storage element | |
| auto item = storage_type((ptr_[0] >> idx_) & BitMask); | |
| // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) | |
| if (is_storage_unaligned && idx_ + sizeof_bits_v<value_type> > sizeof_bits_v<storage_type>) { | |
| uint8_t straddle_bits = uint8_t(sizeof_bits_v<storage_type> - idx_); | |
| storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); | |
| // Extract from the next storage element | |
| item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); | |
| } | |
| return reinterpret_cast<element_type&>(item); | |
| } | |
| } | |
| // Extract to type element_type | |
| CUTE_HOST_DEVICE constexpr | |
| operator element_type() const { | |
| return get(); | |
| } | |
| // Address | |
| subbyte_iterator<T> operator&() const { | |
| return {ptr_, idx_}; | |
| } | |
| }; | |
| // | |
| // subbyte_iterator | |
| // Random-access iterator over subbyte references | |
| // | |
| template <class T> | |
| struct subbyte_iterator | |
| { | |
| // Iterator Element type (const or non-const) | |
| using element_type = T; | |
| // Iterator Value type without type qualifier. | |
| using value_type = remove_cv_t<T>; | |
| // Storage type (const or non-const) | |
| using storage_type = conditional_t<(is_const_v<T>), subbyte_storage_type_t<T> const, subbyte_storage_type_t<T>>; | |
| // Reference proxy type | |
| using reference = subbyte_reference<element_type>; | |
| static_assert(sizeof_bits_v<storage_type> % 8 == 0, "Storage type is not supported"); | |
| static_assert(sizeof_bits_v<element_type> <= sizeof_bits_v<storage_type>, | |
| "Size of Element must not be greater than Storage."); | |
| private: | |
| template <class, class> friend struct swizzle_ptr; | |
| // Pointer to storage element | |
| storage_type* ptr_ = nullptr; | |
| // Bit index of value_type starting position within storage_type element. | |
| // RI: 0 <= idx_ < sizeof_bit<storage_type> | |
| uint8_t idx_ = 0; | |
| public: | |
| // Ctor | |
| subbyte_iterator() = default; | |
| // Ctor | |
| template <class PointerType> | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast<storage_type*>(ptr)), idx_(idx) { } | |
| CUTE_HOST_DEVICE constexpr | |
| reference operator*() const { | |
| return reference(ptr_, idx_); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator& operator+=(uint64_t k) { | |
| k = sizeof_bits_v<value_type> * k + idx_; | |
| ptr_ += k / sizeof_bits_v<storage_type>; | |
| idx_ = k % sizeof_bits_v<storage_type>; | |
| return *this; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator operator+(uint64_t k) const { | |
| return subbyte_iterator(ptr_, idx_) += k; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| reference operator[](uint64_t k) const { | |
| return *(*this + k); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator& operator++() { | |
| idx_ += sizeof_bits_v<value_type>; | |
| if (idx_ >= sizeof_bits_v<storage_type>) { | |
| ++ptr_; | |
| idx_ -= sizeof_bits_v<storage_type>; | |
| } | |
| return *this; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator operator++(int) { | |
| subbyte_iterator ret(*this); | |
| ++(*this); | |
| return ret; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator& operator--() { | |
| if (idx_ >= sizeof_bits_v<value_type>) { | |
| idx_ -= sizeof_bits_v<value_type>; | |
| } else { | |
| --ptr_; | |
| idx_ += sizeof_bits_v<storage_type> - sizeof_bits_v<value_type>; | |
| } | |
| return *this; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| subbyte_iterator operator--(int) { | |
| subbyte_iterator ret(*this); | |
| --(*this); | |
| return ret; | |
| } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { | |
| return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; | |
| } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { | |
| return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); | |
| } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } | |
| CUTE_HOST_DEVICE constexpr friend | |
| bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } | |
| // Conversion to raw pointer with loss of subbyte index | |
| CUTE_HOST_DEVICE constexpr friend | |
| T* raw_pointer_cast(subbyte_iterator const& x) { | |
| assert(x.idx_ == 0); | |
| return reinterpret_cast<T*>(x.ptr_); | |
| } | |
| // Conversion to NewT_ with possible loss of subbyte index | |
| template <class NewT_> | |
| CUTE_HOST_DEVICE constexpr friend | |
| auto recast_ptr(subbyte_iterator const& x) { | |
| using NewT = conditional_t<(is_const_v<T>), NewT_ const, NewT_>; | |
| if constexpr (cute::is_subbyte_v<NewT>) { // Making subbyte_iter, preserve the subbyte idx | |
| return subbyte_iterator<NewT>(x.ptr_, x.idx_); | |
| } else { // Not subbyte, assume/assert subbyte idx 0 | |
| return reinterpret_cast<NewT*>(raw_pointer_cast(x)); | |
| } | |
| CUTE_GCC_UNREACHABLE; | |
| } | |
| CUTE_HOST_DEVICE friend void print(subbyte_iterator x) { | |
| printf("subptr[%db](%p.%u)", int(sizeof_bits_v<T>), x.ptr_, x.idx_); | |
| } | |
| }; | |
| // | |
| // array_subbyte | |
| // Statically sized array for non-byte-aligned data types | |
| // | |
| template <class T, size_t N> | |
| struct array_subbyte | |
| { | |
| using element_type = T; | |
| using value_type = remove_cv_t<T>; | |
| using pointer = element_type*; | |
| using const_pointer = element_type const*; | |
| using size_type = size_t; | |
| using difference_type = ptrdiff_t; | |
| // | |
| // References | |
| // | |
| using reference = subbyte_reference<element_type>; | |
| using const_reference = subbyte_reference<element_type const>; | |
| // | |
| // Iterators | |
| // | |
| using iterator = subbyte_iterator<element_type>; | |
| using const_iterator = subbyte_iterator<element_type const>; | |
| // Storage type (const or non-const) | |
| using storage_type = conditional_t<(is_const_v<T>), subbyte_storage_type_t<T> const, subbyte_storage_type_t<T>>; | |
| static_assert(sizeof_bits_v<storage_type> % 8 == 0, "Storage type is not supported"); | |
| private: | |
| // Number of storage elements, ceil_div | |
| static constexpr size_type StorageElements = (N * sizeof_bits_v<value_type> + sizeof_bits_v<storage_type> - 1) / sizeof_bits_v<storage_type>; | |
| // Internal storage | |
| storage_type storage[StorageElements]; | |
| public: | |
| constexpr | |
| array_subbyte() = default; | |
| CUTE_HOST_DEVICE constexpr | |
| array_subbyte(array_subbyte const& x) { | |
| CUTE_UNROLL | |
| for (size_type i = 0; i < StorageElements; ++i) { | |
| storage[i] = x.storage[i]; | |
| } | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| size_type size() const { | |
| return N; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| size_type max_size() const { | |
| return N; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| bool empty() const { | |
| return !N; | |
| } | |
| // Efficient clear method | |
| CUTE_HOST_DEVICE constexpr | |
| void clear() { | |
| CUTE_UNROLL | |
| for (size_type i = 0; i < StorageElements; ++i) { | |
| storage[i] = storage_type(0); | |
| } | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| void fill(T const& value) { | |
| CUTE_UNROLL | |
| for (size_type i = 0; i < N; ++i) { | |
| at(i) = value; | |
| } | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| reference at(size_type pos) { | |
| return iterator(storage)[pos]; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_reference at(size_type pos) const { | |
| return const_iterator(storage)[pos]; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| reference operator[](size_type pos) { | |
| return at(pos); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_reference operator[](size_type pos) const { | |
| return at(pos); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| reference front() { | |
| return at(0); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_reference front() const { | |
| return at(0); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| reference back() { | |
| return at(N-1); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_reference back() const { | |
| return at(N-1); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| pointer data() { | |
| return reinterpret_cast<pointer>(storage); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_pointer data() const { | |
| return reinterpret_cast<const_pointer>(storage); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| storage_type* raw_data() { | |
| return storage; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| storage_type const* raw_data() const { | |
| return storage; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| iterator begin() { | |
| return iterator(storage); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_iterator begin() const { | |
| return const_iterator(storage); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_iterator cbegin() const { | |
| return begin(); | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| iterator end() { | |
| return iterator(storage) + N; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_iterator end() const { | |
| return const_iterator(storage) + N; | |
| } | |
| CUTE_HOST_DEVICE constexpr | |
| const_iterator cend() const { | |
| return end(); | |
| } | |
| // | |
| // Comparison operators | |
| // | |
| }; | |
| // | |
| // Operators | |
| // | |
| template <class T, size_t N> | |
| CUTE_HOST_DEVICE constexpr | |
| void clear(array_subbyte<T,N>& a) | |
| { | |
| a.clear(); | |
| } | |
| template <class T, size_t N> | |
| CUTE_HOST_DEVICE constexpr | |
| void fill(array_subbyte<T,N>& a, T const& value) | |
| { | |
| a.fill(value); | |
| } | |
| } // namespace cute | |
| // | |
| // Specialize tuple-related functionality for cute::array_subbyte | |
| // | |
| namespace cute | |
| { | |
| template <size_t I, class T, size_t N> | |
| CUTE_HOST_DEVICE constexpr | |
| T& get(array_subbyte<T,N>& a) | |
| { | |
| static_assert(I < N, "Index out of range"); | |
| return a[I]; | |
| } | |
| template <size_t I, class T, size_t N> | |
| CUTE_HOST_DEVICE constexpr | |
| T const& get(array_subbyte<T,N> const& a) | |
| { | |
| static_assert(I < N, "Index out of range"); | |
| return a[I]; | |
| } | |
| template <size_t I, class T, size_t N> | |
| CUTE_HOST_DEVICE constexpr | |
| T&& get(array_subbyte<T,N>&& a) | |
| { | |
| static_assert(I < N, "Index out of range"); | |
| return cute::move(a[I]); | |
| } | |
| } // end namespace cute | |
| namespace CUTE_STL_NAMESPACE | |
| { | |
| template <class T> | |
| struct is_reference<cute::subbyte_reference<T>> | |
| : CUTE_STL_NAMESPACE::true_type | |
| {}; | |
| template <class T, size_t N> | |
| struct tuple_size<cute::array_subbyte<T,N>> | |
| : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | |
| {}; | |
| template <size_t I, class T, size_t N> | |
| struct tuple_element<I, cute::array_subbyte<T,N>> | |
| { | |
| using type = T; | |
| }; | |
| template <class T, size_t N> | |
| struct tuple_size<const cute::array_subbyte<T,N>> | |
| : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | |
| {}; | |
| template <size_t I, class T, size_t N> | |
| struct tuple_element<I, const cute::array_subbyte<T,N>> | |
| { | |
| using type = T; | |
| }; | |
| } // end namespace CUTE_STL_NAMESPACE | |
| namespace std | |
| { | |
| template <class... _Tp> | |
| struct tuple_size; | |
| template <size_t _Ip, class... _Tp> | |
| struct tuple_element; | |
| template <class T, size_t N> | |
| struct tuple_size<cute::array_subbyte<T,N>> | |
| : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | |
| {}; | |
| template <size_t I, class T, size_t N> | |
| struct tuple_element<I, cute::array_subbyte<T,N>> | |
| { | |
| using type = T; | |
| }; | |
| template <class T, size_t N> | |
| struct tuple_size<const cute::array_subbyte<T,N>> | |
| : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | |
| {}; | |
| template <size_t I, class T, size_t N> | |
| struct tuple_element<I, const cute::array_subbyte<T,N>> | |
| { | |
| using type = T; | |
| }; | |
| } // end namespace std | |