| | from typing import Any |
| | from argparse import Namespace |
| | import typing |
| |
|
| |
|
| | class DotDict(Namespace): |
| | """A simple class that builds upon `argparse.Namespace` |
| | in order to make chained attributes possible.""" |
| |
|
| | def __init__(self, temp=False, key=None, parent=None) -> None: |
| | self._temp = temp |
| | self._key = key |
| | self._parent = parent |
| |
|
| | def __eq__(self, other): |
| | if not isinstance(other, DotDict): |
| | return NotImplemented |
| | return vars(self) == vars(other) |
| |
|
| | def __getattr__(self, __name: str) -> Any: |
| | if __name not in self.__dict__ and not self._temp: |
| | self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self) |
| | else: |
| | del self._parent.__dict__[self._key] |
| | raise AttributeError("No attribute '%s'" % __name) |
| | return self.__dict__[__name] |
| |
|
| | def __repr__(self) -> str: |
| | item_keys = [k for k in self.__dict__ if not k.startswith("_")] |
| |
|
| | if len(item_keys) == 0: |
| | return "DotDict()" |
| | elif len(item_keys) == 1: |
| | key = item_keys[0] |
| | val = self.__dict__[key] |
| | return "DotDict(%s=%s)" % (key, repr(val)) |
| | else: |
| | return "DotDict(%s)" % ", ".join( |
| | "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items() |
| | ) |
| |
|
| | @classmethod |
| | def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict": |
| | """Create a DotDict from a (possibly nested) dict `original`. |
| | Warning: this method should not be used on very deeply nested inputs, |
| | since it's recursively traversing the nested dictionary values. |
| | """ |
| | dd = DotDict() |
| | for key, value in original.items(): |
| | if isinstance(value, typing.Mapping): |
| | value = cls.from_dict(value) |
| | setattr(dd, key, value) |
| | return dd |