|
|
#pragma once |
|
|
|
|
|
#include <c10/macros/Macros.h> |
|
|
#include <c10/macros/Export.h> |
|
|
#include <c10/util/TypeTraits.h> |
|
|
#include <c10/util/TypeList.h> |
|
|
#include <c10/util/intrusive_ptr.h> |
|
|
#include <c10/util/order_preserving_flat_hash_map.h> |
|
|
#include <c10/util/Optional.h> |
|
|
#include <ATen/core/TensorBody.h> |
|
|
#include <ATen/core/jit_type_base.h> |
|
|
|
|
|
namespace c10 { |
|
|
struct IValue; |
|
|
template<class Key, class Value> class Dict; |
|
|
struct Type; |
|
|
|
|
|
namespace impl { |
|
|
|
|
|
using valid_dict_key_types = guts::typelist::typelist< |
|
|
int64_t, |
|
|
std::string, |
|
|
double, |
|
|
c10::complex<double>, |
|
|
bool, |
|
|
at::Tensor |
|
|
>; |
|
|
} |
|
|
|
|
|
namespace detail { |
|
|
|
|
|
struct DictKeyHash { |
|
|
size_t operator()(const IValue& ivalue) const; |
|
|
}; |
|
|
|
|
|
struct DictKeyEqualTo { |
|
|
bool operator()(const IValue& lhs, const IValue& rhs) const; |
|
|
}; |
|
|
|
|
|
struct DictImpl final : public c10::intrusive_ptr_target { |
|
|
using dict_map_type = ska_ordered::order_preserving_flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>; |
|
|
struct DictElementTypes final { |
|
|
TypePtr keyType; |
|
|
TypePtr valueType; |
|
|
}; |
|
|
|
|
|
explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_) |
|
|
: dict(std::move(dict_)) |
|
|
, elementTypes(std::move(elementTypes_)) {} |
|
|
dict_map_type dict; |
|
|
|
|
|
DictElementTypes elementTypes; |
|
|
|
|
|
intrusive_ptr<DictImpl> copy() const; |
|
|
friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs); |
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
namespace impl { |
|
|
template<class Key, class Value, class Iterator> class DictIterator; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<class Key, class Value, class Iterator> |
|
|
class DictEntryRef final { |
|
|
public: |
|
|
explicit DictEntryRef(Iterator iterator) |
|
|
: iterator_(std::move(iterator)) {} |
|
|
|
|
|
decltype(auto) key() const { |
|
|
return iterator_->first.template to<Key>(); |
|
|
} |
|
|
|
|
|
decltype(auto) value() const { |
|
|
return iterator_->second.template to<Value>(); |
|
|
} |
|
|
|
|
|
template<class Value_> |
|
|
void setValue(Value_&& value) const { |
|
|
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of setValue()"); |
|
|
iterator_->second = Value(std::forward<Value_>(value)); |
|
|
} |
|
|
|
|
|
private: |
|
|
|
|
|
|
|
|
|
|
|
DictEntryRef(const DictEntryRef&) = default; |
|
|
DictEntryRef& operator=(const DictEntryRef&) = default; |
|
|
DictEntryRef(DictEntryRef&&) noexcept = default; |
|
|
DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default; |
|
|
|
|
|
Iterator iterator_; |
|
|
friend class DictIterator<Key, Value, Iterator>; |
|
|
friend class Dict<Key, Value>; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<class Key, class Value, class Iterator> |
|
|
class DictIterator final : public std::iterator<std::forward_iterator_tag, DictEntryRef<Key, Value, Iterator>> { |
|
|
public: |
|
|
explicit DictIterator() = default; |
|
|
~DictIterator() = default; |
|
|
|
|
|
DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} |
|
|
DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} |
|
|
DictIterator& operator=(const DictIterator& rhs) { |
|
|
entryRef_ = rhs.entryRef_; |
|
|
return *this; |
|
|
} |
|
|
DictIterator& operator=(DictIterator&& rhs) noexcept { |
|
|
entryRef_ = std::move(rhs.entryRef_); |
|
|
return *this; |
|
|
} |
|
|
|
|
|
DictIterator& operator++() { |
|
|
++entryRef_.iterator_; |
|
|
return *this; |
|
|
} |
|
|
|
|
|
DictIterator operator++(int) { |
|
|
DictIterator copy(*this); |
|
|
++*this; |
|
|
return copy; |
|
|
} |
|
|
|
|
|
const DictEntryRef<Key, Value, Iterator>& operator*() const { |
|
|
return entryRef_; |
|
|
} |
|
|
|
|
|
const DictEntryRef<Key, Value, Iterator>* operator->() const { |
|
|
return &entryRef_; |
|
|
} |
|
|
|
|
|
friend typename std::iterator<std::random_access_iterator_tag, DictEntryRef<Key, Value, Iterator>>::difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_; |
|
|
} |
|
|
|
|
|
private: |
|
|
explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {} |
|
|
|
|
|
const Iterator& get_iterator_() const { |
|
|
return entryRef_.iterator_; |
|
|
} |
|
|
|
|
|
friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.get_iterator_() == rhs.get_iterator_(); |
|
|
} |
|
|
|
|
|
friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.get_iterator_() != rhs.get_iterator_(); |
|
|
} |
|
|
|
|
|
friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.get_iterator_() < rhs.get_iterator_(); |
|
|
} |
|
|
|
|
|
friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.get_iterator_() <= rhs.get_iterator_(); |
|
|
} |
|
|
|
|
|
friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.get_iterator_() > rhs.get_iterator_(); |
|
|
} |
|
|
|
|
|
friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) { |
|
|
return lhs.get_iterator_() >= rhs.get_iterator_(); |
|
|
} |
|
|
|
|
|
DictEntryRef<Key, Value, Iterator> entryRef_; |
|
|
|
|
|
friend class DictIterator<Key, Value, typename c10::detail::DictImpl::dict_map_type::iterator>; |
|
|
friend class Dict<Key, Value>; |
|
|
}; |
|
|
|
|
|
template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue> dict); |
|
|
template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value> dict); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<class Key, class Value> |
|
|
class Dict final { |
|
|
private: |
|
|
static_assert((std::is_same<IValue, Key>::value && std::is_same<IValue, Value>::value) || guts::typelist::contains<impl::valid_dict_key_types, Key>::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string."); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c10::intrusive_ptr<detail::DictImpl> impl_; |
|
|
|
|
|
explicit Dict(c10::intrusive_ptr<detail::DictImpl>&& impl); |
|
|
friend struct IValue; |
|
|
template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>); |
|
|
template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>); |
|
|
|
|
|
public: |
|
|
using key_type = Key; |
|
|
using mapped_type = Value; |
|
|
using size_type = typename detail::DictImpl::dict_map_type::size_type; |
|
|
using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explicit Dict(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explicit Dict(TypePtr keyType, TypePtr valueType); |
|
|
|
|
|
~Dict() = default; |
|
|
|
|
|
Dict(const Dict&) = default; |
|
|
Dict& operator=(const Dict&) = default; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dict copy() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterator begin() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterator end() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool empty() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size_type size() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void clear() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<class Key_, class Value_> |
|
|
std::pair<iterator, bool> insert(Key_&& key, Value_&& value) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<class Key_, class Value_> |
|
|
std::pair<iterator, bool> insert_or_assign(Key_&& key, Value_&& value) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void erase(iterator iter) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_NODISCARD size_t erase(const Key& key) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value at(const Key& key) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterator find(const Key& key) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool contains(const Key& key) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void reserve(size_type count) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class Key_, class Value_> |
|
|
friend bool operator==( |
|
|
const Dict<Key_, Value_>& lhs, |
|
|
const Dict<Key_, Value_>& rhs); |
|
|
template <class Key_, class Value_> |
|
|
friend bool operator!=( |
|
|
const Dict<Key_, Value_>& lhs, |
|
|
const Dict<Key_, Value_>& rhs); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool is(const Dict& rhs) const; |
|
|
|
|
|
|
|
|
|
|
|
TypePtr keyType() const; |
|
|
TypePtr valueType() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void unsafeSetKeyType(TypePtr t); |
|
|
void unsafeSetValueType(TypePtr t); |
|
|
}; |
|
|
|
|
|
namespace impl { |
|
|
|
|
|
|
|
|
|
|
|
using GenericDict = Dict<IValue, IValue>; |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
namespace torch { |
|
|
template<class Key, class Value> using Dict = c10::Dict<Key, Value>; |
|
|
} |
|
|
|
|
|
#include <ATen/core/Dict_inl.h> |
|
|
|