File size: 3,609 Bytes
593db5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
A subclass of MutableAttr that has defaultdict support.
"""
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
import six
from attrdict.mixins import MutableAttr
__all__ = ['AttrDefault']
class AttrDefault(MutableAttr):
"""
An implementation of MutableAttr with defaultdict support
"""
def __init__(self, default_factory=None, items=None, sequence_type=tuple,
pass_key=False):
if items is None:
items = {}
elif not isinstance(items, Mapping):
items = dict(items)
self._setattr('_default_factory', default_factory)
self._setattr('_mapping', items)
self._setattr('_sequence_type', sequence_type)
self._setattr('_pass_key', pass_key)
self._setattr('_allow_invalid_attributes', False)
def _configuration(self):
"""
The configuration for a AttrDefault instance
"""
return self._sequence_type, self._default_factory, self._pass_key
def __getitem__(self, key):
"""
Access a value associated with a key.
Note: values returned will not be wrapped, even if recursive
is True.
"""
if key in self._mapping:
return self._mapping[key]
elif self._default_factory is not None:
return self.__missing__(key)
raise KeyError(key)
def __setitem__(self, key, value):
"""
Add a key-value pair to the instance.
"""
self._mapping[key] = value
def __delitem__(self, key):
"""
Delete a key-value pair
"""
del self._mapping[key]
def __len__(self):
"""
Check the length of the mapping.
"""
return len(self._mapping)
def __iter__(self):
"""
Iterated through the keys.
"""
return iter(self._mapping)
def __missing__(self, key):
"""
Add a missing element.
"""
if self._pass_key:
self[key] = value = self._default_factory(key)
else:
self[key] = value = self._default_factory()
return value
def __repr__(self):
"""
Return a string representation of the object.
"""
return six.u(
"AttrDefault({default_factory}, {pass_key}, {mapping})"
).format(
default_factory=repr(self._default_factory),
pass_key=repr(self._pass_key),
mapping=repr(self._mapping),
)
def __getstate__(self):
"""
Serialize the object.
"""
return (
self._default_factory,
self._mapping,
self._sequence_type,
self._pass_key,
self._allow_invalid_attributes,
)
def __setstate__(self, state):
"""
Deserialize the object.
"""
(default_factory, mapping, sequence_type, pass_key,
allow_invalid_attributes) = state
self._setattr('_default_factory', default_factory)
self._setattr('_mapping', mapping)
self._setattr('_sequence_type', sequence_type)
self._setattr('_pass_key', pass_key)
self._setattr('_allow_invalid_attributes', allow_invalid_attributes)
@classmethod
def _constructor(cls, mapping, configuration):
"""
A standardized constructor.
"""
sequence_type, default_factory, pass_key = configuration
return cls(default_factory, mapping, sequence_type=sequence_type,
pass_key=pass_key)
|