| | #pragma once |
| |
|
| | #include <c10/macros/Macros.h> |
| | #include <c10/util/C++17.h> |
| | #include <c10/util/Optional.h> |
| | #if defined(_MSC_VER) |
| | #include <intrin.h> |
| | #endif |
| |
|
| | namespace c10 { |
| | namespace utils { |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | struct bitset final { |
| | private: |
| | #if defined(_MSC_VER) |
| | |
| | using bitset_type = int64_t; |
| | #else |
| | |
| | using bitset_type = long long int; |
| | #endif |
| | public: |
| | static constexpr size_t NUM_BITS() { |
| | return 8 * sizeof(bitset_type); |
| | } |
| |
|
| | constexpr bitset() noexcept : bitset_(0) {} |
| | constexpr bitset(const bitset&) noexcept = default; |
| | constexpr bitset(bitset&&) noexcept = default; |
| | |
| | |
| | bitset& operator=(const bitset&) noexcept = default; |
| | bitset& operator=(bitset&&) noexcept = default; |
| |
|
| | constexpr void set(size_t index) noexcept { |
| | bitset_ |= (static_cast<long long int>(1) << index); |
| | } |
| |
|
| | constexpr void unset(size_t index) noexcept { |
| | bitset_ &= ~(static_cast<long long int>(1) << index); |
| | } |
| |
|
| | constexpr bool get(size_t index) const noexcept { |
| | return bitset_ & (static_cast<long long int>(1) << index); |
| | } |
| |
|
| | constexpr bool is_entirely_unset() const noexcept { |
| | return 0 == bitset_; |
| | } |
| |
|
| | |
| | template <class Func> |
| | void for_each_set_bit(Func&& func) const { |
| | bitset cur = *this; |
| | size_t index = cur.find_first_set(); |
| | while (0 != index) { |
| | |
| | index -= 1; |
| | func(index); |
| | cur.unset(index); |
| | index = cur.find_first_set(); |
| | } |
| | } |
| |
|
| | private: |
| | |
| | |
| | |
| | size_t find_first_set() const { |
| | #if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64)) |
| | unsigned long result; |
| | bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); |
| | if (!has_bits_set) { |
| | return 0; |
| | } |
| | return result + 1; |
| | #elif defined(_MSC_VER) && defined(_M_IX86) |
| | unsigned long result; |
| | if (static_cast<uint32_t>(bitset_) != 0) { |
| | bool has_bits_set = |
| | (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_))); |
| | if (!has_bits_set) { |
| | return 0; |
| | } |
| | return result + 1; |
| | } else { |
| | bool has_bits_set = |
| | (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32))); |
| | if (!has_bits_set) { |
| | return 32; |
| | } |
| | return result + 33; |
| | } |
| | #else |
| | return __builtin_ffsll(bitset_); |
| | #endif |
| | } |
| |
|
| | friend bool operator==(bitset lhs, bitset rhs) noexcept { |
| | return lhs.bitset_ == rhs.bitset_; |
| | } |
| |
|
| | bitset_type bitset_; |
| | }; |
| |
|
| | inline bool operator!=(bitset lhs, bitset rhs) noexcept { |
| | return !(lhs == rhs); |
| | } |
| |
|
| | } |
| | } |
| |
|