| #pragma once |
|
|
| #include <cstddef> |
| #if defined(_MSC_VER) |
| #include <intrin.h> |
| #endif |
|
|
| namespace c10::utils { |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| struct bitset final { |
| private: |
| #if defined(_MSC_VER) |
| |
| using bitset_type = int64_t; |
| #else |
| |
| using bitset_type = long long int; |
| #endif |
| public: |
| static constexpr size_t NUM_BITS() { |
| return 8 * sizeof(bitset_type); |
| } |
|
|
| constexpr bitset() noexcept = default; |
| constexpr bitset(const bitset&) noexcept = default; |
| constexpr bitset(bitset&&) noexcept = default; |
| |
| |
| bitset& operator=(const bitset&) noexcept = default; |
| bitset& operator=(bitset&&) noexcept = default; |
| ~bitset() = default; |
|
|
| constexpr void set(size_t index) noexcept { |
| bitset_ |= (static_cast<long long int>(1) << index); |
| } |
|
|
| constexpr void unset(size_t index) noexcept { |
| bitset_ &= ~(static_cast<long long int>(1) << index); |
| } |
|
|
| constexpr bool get(size_t index) const noexcept { |
| return bitset_ & (static_cast<long long int>(1) << index); |
| } |
|
|
| constexpr bool is_entirely_unset() const noexcept { |
| return 0 == bitset_; |
| } |
|
|
| |
| template <class Func> |
| |
| void for_each_set_bit(Func&& func) const { |
| bitset cur = *this; |
| size_t index = cur.find_first_set(); |
| while (0 != index) { |
| |
| index -= 1; |
| func(index); |
| cur.unset(index); |
| index = cur.find_first_set(); |
| } |
| } |
|
|
| private: |
| |
| |
| |
| size_t find_first_set() const { |
| #if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64)) |
| unsigned long result; |
| bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); |
| if (!has_bits_set) { |
| return 0; |
| } |
| return result + 1; |
| #elif defined(_MSC_VER) && defined(_M_IX86) |
| unsigned long result; |
| if (static_cast<uint32_t>(bitset_) != 0) { |
| bool has_bits_set = |
| (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_))); |
| if (!has_bits_set) { |
| return 0; |
| } |
| return result + 1; |
| } else { |
| bool has_bits_set = |
| (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32))); |
| if (!has_bits_set) { |
| return 32; |
| } |
| return result + 33; |
| } |
| #else |
| return __builtin_ffsll(bitset_); |
| #endif |
| } |
|
|
| friend bool operator==(bitset lhs, bitset rhs) noexcept { |
| return lhs.bitset_ == rhs.bitset_; |
| } |
|
|
| bitset_type bitset_{0}; |
| }; |
|
|
| inline bool operator!=(bitset lhs, bitset rhs) noexcept { |
| return !(lhs == rhs); |
| } |
|
|
| } |
|
|