Tiger14n commited on
Commit
914e267
·
1 Parent(s): 89d40d1
omegaconf/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import Container, DictKeyType, Node, SCMode, UnionNode
2
+ from .dictconfig import DictConfig
3
+ from .errors import (
4
+ KeyValidationError,
5
+ MissingMandatoryValue,
6
+ ReadonlyConfigError,
7
+ UnsupportedValueType,
8
+ ValidationError,
9
+ )
10
+ from .listconfig import ListConfig
11
+ from .nodes import (
12
+ AnyNode,
13
+ BooleanNode,
14
+ BytesNode,
15
+ EnumNode,
16
+ FloatNode,
17
+ IntegerNode,
18
+ PathNode,
19
+ StringNode,
20
+ ValueNode,
21
+ )
22
+ from .omegaconf import (
23
+ II,
24
+ MISSING,
25
+ SI,
26
+ OmegaConf,
27
+ Resolver,
28
+ flag_override,
29
+ open_dict,
30
+ read_write,
31
+ )
32
+ from .version import __version__
33
+
34
+ __all__ = [
35
+ "__version__",
36
+ "MissingMandatoryValue",
37
+ "ValidationError",
38
+ "ReadonlyConfigError",
39
+ "UnsupportedValueType",
40
+ "KeyValidationError",
41
+ "Container",
42
+ "UnionNode",
43
+ "ListConfig",
44
+ "DictConfig",
45
+ "DictKeyType",
46
+ "OmegaConf",
47
+ "Resolver",
48
+ "SCMode",
49
+ "flag_override",
50
+ "read_write",
51
+ "open_dict",
52
+ "Node",
53
+ "ValueNode",
54
+ "AnyNode",
55
+ "IntegerNode",
56
+ "StringNode",
57
+ "BytesNode",
58
+ "PathNode",
59
+ "BooleanNode",
60
+ "EnumNode",
61
+ "FloatNode",
62
+ "MISSING",
63
+ "SI",
64
+ "II",
65
+ ]
omegaconf/_impl.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode
4
+ from omegaconf.errors import ConfigTypeError, InterpolationToMissingValueError
5
+
6
+ from ._utils import _DEFAULT_MARKER_, _get_value
7
+
8
+
9
+ def _resolve_container_value(cfg: Container, key: Any) -> None:
10
+ node = cfg._get_child(key)
11
+ assert isinstance(node, Node)
12
+ if node._is_interpolation():
13
+ try:
14
+ resolved = node._dereference_node()
15
+ except InterpolationToMissingValueError:
16
+ node._set_value(MISSING)
17
+ else:
18
+ if isinstance(resolved, Container):
19
+ _resolve(resolved)
20
+ if isinstance(resolved, Container) and isinstance(node, ValueNode):
21
+ cfg[key] = resolved
22
+ else:
23
+ node._set_value(_get_value(resolved))
24
+ else:
25
+ _resolve(node)
26
+
27
+
28
+ def _resolve(cfg: Node) -> Node:
29
+ assert isinstance(cfg, Node)
30
+ if cfg._is_interpolation():
31
+ try:
32
+ resolved = cfg._dereference_node()
33
+ except InterpolationToMissingValueError:
34
+ cfg._set_value(MISSING)
35
+ else:
36
+ cfg._set_value(resolved._value())
37
+
38
+ if isinstance(cfg, DictConfig):
39
+ for k in cfg.keys():
40
+ _resolve_container_value(cfg, k)
41
+
42
+ elif isinstance(cfg, ListConfig):
43
+ for i in range(len(cfg)):
44
+ _resolve_container_value(cfg, i)
45
+
46
+ return cfg
47
+
48
+
49
+ def select_value(
50
+ cfg: Container,
51
+ key: str,
52
+ *,
53
+ default: Any = _DEFAULT_MARKER_,
54
+ throw_on_resolution_failure: bool = True,
55
+ throw_on_missing: bool = False,
56
+ absolute_key: bool = False,
57
+ ) -> Any:
58
+ node = select_node(
59
+ cfg=cfg,
60
+ key=key,
61
+ throw_on_resolution_failure=throw_on_resolution_failure,
62
+ throw_on_missing=throw_on_missing,
63
+ absolute_key=absolute_key,
64
+ )
65
+
66
+ node_not_found = node is None
67
+ if node_not_found or node._is_missing():
68
+ if default is not _DEFAULT_MARKER_:
69
+ return default
70
+ else:
71
+ return None
72
+
73
+ return _get_value(node)
74
+
75
+
76
+ def select_node(
77
+ cfg: Container,
78
+ key: str,
79
+ *,
80
+ throw_on_resolution_failure: bool = True,
81
+ throw_on_missing: bool = False,
82
+ absolute_key: bool = False,
83
+ ) -> Any:
84
+ try:
85
+ # for non relative keys, the interpretation can be:
86
+ # 1. relative to cfg
87
+ # 2. relative to the config root
88
+ # This is controlled by the absolute_key flag. By default, such keys are relative to cfg.
89
+ if not absolute_key and not key.startswith("."):
90
+ key = f".{key}"
91
+
92
+ cfg, key = cfg._resolve_key_and_root(key)
93
+ _root, _last_key, node = cfg._select_impl(
94
+ key,
95
+ throw_on_missing=throw_on_missing,
96
+ throw_on_resolution_failure=throw_on_resolution_failure,
97
+ )
98
+ except ConfigTypeError:
99
+ return None
100
+
101
+ return node
omegaconf/_utils.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import pathlib
4
+ import re
5
+ import string
6
+ import sys
7
+ import types
8
+ import warnings
9
+ from contextlib import contextmanager
10
+ from enum import Enum
11
+ from textwrap import dedent
12
+ from typing import (
13
+ Any,
14
+ Dict,
15
+ Iterator,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ Union,
21
+ get_type_hints,
22
+ )
23
+
24
+ import yaml
25
+
26
+ from .errors import (
27
+ ConfigIndexError,
28
+ ConfigTypeError,
29
+ ConfigValueError,
30
+ GrammarParseError,
31
+ OmegaConfBaseException,
32
+ ValidationError,
33
+ )
34
+ from .grammar_parser import SIMPLE_INTERPOLATION_PATTERN, parse
35
+
36
+ try:
37
+ import dataclasses
38
+
39
+ except ImportError: # pragma: no cover
40
+ dataclasses = None # type: ignore # pragma: no cover
41
+
42
+ try:
43
+ import attr
44
+
45
+ except ImportError: # pragma: no cover
46
+ attr = None # type: ignore # pragma: no cover
47
+
48
+ NoneType: Type[None] = type(None)
49
+
50
+ BUILTIN_VALUE_TYPES: Tuple[Type[Any], ...] = (
51
+ int,
52
+ float,
53
+ bool,
54
+ str,
55
+ bytes,
56
+ NoneType,
57
+ )
58
+
59
+ # Regexprs to match key paths like: a.b, a[b], ..a[c].d, etc.
60
+ # We begin by matching the head (in these examples: a, a, ..a).
61
+ # This can be read as "dots followed by any character but `.` or `[`"
62
+ # Note that a key starting with brackets, like [a], is purposedly *not*
63
+ # matched here and will instead be handled in the next regex below (this
64
+ # is to keep this regex simple).
65
+ KEY_PATH_HEAD = re.compile(r"(\.)*[^.[]*")
66
+ # Then we match other keys. The following expression matches one key and can
67
+ # be read as a choice between two syntaxes:
68
+ # - `.` followed by anything except `.` or `[` (ex: .b, .d)
69
+ # - `[` followed by anything then `]` (ex: [b], [c])
70
+ KEY_PATH_OTHER = re.compile(r"\.([^.[]*)|\[(.*?)\]")
71
+
72
+
73
+ # source: https://yaml.org/type/bool.html
74
+ YAML_BOOL_TYPES = [
75
+ "y",
76
+ "Y",
77
+ "yes",
78
+ "Yes",
79
+ "YES",
80
+ "n",
81
+ "N",
82
+ "no",
83
+ "No",
84
+ "NO",
85
+ "true",
86
+ "True",
87
+ "TRUE",
88
+ "false",
89
+ "False",
90
+ "FALSE",
91
+ "on",
92
+ "On",
93
+ "ON",
94
+ "off",
95
+ "Off",
96
+ "OFF",
97
+ ]
98
+
99
+
100
+ class Marker:
101
+ def __init__(self, desc: str):
102
+ self.desc = desc
103
+
104
+ def __repr__(self) -> str:
105
+ return self.desc
106
+
107
+
108
+ # To be used as default value when `None` is not an option.
109
+ _DEFAULT_MARKER_: Any = Marker("_DEFAULT_MARKER_")
110
+
111
+
112
+ class OmegaConfDumper(yaml.Dumper): # type: ignore
113
+ str_representer_added = False
114
+
115
+ @staticmethod
116
+ def str_representer(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
117
+ with_quotes = yaml_is_bool(data) or is_int(data) or is_float(data)
118
+ return dumper.represent_scalar(
119
+ yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG,
120
+ data,
121
+ style=("'" if with_quotes else None),
122
+ )
123
+
124
+
125
+ def get_omega_conf_dumper() -> Type[OmegaConfDumper]:
126
+ if not OmegaConfDumper.str_representer_added:
127
+ OmegaConfDumper.add_representer(str, OmegaConfDumper.str_representer)
128
+ OmegaConfDumper.str_representer_added = True
129
+ return OmegaConfDumper
130
+
131
+
132
+ def yaml_is_bool(b: str) -> bool:
133
+ return b in YAML_BOOL_TYPES
134
+
135
+
136
+ def get_yaml_loader() -> Any:
137
+ class OmegaConfLoader(yaml.SafeLoader): # type: ignore
138
+ def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Any:
139
+ keys = set()
140
+ for key_node, value_node in node.value:
141
+ if key_node.tag != yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG:
142
+ continue
143
+ if key_node.value in keys:
144
+ raise yaml.constructor.ConstructorError(
145
+ "while constructing a mapping",
146
+ node.start_mark,
147
+ f"found duplicate key {key_node.value}",
148
+ key_node.start_mark,
149
+ )
150
+ keys.add(key_node.value)
151
+ return super().construct_mapping(node, deep=deep)
152
+
153
+ loader = OmegaConfLoader
154
+ loader.add_implicit_resolver(
155
+ "tag:yaml.org,2002:float",
156
+ re.compile(
157
+ """^(?:
158
+ [-+]?[0-9]+(?:_[0-9]+)*\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
159
+ |[-+]?[0-9]+(?:_[0-9]+)*(?:[eE][-+]?[0-9]+)
160
+ |\\.[0-9]+(?:_[0-9]+)*(?:[eE][-+][0-9]+)?
161
+ |[-+]?[0-9]+(?:_[0-9]+)*(?::[0-5]?[0-9])+\\.[0-9_]*
162
+ |[-+]?\\.(?:inf|Inf|INF)
163
+ |\\.(?:nan|NaN|NAN))$""",
164
+ re.X,
165
+ ),
166
+ list("-+0123456789."),
167
+ )
168
+ loader.yaml_implicit_resolvers = {
169
+ key: [
170
+ (tag, regexp)
171
+ for tag, regexp in resolvers
172
+ if tag != "tag:yaml.org,2002:timestamp"
173
+ ]
174
+ for key, resolvers in loader.yaml_implicit_resolvers.items()
175
+ }
176
+
177
+ loader.add_constructor(
178
+ "tag:yaml.org,2002:python/object/apply:pathlib.Path",
179
+ lambda loader, node: pathlib.Path(*loader.construct_sequence(node)),
180
+ )
181
+ loader.add_constructor(
182
+ "tag:yaml.org,2002:python/object/apply:pathlib.PosixPath",
183
+ lambda loader, node: pathlib.PosixPath(*loader.construct_sequence(node)),
184
+ )
185
+ loader.add_constructor(
186
+ "tag:yaml.org,2002:python/object/apply:pathlib.WindowsPath",
187
+ lambda loader, node: pathlib.WindowsPath(*loader.construct_sequence(node)),
188
+ )
189
+
190
+ return loader
191
+
192
+
193
+ def _get_class(path: str) -> type:
194
+ from importlib import import_module
195
+
196
+ module_path, _, class_name = path.rpartition(".")
197
+ mod = import_module(module_path)
198
+ try:
199
+ klass: type = getattr(mod, class_name)
200
+ except AttributeError:
201
+ raise ImportError(f"Class {class_name} is not in module {module_path}")
202
+ return klass
203
+
204
+
205
+ def is_union_annotation(type_: Any) -> bool:
206
+ if sys.version_info >= (3, 10): # pragma: no cover
207
+ if isinstance(type_, types.UnionType):
208
+ return True
209
+ return getattr(type_, "__origin__", None) is Union
210
+
211
+
212
+ def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
213
+ """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
214
+ if is_union_annotation(type_):
215
+ args = type_.__args__
216
+ if NoneType in args:
217
+ optional = True
218
+ args = tuple(a for a in args if a is not NoneType)
219
+ else:
220
+ optional = False
221
+ if len(args) == 1:
222
+ return optional, args[0]
223
+ elif len(args) >= 2:
224
+ return optional, Union[args]
225
+ else:
226
+ assert False
227
+
228
+ if type_ is Any:
229
+ return True, Any
230
+
231
+ if type_ in (None, NoneType):
232
+ return True, NoneType
233
+
234
+ return False, type_
235
+
236
+
237
+ def _is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool:
238
+ """Check `obj` metadata to see if the given node is optional."""
239
+ from .base import Container, Node
240
+
241
+ if key is not None:
242
+ assert isinstance(obj, Container)
243
+ obj = obj._get_node(key)
244
+ assert isinstance(obj, Node)
245
+ return obj._is_optional()
246
+
247
+
248
+ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
249
+ import typing # lgtm [py/import-and-import-from]
250
+
251
+ forward = typing.ForwardRef if hasattr(typing, "ForwardRef") else typing._ForwardRef # type: ignore
252
+ if type(type_) is forward:
253
+ return _get_class(f"{module}.{type_.__forward_arg__}")
254
+ else:
255
+ if is_dict_annotation(type_):
256
+ kt, vt = get_dict_key_value_types(type_)
257
+ if kt is not None:
258
+ kt = _resolve_forward(kt, module=module)
259
+ if vt is not None:
260
+ vt = _resolve_forward(vt, module=module)
261
+ return Dict[kt, vt] # type: ignore
262
+ if is_list_annotation(type_):
263
+ et = get_list_element_type(type_)
264
+ if et is not None:
265
+ et = _resolve_forward(et, module=module)
266
+ return List[et] # type: ignore
267
+ if is_tuple_annotation(type_):
268
+ its = get_tuple_item_types(type_)
269
+ its = tuple(_resolve_forward(it, module=module) for it in its)
270
+ return Tuple[its] # type: ignore
271
+
272
+ return type_
273
+
274
+
275
+ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]]:
276
+ """Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
277
+ from omegaconf.omegaconf import _maybe_wrap
278
+
279
+ is_type = isinstance(obj, type)
280
+ obj_type = obj if is_type else type(obj)
281
+ subclasses_dict = is_dict_subclass(obj_type)
282
+
283
+ if subclasses_dict:
284
+ warnings.warn(
285
+ f"Class `{obj_type.__name__}` subclasses `Dict`."
286
+ + " Subclassing `Dict` in Structured Config classes is deprecated,"
287
+ + " see github.com/omry/omegaconf/issues/663",
288
+ UserWarning,
289
+ stacklevel=9,
290
+ )
291
+
292
+ if is_type:
293
+ return None
294
+ elif subclasses_dict:
295
+ dict_subclass_data = {}
296
+ key_type, element_type = get_dict_key_value_types(obj_type)
297
+ for name, value in obj.items():
298
+ is_optional, type_ = _resolve_optional(element_type)
299
+ type_ = _resolve_forward(type_, obj.__module__)
300
+ try:
301
+ dict_subclass_data[name] = _maybe_wrap(
302
+ ref_type=type_,
303
+ is_optional=is_optional,
304
+ key=name,
305
+ value=value,
306
+ parent=parent,
307
+ )
308
+ except ValidationError as ex:
309
+ format_and_raise(
310
+ node=None, key=name, value=value, cause=ex, msg=str(ex)
311
+ )
312
+ return dict_subclass_data
313
+ else:
314
+ return None
315
+
316
+
317
+ def get_attr_class_fields(obj: Any) -> List["attr.Attribute[Any]"]:
318
+ is_type = isinstance(obj, type)
319
+ obj_type = obj if is_type else type(obj)
320
+ fields = attr.fields_dict(obj_type).values()
321
+ return [f for f in fields if f.metadata.get("omegaconf_ignore") is not True]
322
+
323
+
324
+ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
325
+ from omegaconf.omegaconf import OmegaConf, _maybe_wrap
326
+
327
+ flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
328
+
329
+ from omegaconf import MISSING
330
+
331
+ d = {}
332
+ is_type = isinstance(obj, type)
333
+ obj_type = obj if is_type else type(obj)
334
+ dummy_parent = OmegaConf.create({}, flags=flags)
335
+ dummy_parent._metadata.object_type = obj_type
336
+ resolved_hints = get_type_hints(obj_type)
337
+
338
+ for attrib in get_attr_class_fields(obj):
339
+ name = attrib.name
340
+ is_optional, type_ = _resolve_optional(resolved_hints[name])
341
+ type_ = _resolve_forward(type_, obj.__module__)
342
+ if not is_type:
343
+ value = getattr(obj, name)
344
+ else:
345
+ value = attrib.default
346
+ if value == attr.NOTHING:
347
+ value = MISSING
348
+ if is_union_annotation(type_) and not is_supported_union_annotation(type_):
349
+ e = ConfigValueError(
350
+ f"Unions of containers are not supported:\n{name}: {type_str(type_)}"
351
+ )
352
+ format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
353
+
354
+ try:
355
+ d[name] = _maybe_wrap(
356
+ ref_type=type_,
357
+ is_optional=is_optional,
358
+ key=name,
359
+ value=value,
360
+ parent=dummy_parent,
361
+ )
362
+ except (ValidationError, GrammarParseError) as ex:
363
+ format_and_raise(
364
+ node=dummy_parent, key=name, value=value, cause=ex, msg=str(ex)
365
+ )
366
+ d[name]._set_parent(None)
367
+ dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
368
+ if dict_subclass_data is not None:
369
+ d.update(dict_subclass_data)
370
+ return d
371
+
372
+
373
+ def get_dataclass_fields(obj: Any) -> List["dataclasses.Field[Any]"]:
374
+ fields = dataclasses.fields(obj)
375
+ return [f for f in fields if f.metadata.get("omegaconf_ignore") is not True]
376
+
377
+
378
+ def get_dataclass_data(
379
+ obj: Any, allow_objects: Optional[bool] = None
380
+ ) -> Dict[str, Any]:
381
+ from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap
382
+
383
+ flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
384
+ d = {}
385
+ is_type = isinstance(obj, type)
386
+ obj_type = get_type_of(obj)
387
+ dummy_parent = OmegaConf.create({}, flags=flags)
388
+ dummy_parent._metadata.object_type = obj_type
389
+ resolved_hints = get_type_hints(obj_type)
390
+ for field in get_dataclass_fields(obj):
391
+ name = field.name
392
+ is_optional, type_ = _resolve_optional(resolved_hints[field.name])
393
+ type_ = _resolve_forward(type_, obj.__module__)
394
+ has_default = field.default != dataclasses.MISSING
395
+ has_default_factory = field.default_factory != dataclasses.MISSING
396
+
397
+ if not is_type:
398
+ value = getattr(obj, name)
399
+ else:
400
+ if has_default:
401
+ value = field.default
402
+ elif has_default_factory:
403
+ value = field.default_factory() # type: ignore
404
+ else:
405
+ value = MISSING
406
+
407
+ if is_union_annotation(type_) and not is_supported_union_annotation(type_):
408
+ e = ConfigValueError(
409
+ f"Unions of containers are not supported:\n{name}: {type_str(type_)}"
410
+ )
411
+ format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
412
+ try:
413
+ d[name] = _maybe_wrap(
414
+ ref_type=type_,
415
+ is_optional=is_optional,
416
+ key=name,
417
+ value=value,
418
+ parent=dummy_parent,
419
+ )
420
+ except (ValidationError, GrammarParseError) as ex:
421
+ format_and_raise(
422
+ node=dummy_parent, key=name, value=value, cause=ex, msg=str(ex)
423
+ )
424
+ d[name]._set_parent(None)
425
+ dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
426
+ if dict_subclass_data is not None:
427
+ d.update(dict_subclass_data)
428
+ return d
429
+
430
+
431
+ def is_dataclass(obj: Any) -> bool:
432
+ from omegaconf.base import Node
433
+
434
+ if dataclasses is None or isinstance(obj, Node):
435
+ return False
436
+ return dataclasses.is_dataclass(obj)
437
+
438
+
439
+ def is_attr_class(obj: Any) -> bool:
440
+ from omegaconf.base import Node
441
+
442
+ if attr is None or isinstance(obj, Node):
443
+ return False
444
+ return attr.has(obj)
445
+
446
+
447
+ def is_structured_config(obj: Any) -> bool:
448
+ return is_attr_class(obj) or is_dataclass(obj)
449
+
450
+
451
+ def is_dataclass_frozen(type_: Any) -> bool:
452
+ return type_.__dataclass_params__.frozen # type: ignore
453
+
454
+
455
+ def is_attr_frozen(type_: type) -> bool:
456
+ # This is very hacky and probably fragile as well.
457
+ # Unfortunately currently there isn't an official API in attr that can detect that.
458
+ # noinspection PyProtectedMember
459
+ return type_.__setattr__ == attr._make._frozen_setattrs # type: ignore
460
+
461
+
462
+ def get_type_of(class_or_object: Any) -> Type[Any]:
463
+ type_ = class_or_object
464
+ if not isinstance(type_, type):
465
+ type_ = type(class_or_object)
466
+ assert isinstance(type_, type)
467
+ return type_
468
+
469
+
470
+ def is_structured_config_frozen(obj: Any) -> bool:
471
+ type_ = get_type_of(obj)
472
+
473
+ if is_dataclass(type_):
474
+ return is_dataclass_frozen(type_)
475
+ if is_attr_class(type_):
476
+ return is_attr_frozen(type_)
477
+ return False
478
+
479
+
480
+ def get_structured_config_init_field_names(obj: Any) -> List[str]:
481
+ fields: Union[List["dataclasses.Field[Any]"], List["attr.Attribute[Any]"]]
482
+ if is_dataclass(obj):
483
+ fields = get_dataclass_fields(obj)
484
+ elif is_attr_class(obj):
485
+ fields = get_attr_class_fields(obj)
486
+ else:
487
+ raise ValueError(f"Unsupported type: {type(obj).__name__}")
488
+ return [f.name for f in fields if f.init]
489
+
490
+
491
+ def get_structured_config_data(
492
+ obj: Any, allow_objects: Optional[bool] = None
493
+ ) -> Dict[str, Any]:
494
+ if is_dataclass(obj):
495
+ return get_dataclass_data(obj, allow_objects=allow_objects)
496
+ elif is_attr_class(obj):
497
+ return get_attr_data(obj, allow_objects=allow_objects)
498
+ else:
499
+ raise ValueError(f"Unsupported type: {type(obj).__name__}")
500
+
501
+
502
+ class ValueKind(Enum):
503
+ VALUE = 0
504
+ MANDATORY_MISSING = 1
505
+ INTERPOLATION = 2
506
+
507
+
508
+ def _is_missing_value(value: Any) -> bool:
509
+ from omegaconf import Node
510
+
511
+ if isinstance(value, Node):
512
+ value = value._value()
513
+ return _is_missing_literal(value)
514
+
515
+
516
+ def _is_missing_literal(value: Any) -> bool:
517
+ # Uses literal '???' instead of the MISSING const for performance reasons.
518
+ return isinstance(value, str) and value == "???"
519
+
520
+
521
+ def _is_none(
522
+ value: Any, resolve: bool = False, throw_on_resolution_failure: bool = True
523
+ ) -> bool:
524
+ from omegaconf import Node
525
+
526
+ if not isinstance(value, Node):
527
+ return value is None
528
+
529
+ if resolve:
530
+ value = value._maybe_dereference_node(
531
+ throw_on_resolution_failure=throw_on_resolution_failure
532
+ )
533
+ if not throw_on_resolution_failure and value is None:
534
+ # Resolution failure: consider that it is *not* None.
535
+ return False
536
+ assert isinstance(value, Node)
537
+
538
+ return value._is_none()
539
+
540
+
541
+ def get_value_kind(
542
+ value: Any, strict_interpolation_validation: bool = False
543
+ ) -> ValueKind:
544
+ """
545
+ Determine the kind of a value
546
+ Examples:
547
+ VALUE: "10", "20", True
548
+ MANDATORY_MISSING: "???"
549
+ INTERPOLATION: "${foo.bar}", "${foo.${bar}}", "${foo:bar}", "[${foo}, ${bar}]",
550
+ "ftp://${host}/path", "${foo:${bar}, [true], {'baz': ${baz}}}"
551
+
552
+ :param value: Input to classify.
553
+ :param strict_interpolation_validation: If `True`, then when `value` is a string
554
+ containing "${", it is parsed to validate the interpolation syntax. If `False`,
555
+ this parsing step is skipped: this is more efficient, but will not detect errors.
556
+ """
557
+
558
+ if _is_missing_value(value):
559
+ return ValueKind.MANDATORY_MISSING
560
+
561
+ if _is_interpolation(value, strict_interpolation_validation):
562
+ return ValueKind.INTERPOLATION
563
+
564
+ return ValueKind.VALUE
565
+
566
+
567
+ def _is_interpolation(v: Any, strict_interpolation_validation: bool = False) -> bool:
568
+ from omegaconf import Node
569
+
570
+ if isinstance(v, Node):
571
+ v = v._value()
572
+
573
+ if isinstance(v, str) and _is_interpolation_string(
574
+ v, strict_interpolation_validation
575
+ ):
576
+ return True
577
+ return False
578
+
579
+
580
+ def _is_interpolation_string(value: str, strict_interpolation_validation: bool) -> bool:
581
+ # We identify potential interpolations by the presence of "${" in the string.
582
+ # Note that escaped interpolations (ex: "esc: \${bar}") are identified as
583
+ # interpolations: this is intended, since they must be processed as interpolations
584
+ # for the string to be properly un-escaped.
585
+ # Keep in mind that invalid interpolations will only be detected when
586
+ # `strict_interpolation_validation` is True.
587
+ if "${" in value:
588
+ if strict_interpolation_validation:
589
+ # First try the cheap regex matching that detects common interpolations.
590
+ if SIMPLE_INTERPOLATION_PATTERN.match(value) is None:
591
+ # If no match, do the more expensive grammar parsing to detect errors.
592
+ parse(value)
593
+ return True
594
+ return False
595
+
596
+
597
+ def _is_special(value: Any) -> bool:
598
+ """Special values are None, MISSING, and interpolation."""
599
+ return _is_none(value) or get_value_kind(value) in (
600
+ ValueKind.MANDATORY_MISSING,
601
+ ValueKind.INTERPOLATION,
602
+ )
603
+
604
+
605
+ def is_float(st: str) -> bool:
606
+ try:
607
+ float(st)
608
+ return True
609
+ except ValueError:
610
+ return False
611
+
612
+
613
+ def is_int(st: str) -> bool:
614
+ try:
615
+ int(st)
616
+ return True
617
+ except ValueError:
618
+ return False
619
+
620
+
621
+ def is_primitive_list(obj: Any) -> bool:
622
+ return isinstance(obj, (list, tuple))
623
+
624
+
625
+ def is_primitive_dict(obj: Any) -> bool:
626
+ t = get_type_of(obj)
627
+ return t is dict
628
+
629
+
630
+ def is_dict_annotation(type_: Any) -> bool:
631
+ if type_ in (dict, Dict):
632
+ return True
633
+ origin = getattr(type_, "__origin__", None)
634
+ # type_dict is a bit hard to detect.
635
+ # this support is tentative, if it eventually causes issues in other areas it may be dropped.
636
+ if sys.version_info < (3, 7, 0): # pragma: no cover
637
+ typed_dict = hasattr(type_, "__base__") and type_.__base__ == Dict
638
+ return origin is Dict or type_ is Dict or typed_dict
639
+ else: # pragma: no cover
640
+ typed_dict = hasattr(type_, "__base__") and type_.__base__ == dict
641
+ return origin is dict or typed_dict
642
+
643
+
644
+ def is_list_annotation(type_: Any) -> bool:
645
+ if type_ in (list, List):
646
+ return True
647
+ origin = getattr(type_, "__origin__", None)
648
+ if sys.version_info < (3, 7, 0):
649
+ return origin is List or type_ is List # pragma: no cover
650
+ else:
651
+ return origin is list # pragma: no cover
652
+
653
+
654
+ def is_tuple_annotation(type_: Any) -> bool:
655
+ if type_ in (tuple, Tuple):
656
+ return True
657
+ origin = getattr(type_, "__origin__", None)
658
+ if sys.version_info < (3, 7, 0):
659
+ return origin is Tuple or type_ is Tuple # pragma: no cover
660
+ else:
661
+ return origin is tuple # pragma: no cover
662
+
663
+
664
+ def is_supported_union_annotation(obj: Any) -> bool:
665
+ """Currently only primitive types are supported in Unions, e.g. Union[int, str]"""
666
+ if not is_union_annotation(obj):
667
+ return False
668
+ args = obj.__args__
669
+ return all(is_primitive_type_annotation(arg) for arg in args)
670
+
671
+
672
+ def is_dict_subclass(type_: Any) -> bool:
673
+ return type_ is not None and isinstance(type_, type) and issubclass(type_, Dict)
674
+
675
+
676
+ def is_dict(obj: Any) -> bool:
677
+ return is_primitive_dict(obj) or is_dict_annotation(obj) or is_dict_subclass(obj)
678
+
679
+
680
+ def is_primitive_container(obj: Any) -> bool:
681
+ return is_primitive_list(obj) or is_primitive_dict(obj)
682
+
683
+
684
+ def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any:
685
+ args = getattr(ref_type, "__args__", None)
686
+ if ref_type is not List and args is not None and args[0]:
687
+ element_type = args[0]
688
+ else:
689
+ element_type = Any
690
+ return element_type
691
+
692
+
693
+ def get_tuple_item_types(ref_type: Type[Any]) -> Tuple[Any, ...]:
694
+ args = getattr(ref_type, "__args__", None)
695
+ if args in (None, ()):
696
+ args = (Any, ...)
697
+ assert isinstance(args, tuple)
698
+ return args
699
+
700
+
701
+ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:
702
+ args = getattr(ref_type, "__args__", None)
703
+ if args is None:
704
+ bases = getattr(ref_type, "__orig_bases__", None)
705
+ if bases is not None and len(bases) > 0:
706
+ args = getattr(bases[0], "__args__", None)
707
+
708
+ key_type: Any
709
+ element_type: Any
710
+ if ref_type is None or ref_type == Dict:
711
+ key_type = Any
712
+ element_type = Any
713
+ else:
714
+ if args is not None:
715
+ key_type = args[0]
716
+ element_type = args[1]
717
+ else:
718
+ key_type = Any
719
+ element_type = Any
720
+
721
+ return key_type, element_type
722
+
723
+
724
+ def is_valid_value_annotation(type_: Any) -> bool:
725
+ _, type_ = _resolve_optional(type_)
726
+ return (
727
+ type_ is Any
728
+ or is_primitive_type_annotation(type_)
729
+ or is_structured_config(type_)
730
+ or is_container_annotation(type_)
731
+ or is_supported_union_annotation(type_)
732
+ )
733
+
734
+
735
+ def _valid_dict_key_annotation_type(type_: Any) -> bool:
736
+ from omegaconf import DictKeyType
737
+
738
+ return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore
739
+
740
+
741
+ def is_primitive_type_annotation(type_: Any) -> bool:
742
+ type_ = get_type_of(type_)
743
+ return issubclass(type_, (Enum, pathlib.Path)) or type_ in BUILTIN_VALUE_TYPES
744
+
745
+
746
+ def _get_value(value: Any) -> Any:
747
+ from .base import Container, UnionNode
748
+ from .nodes import ValueNode
749
+
750
+ if isinstance(value, ValueNode):
751
+ return value._value()
752
+ elif isinstance(value, Container):
753
+ boxed = value._value()
754
+ if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
755
+ return boxed
756
+ elif isinstance(value, UnionNode):
757
+ boxed = value._value()
758
+ if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
759
+ return boxed
760
+ else:
761
+ return _get_value(boxed) # pass through value of boxed node
762
+
763
+ # return primitives and regular OmegaConf Containers as is
764
+ return value
765
+
766
+
767
+ def get_type_hint(obj: Any, key: Any = None) -> Optional[Type[Any]]:
768
+ from omegaconf import Container, Node
769
+
770
+ if isinstance(obj, Container):
771
+ if key is not None:
772
+ obj = obj._get_node(key)
773
+ else:
774
+ if key is not None:
775
+ raise ValueError("Key must only be provided when obj is a container")
776
+
777
+ if isinstance(obj, Node):
778
+ ref_type = obj._metadata.ref_type
779
+ if obj._is_optional() and ref_type is not Any:
780
+ return Optional[ref_type] # type: ignore
781
+ else:
782
+ return ref_type
783
+ else:
784
+ return Any # type: ignore
785
+
786
+
787
+ def _raise(ex: Exception, cause: Exception) -> None:
788
+ # Set the environment variable OC_CAUSE=1 to get a stacktrace that includes the
789
+ # causing exception.
790
+ env_var = os.environ["OC_CAUSE"] if "OC_CAUSE" in os.environ else None
791
+ debugging = sys.gettrace() is not None
792
+ full_backtrace = (debugging and not env_var == "0") or (env_var == "1")
793
+ if full_backtrace:
794
+ ex.__cause__ = cause
795
+ else:
796
+ ex.__cause__ = None
797
+ raise ex.with_traceback(sys.exc_info()[2]) # set env var OC_CAUSE=1 for full trace
798
+
799
+
800
+ def format_and_raise(
801
+ node: Any,
802
+ key: Any,
803
+ value: Any,
804
+ msg: str,
805
+ cause: Exception,
806
+ type_override: Any = None,
807
+ ) -> None:
808
+ from omegaconf import OmegaConf
809
+ from omegaconf.base import Node
810
+
811
+ if isinstance(cause, AssertionError):
812
+ raise
813
+
814
+ if isinstance(cause, OmegaConfBaseException) and cause._initialized:
815
+ ex = cause
816
+ if type_override is not None:
817
+ ex = type_override(str(cause))
818
+ ex.__dict__ = copy.deepcopy(cause.__dict__)
819
+ _raise(ex, cause)
820
+
821
+ object_type: Optional[Type[Any]]
822
+ object_type_str: Optional[str] = None
823
+ ref_type: Optional[Type[Any]]
824
+ ref_type_str: Optional[str]
825
+
826
+ child_node: Optional[Node] = None
827
+ if node is None:
828
+ full_key = key if key is not None else ""
829
+ object_type = None
830
+ ref_type = None
831
+ ref_type_str = None
832
+ else:
833
+ if key is not None and not node._is_none():
834
+ child_node = node._get_node(key, validate_access=False)
835
+
836
+ try:
837
+ full_key = node._get_full_key(key=key)
838
+ except Exception as exc:
839
+ # Since we are handling an exception, raising a different one here would
840
+ # be misleading. Instead, we display it in the key.
841
+ full_key = f"<unresolvable due to {type(exc).__name__}: {exc}>"
842
+
843
+ object_type = OmegaConf.get_type(node)
844
+ object_type_str = type_str(object_type)
845
+
846
+ ref_type = get_type_hint(node)
847
+ ref_type_str = type_str(ref_type)
848
+
849
+ msg = string.Template(msg).safe_substitute(
850
+ REF_TYPE=ref_type_str,
851
+ OBJECT_TYPE=object_type_str,
852
+ KEY=key,
853
+ FULL_KEY=full_key,
854
+ VALUE=value,
855
+ VALUE_TYPE=type_str(type(value), include_module_name=True),
856
+ KEY_TYPE=f"{type(key).__name__}",
857
+ )
858
+
859
+ if ref_type not in (None, Any):
860
+ template = dedent(
861
+ """\
862
+ $MSG
863
+ full_key: $FULL_KEY
864
+ reference_type=$REF_TYPE
865
+ object_type=$OBJECT_TYPE"""
866
+ )
867
+ else:
868
+ template = dedent(
869
+ """\
870
+ $MSG
871
+ full_key: $FULL_KEY
872
+ object_type=$OBJECT_TYPE"""
873
+ )
874
+ s = string.Template(template=template)
875
+
876
+ message = s.substitute(
877
+ REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, MSG=msg, FULL_KEY=full_key
878
+ )
879
+ exception_type = type(cause) if type_override is None else type_override
880
+ if exception_type == TypeError:
881
+ exception_type = ConfigTypeError
882
+ elif exception_type == IndexError:
883
+ exception_type = ConfigIndexError
884
+
885
+ ex = exception_type(f"{message}")
886
+ if issubclass(exception_type, OmegaConfBaseException):
887
+ ex._initialized = True
888
+ ex.msg = message
889
+ ex.parent_node = node
890
+ ex.child_node = child_node
891
+ ex.key = key
892
+ ex.full_key = full_key
893
+ ex.value = value
894
+ ex.object_type = object_type
895
+ ex.object_type_str = object_type_str
896
+ ex.ref_type = ref_type
897
+ ex.ref_type_str = ref_type_str
898
+
899
+ _raise(ex, cause)
900
+
901
+
902
+ def type_str(t: Any, include_module_name: bool = False) -> str:
903
+ is_optional, t = _resolve_optional(t)
904
+ if t is NoneType:
905
+ return str(t.__name__)
906
+ if t is Any:
907
+ return "Any"
908
+ if t is ...:
909
+ return "..."
910
+
911
+ if hasattr(t, "__name__"):
912
+ name = str(t.__name__)
913
+ elif getattr(t, "_name", None) is not None: # pragma: no cover
914
+ name = str(t._name)
915
+ elif getattr(t, "__origin__", None) is not None: # pragma: no cover
916
+ name = type_str(t.__origin__)
917
+ else:
918
+ name = str(t)
919
+ if name.startswith("typing."): # pragma: no cover
920
+ name = name[len("typing.") :]
921
+
922
+ args = getattr(t, "__args__", None)
923
+ if args is not None:
924
+ args = ", ".join(
925
+ [type_str(t, include_module_name=include_module_name) for t in t.__args__]
926
+ )
927
+ ret = f"{name}[{args}]"
928
+ else:
929
+ ret = name
930
+ if include_module_name:
931
+ if (
932
+ hasattr(t, "__module__")
933
+ and t.__module__ != "builtins"
934
+ and t.__module__ != "typing"
935
+ and not t.__module__.startswith("omegaconf.")
936
+ ):
937
+ module_prefix = str(t.__module__) + "."
938
+ else:
939
+ module_prefix = ""
940
+ ret = module_prefix + ret
941
+ if is_optional:
942
+ return f"Optional[{ret}]"
943
+ else:
944
+ return ret
945
+
946
+
947
+ def _ensure_container(target: Any, flags: Optional[Dict[str, bool]] = None) -> Any:
948
+ from omegaconf import OmegaConf
949
+
950
+ if is_primitive_container(target):
951
+ assert isinstance(target, (list, dict))
952
+ target = OmegaConf.create(target, flags=flags)
953
+ elif is_structured_config(target):
954
+ target = OmegaConf.structured(target, flags=flags)
955
+ elif not OmegaConf.is_config(target):
956
+ raise ValueError(
957
+ "Invalid input. Supports one of "
958
+ + "[dict,list,DictConfig,ListConfig,dataclass,dataclass instance,attr class,attr class instance]"
959
+ )
960
+
961
+ return target
962
+
963
+
964
+ def is_generic_list(type_: Any) -> bool:
965
+ """
966
+ Checks if a type is a generic list, for example:
967
+ list returns False
968
+ typing.List returns False
969
+ typing.List[T] returns True
970
+
971
+ :param type_: variable type
972
+ :return: bool
973
+ """
974
+ return is_list_annotation(type_) and get_list_element_type(type_) is not None
975
+
976
+
977
+ def is_generic_dict(type_: Any) -> bool:
978
+ """
979
+ Checks if a type is a generic dict, for example:
980
+ list returns False
981
+ typing.List returns False
982
+ typing.List[T] returns True
983
+
984
+ :param type_: variable type
985
+ :return: bool
986
+ """
987
+ return is_dict_annotation(type_) and len(get_dict_key_value_types(type_)) > 0
988
+
989
+
990
+ def is_container_annotation(type_: Any) -> bool:
991
+ return is_list_annotation(type_) or is_dict_annotation(type_)
992
+
993
+
994
+ def split_key(key: str) -> List[str]:
995
+ """
996
+ Split a full key path into its individual components.
997
+
998
+ This is similar to `key.split(".")` but also works with the getitem syntax:
999
+ "a.b" -> ["a", "b"]
1000
+ "a[b]" -> ["a", "b"]
1001
+ ".a.b[c].d" -> ["", "a", "b", "c", "d"]
1002
+ "[a].b" -> ["a", "b"]
1003
+ """
1004
+ # Obtain the first part of the key (in docstring examples: a, a, .a, '')
1005
+ first = KEY_PATH_HEAD.match(key)
1006
+ assert first is not None
1007
+ first_stop = first.span()[1]
1008
+
1009
+ # `tokens` will contain all elements composing the key.
1010
+ tokens = key[0:first_stop].split(".")
1011
+
1012
+ # Optimization in case `key` has no other component: we are done.
1013
+ if first_stop == len(key):
1014
+ return tokens
1015
+
1016
+ if key[first_stop] == "[" and not tokens[-1]:
1017
+ # This is a special case where the first key starts with brackets, e.g.
1018
+ # [a] or ..[a]. In that case there is an extra "" in `tokens` that we
1019
+ # need to get rid of:
1020
+ # [a] -> tokens = [""] but we would like []
1021
+ # ..[a] -> tokens = ["", "", ""] but we would like ["", ""]
1022
+ tokens.pop()
1023
+
1024
+ # Identify other key elements (in docstring examples: b, b, b/c/d, b)
1025
+ others = KEY_PATH_OTHER.findall(key[first_stop:])
1026
+
1027
+ # There are two groups in the `KEY_PATH_OTHER` regex: one for keys starting
1028
+ # with a dot (.b, .d) and one for keys starting with a bracket ([b], [c]).
1029
+ # Only one group can be non-empty.
1030
+ tokens += [dot_key if dot_key else bracket_key for dot_key, bracket_key in others]
1031
+
1032
+ return tokens
1033
+
1034
+
1035
+ # Similar to Python 3.7+'s `contextlib.nullcontext` (which should be used instead,
1036
+ # once support for Python 3.6 is dropped).
1037
+ @contextmanager
1038
+ def nullcontext(enter_result: Any = None) -> Iterator[Any]:
1039
+ yield enter_result
omegaconf/base.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import sys
3
+ from abc import ABC, abstractmethod
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
8
+
9
+ from antlr4 import ParserRuleContext
10
+
11
+ from ._utils import (
12
+ _DEFAULT_MARKER_,
13
+ NoneType,
14
+ ValueKind,
15
+ _get_value,
16
+ _is_interpolation,
17
+ _is_missing_value,
18
+ _is_special,
19
+ format_and_raise,
20
+ get_value_kind,
21
+ is_union_annotation,
22
+ is_valid_value_annotation,
23
+ split_key,
24
+ type_str,
25
+ )
26
+ from .errors import (
27
+ ConfigKeyError,
28
+ ConfigTypeError,
29
+ InterpolationKeyError,
30
+ InterpolationResolutionError,
31
+ InterpolationToMissingValueError,
32
+ InterpolationValidationError,
33
+ MissingMandatoryValue,
34
+ UnsupportedInterpolationType,
35
+ ValidationError,
36
+ )
37
+ from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
38
+ from .grammar_parser import parse
39
+ from .grammar_visitor import GrammarVisitor
40
+
41
+ DictKeyType = Union[str, bytes, int, Enum, float, bool]
42
+
43
+
44
+ @dataclass
45
+ class Metadata:
46
+
47
+ ref_type: Union[Type[Any], Any]
48
+
49
+ object_type: Union[Type[Any], Any]
50
+
51
+ optional: bool
52
+
53
+ key: Any
54
+
55
+ # Flags have 3 modes:
56
+ # unset : inherit from parent (None if no parent specifies)
57
+ # set to true: flag is true
58
+ # set to false: flag is false
59
+ flags: Optional[Dict[str, bool]] = None
60
+
61
+ # If True, when checking the value of a flag, if the flag is not set None is returned
62
+ # otherwise, the parent node is queried.
63
+ flags_root: bool = False
64
+
65
+ resolver_cache: Dict[str, Any] = field(default_factory=lambda: defaultdict(dict))
66
+
67
+ def __post_init__(self) -> None:
68
+ if self.flags is None:
69
+ self.flags = {}
70
+
71
+ @property
72
+ def type_hint(self) -> Union[Type[Any], Any]:
73
+ """Compute `type_hint` from `self.optional` and `self.ref_type`"""
74
+ # For compatibility with pickled OmegaConf objects created using older
75
+ # versions of OmegaConf, we store `ref_type` and `object_type`
76
+ # separately (rather than storing `type_hint` directly).
77
+ if self.optional:
78
+ return Optional[self.ref_type]
79
+ else:
80
+ return self.ref_type
81
+
82
+
83
+ @dataclass
84
+ class ContainerMetadata(Metadata):
85
+ key_type: Any = None
86
+ element_type: Any = None
87
+
88
+ def __post_init__(self) -> None:
89
+ if self.ref_type is None:
90
+ self.ref_type = Any
91
+ assert self.key_type is Any or isinstance(self.key_type, type)
92
+ if self.element_type is not None:
93
+ if not is_valid_value_annotation(self.element_type):
94
+ raise ValidationError(
95
+ f"Unsupported value type: '{type_str(self.element_type, include_module_name=True)}'"
96
+ )
97
+
98
+ if self.flags is None:
99
+ self.flags = {}
100
+
101
+
102
+ class Node(ABC):
103
+ _metadata: Metadata
104
+
105
+ _parent: Optional["Box"]
106
+ _flags_cache: Optional[Dict[str, Optional[bool]]]
107
+
108
+ def __init__(self, parent: Optional["Box"], metadata: Metadata):
109
+ self.__dict__["_metadata"] = metadata
110
+ self.__dict__["_parent"] = parent
111
+ self.__dict__["_flags_cache"] = None
112
+
113
+ def __getstate__(self) -> Dict[str, Any]:
114
+ # Overridden to ensure that the flags cache is cleared on serialization.
115
+ state_dict = copy.copy(self.__dict__)
116
+ del state_dict["_flags_cache"]
117
+ return state_dict
118
+
119
+ def __setstate__(self, state_dict: Dict[str, Any]) -> None:
120
+ self.__dict__.update(state_dict)
121
+ self.__dict__["_flags_cache"] = None
122
+
123
+ def _set_parent(self, parent: Optional["Box"]) -> None:
124
+ assert parent is None or isinstance(parent, Box)
125
+ self.__dict__["_parent"] = parent
126
+ self._invalidate_flags_cache()
127
+
128
+ def _invalidate_flags_cache(self) -> None:
129
+ self.__dict__["_flags_cache"] = None
130
+
131
+ def _get_parent(self) -> Optional["Box"]:
132
+ parent = self.__dict__["_parent"]
133
+ assert parent is None or isinstance(parent, Box)
134
+ return parent
135
+
136
+ def _get_parent_container(self) -> Optional["Container"]:
137
+ """
138
+ Like _get_parent, but returns the grandparent
139
+ in the case where `self` is wrapped by a UnionNode.
140
+ """
141
+ parent = self.__dict__["_parent"]
142
+ assert parent is None or isinstance(parent, Box)
143
+
144
+ if isinstance(parent, UnionNode):
145
+ grandparent = parent.__dict__["_parent"]
146
+ assert grandparent is None or isinstance(grandparent, Container)
147
+ return grandparent
148
+ else:
149
+ assert parent is None or isinstance(parent, Container)
150
+ return parent
151
+
152
+ def _set_flag(
153
+ self,
154
+ flags: Union[List[str], str],
155
+ values: Union[List[Optional[bool]], Optional[bool]],
156
+ ) -> "Node":
157
+ if isinstance(flags, str):
158
+ flags = [flags]
159
+
160
+ if values is None or isinstance(values, bool):
161
+ values = [values]
162
+
163
+ if len(values) == 1:
164
+ values = len(flags) * values
165
+
166
+ if len(flags) != len(values):
167
+ raise ValueError("Inconsistent lengths of input flag names and values")
168
+
169
+ for idx, flag in enumerate(flags):
170
+ value = values[idx]
171
+ if value is None:
172
+ assert self._metadata.flags is not None
173
+ if flag in self._metadata.flags:
174
+ del self._metadata.flags[flag]
175
+ else:
176
+ assert self._metadata.flags is not None
177
+ self._metadata.flags[flag] = value
178
+ self._invalidate_flags_cache()
179
+ return self
180
+
181
+ def _get_node_flag(self, flag: str) -> Optional[bool]:
182
+ """
183
+ :param flag: flag to inspect
184
+ :return: the state of the flag on this node.
185
+ """
186
+ assert self._metadata.flags is not None
187
+ return self._metadata.flags.get(flag)
188
+
189
+ def _get_flag(self, flag: str) -> Optional[bool]:
190
+ cache = self.__dict__["_flags_cache"]
191
+ if cache is None:
192
+ cache = self.__dict__["_flags_cache"] = {}
193
+
194
+ ret = cache.get(flag, _DEFAULT_MARKER_)
195
+ if ret is _DEFAULT_MARKER_:
196
+ ret = self._get_flag_no_cache(flag)
197
+ cache[flag] = ret
198
+ assert ret is None or isinstance(ret, bool)
199
+ return ret
200
+
201
+ def _get_flag_no_cache(self, flag: str) -> Optional[bool]:
202
+ """
203
+ Returns True if this config node flag is set
204
+ A flag is set if node.set_flag(True) was called
205
+ or one if it's parents is flag is set
206
+ :return:
207
+ """
208
+ flags = self._metadata.flags
209
+ assert flags is not None
210
+ if flag in flags and flags[flag] is not None:
211
+ return flags[flag]
212
+
213
+ if self._is_flags_root():
214
+ return None
215
+
216
+ parent = self._get_parent()
217
+ if parent is None:
218
+ return None
219
+ else:
220
+ # noinspection PyProtectedMember
221
+ return parent._get_flag(flag)
222
+
223
+ def _format_and_raise(
224
+ self,
225
+ key: Any,
226
+ value: Any,
227
+ cause: Exception,
228
+ msg: Optional[str] = None,
229
+ type_override: Any = None,
230
+ ) -> None:
231
+ format_and_raise(
232
+ node=self,
233
+ key=key,
234
+ value=value,
235
+ msg=str(cause) if msg is None else msg,
236
+ cause=cause,
237
+ type_override=type_override,
238
+ )
239
+ assert False
240
+
241
+ @abstractmethod
242
+ def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
243
+ ...
244
+
245
+ def _dereference_node(self) -> "Node":
246
+ node = self._dereference_node_impl(throw_on_resolution_failure=True)
247
+ assert node is not None
248
+ return node
249
+
250
+ def _maybe_dereference_node(
251
+ self,
252
+ throw_on_resolution_failure: bool = False,
253
+ memo: Optional[Set[int]] = None,
254
+ ) -> Optional["Node"]:
255
+ return self._dereference_node_impl(
256
+ throw_on_resolution_failure=throw_on_resolution_failure,
257
+ memo=memo,
258
+ )
259
+
260
+ def _dereference_node_impl(
261
+ self,
262
+ throw_on_resolution_failure: bool,
263
+ memo: Optional[Set[int]] = None,
264
+ ) -> Optional["Node"]:
265
+ if not self._is_interpolation():
266
+ return self
267
+
268
+ parent = self._get_parent_container()
269
+ if parent is None:
270
+ if throw_on_resolution_failure:
271
+ raise InterpolationResolutionError(
272
+ "Cannot resolve interpolation for a node without a parent"
273
+ )
274
+ return None
275
+ assert parent is not None
276
+ key = self._key()
277
+ return parent._resolve_interpolation_from_parse_tree(
278
+ parent=parent,
279
+ key=key,
280
+ value=self,
281
+ parse_tree=parse(_get_value(self)),
282
+ throw_on_resolution_failure=throw_on_resolution_failure,
283
+ memo=memo,
284
+ )
285
+
286
+ def _get_root(self) -> "Container":
287
+ root: Optional[Box] = self._get_parent()
288
+ if root is None:
289
+ assert isinstance(self, Container)
290
+ return self
291
+ assert root is not None and isinstance(root, Box)
292
+ while root._get_parent() is not None:
293
+ root = root._get_parent()
294
+ assert root is not None and isinstance(root, Box)
295
+ assert root is not None and isinstance(root, Container)
296
+ return root
297
+
298
+ def _is_missing(self) -> bool:
299
+ """
300
+ Check if the node's value is `???` (does *not* resolve interpolations).
301
+ """
302
+ return _is_missing_value(self)
303
+
304
+ def _is_none(self) -> bool:
305
+ """
306
+ Check if the node's value is `None` (does *not* resolve interpolations).
307
+ """
308
+ return self._value() is None
309
+
310
+ @abstractmethod
311
+ def __eq__(self, other: Any) -> bool:
312
+ ...
313
+
314
+ @abstractmethod
315
+ def __ne__(self, other: Any) -> bool:
316
+ ...
317
+
318
+ @abstractmethod
319
+ def __hash__(self) -> int:
320
+ ...
321
+
322
+ @abstractmethod
323
+ def _value(self) -> Any:
324
+ ...
325
+
326
+ @abstractmethod
327
+ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
328
+ ...
329
+
330
+ @abstractmethod
331
+ def _is_optional(self) -> bool:
332
+ ...
333
+
334
+ @abstractmethod
335
+ def _is_interpolation(self) -> bool:
336
+ ...
337
+
338
+ def _key(self) -> Any:
339
+ return self._metadata.key
340
+
341
+ def _set_key(self, key: Any) -> None:
342
+ self._metadata.key = key
343
+
344
+ def _is_flags_root(self) -> bool:
345
+ return self._metadata.flags_root
346
+
347
+ def _set_flags_root(self, flags_root: bool) -> None:
348
+ if self._metadata.flags_root != flags_root:
349
+ self._metadata.flags_root = flags_root
350
+ self._invalidate_flags_cache()
351
+
352
+ def _has_ref_type(self) -> bool:
353
+ return self._metadata.ref_type is not Any
354
+
355
+
356
+ class Box(Node):
357
+ """
358
+ Base class for nodes that can contain other nodes.
359
+ Concrete subclasses include DictConfig, ListConfig, and UnionNode.
360
+ """
361
+
362
+ _content: Any
363
+
364
+ def __init__(self, parent: Optional["Box"], metadata: Metadata):
365
+ super().__init__(parent=parent, metadata=metadata)
366
+ self.__dict__["_content"] = None
367
+
368
+ def __copy__(self) -> Any:
369
+ # real shallow copy is impossible because of the reference to the parent.
370
+ return copy.deepcopy(self)
371
+
372
+ def _re_parent(self) -> None:
373
+ from .dictconfig import DictConfig
374
+ from .listconfig import ListConfig
375
+
376
+ # update parents of first level Config nodes to self
377
+
378
+ if isinstance(self, DictConfig):
379
+ content = self.__dict__["_content"]
380
+ if isinstance(content, dict):
381
+ for _key, value in self.__dict__["_content"].items():
382
+ if value is not None:
383
+ value._set_parent(self)
384
+ if isinstance(value, Box):
385
+ value._re_parent()
386
+ elif isinstance(self, ListConfig):
387
+ content = self.__dict__["_content"]
388
+ if isinstance(content, list):
389
+ for item in self.__dict__["_content"]:
390
+ if item is not None:
391
+ item._set_parent(self)
392
+ if isinstance(item, Box):
393
+ item._re_parent()
394
+ elif isinstance(self, UnionNode):
395
+ content = self.__dict__["_content"]
396
+ if isinstance(content, Node):
397
+ content._set_parent(self)
398
+ if isinstance(content, Box): # pragma: no cover
399
+ # No coverage here as support for containers inside
400
+ # UnionNode is not yet implemented
401
+ content._re_parent()
402
+
403
+
404
+ class Container(Box):
405
+ """
406
+ Container tagging interface
407
+ """
408
+
409
+ _metadata: ContainerMetadata
410
+
411
+ @abstractmethod
412
+ def _get_child(
413
+ self,
414
+ key: Any,
415
+ validate_access: bool = True,
416
+ validate_key: bool = True,
417
+ throw_on_missing_value: bool = False,
418
+ throw_on_missing_key: bool = False,
419
+ ) -> Union[Optional[Node], List[Optional[Node]]]:
420
+ ...
421
+
422
+ @abstractmethod
423
+ def _get_node(
424
+ self,
425
+ key: Any,
426
+ validate_access: bool = True,
427
+ validate_key: bool = True,
428
+ throw_on_missing_value: bool = False,
429
+ throw_on_missing_key: bool = False,
430
+ ) -> Union[Optional[Node], List[Optional[Node]]]:
431
+ ...
432
+
433
+ @abstractmethod
434
+ def __delitem__(self, key: Any) -> None:
435
+ ...
436
+
437
+ @abstractmethod
438
+ def __setitem__(self, key: Any, value: Any) -> None:
439
+ ...
440
+
441
+ @abstractmethod
442
+ def __iter__(self) -> Iterator[Any]:
443
+ ...
444
+
445
+ @abstractmethod
446
+ def __getitem__(self, key_or_index: Any) -> Any:
447
+ ...
448
+
449
+ def _resolve_key_and_root(self, key: str) -> Tuple["Container", str]:
450
+ orig = key
451
+ if not key.startswith("."):
452
+ return self._get_root(), key
453
+ else:
454
+ root: Optional[Container] = self
455
+ assert key.startswith(".")
456
+ while True:
457
+ assert root is not None
458
+ key = key[1:]
459
+ if not key.startswith("."):
460
+ break
461
+ root = root._get_parent_container()
462
+ if root is None:
463
+ raise ConfigKeyError(f"Error resolving key '{orig}'")
464
+
465
+ return root, key
466
+
467
+ def _select_impl(
468
+ self,
469
+ key: str,
470
+ throw_on_missing: bool,
471
+ throw_on_resolution_failure: bool,
472
+ memo: Optional[Set[int]] = None,
473
+ ) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]:
474
+ """
475
+ Select a value using dot separated key sequence
476
+ """
477
+ from .omegaconf import _select_one
478
+
479
+ if key == "":
480
+ return self, "", self
481
+
482
+ split = split_key(key)
483
+ root: Optional[Container] = self
484
+ for i in range(len(split) - 1):
485
+ if root is None:
486
+ break
487
+
488
+ k = split[i]
489
+ ret, _ = _select_one(
490
+ c=root,
491
+ key=k,
492
+ throw_on_missing=throw_on_missing,
493
+ throw_on_type_error=throw_on_resolution_failure,
494
+ )
495
+ if isinstance(ret, Node):
496
+ ret = ret._maybe_dereference_node(
497
+ throw_on_resolution_failure=throw_on_resolution_failure,
498
+ memo=memo,
499
+ )
500
+
501
+ if ret is not None and not isinstance(ret, Container):
502
+ parent_key = ".".join(split[0 : i + 1])
503
+ child_key = split[i + 1]
504
+ raise ConfigTypeError(
505
+ f"Error trying to access {key}: node `{parent_key}` "
506
+ f"is not a container and thus cannot contain `{child_key}`"
507
+ )
508
+ root = ret
509
+
510
+ if root is None:
511
+ return None, None, None
512
+
513
+ last_key = split[-1]
514
+ value, _ = _select_one(
515
+ c=root,
516
+ key=last_key,
517
+ throw_on_missing=throw_on_missing,
518
+ throw_on_type_error=throw_on_resolution_failure,
519
+ )
520
+ if value is None:
521
+ return root, last_key, None
522
+
523
+ if memo is not None:
524
+ vid = id(value)
525
+ if vid in memo:
526
+ raise InterpolationResolutionError("Recursive interpolation detected")
527
+ # push to memo "stack"
528
+ memo.add(vid)
529
+
530
+ try:
531
+ value = root._maybe_resolve_interpolation(
532
+ parent=root,
533
+ key=last_key,
534
+ value=value,
535
+ throw_on_resolution_failure=throw_on_resolution_failure,
536
+ memo=memo,
537
+ )
538
+ finally:
539
+ if memo is not None:
540
+ # pop from memo "stack"
541
+ memo.remove(vid)
542
+
543
+ return root, last_key, value
544
+
545
+ def _resolve_interpolation_from_parse_tree(
546
+ self,
547
+ parent: Optional["Container"],
548
+ value: "Node",
549
+ key: Any,
550
+ parse_tree: OmegaConfGrammarParser.ConfigValueContext,
551
+ throw_on_resolution_failure: bool,
552
+ memo: Optional[Set[int]],
553
+ ) -> Optional["Node"]:
554
+ """
555
+ Resolve an interpolation.
556
+
557
+ This happens in two steps:
558
+ 1. The parse tree is visited, which outputs either a `Node` (e.g.,
559
+ for node interpolations "${foo}"), a string (e.g., for string
560
+ interpolations "hello ${name}", or any other arbitrary value
561
+ (e.g., or custom interpolations "${foo:bar}").
562
+ 2. This output is potentially validated and converted when the node
563
+ being resolved (`value`) is typed.
564
+
565
+ If an error occurs in one of the above steps, an `InterpolationResolutionError`
566
+ (or a subclass of it) is raised, *unless* `throw_on_resolution_failure` is set
567
+ to `False` (in which case the return value is `None`).
568
+
569
+ :param parent: Parent of the node being resolved.
570
+ :param value: Node being resolved.
571
+ :param key: The associated key in the parent.
572
+ :param parse_tree: The parse tree as obtained from `grammar_parser.parse()`.
573
+ :param throw_on_resolution_failure: If `False`, then exceptions raised during
574
+ the resolution of the interpolation are silenced, and instead `None` is
575
+ returned.
576
+
577
+ :return: A `Node` that contains the interpolation result. This may be an existing
578
+ node in the config (in the case of a node interpolation "${foo}"), or a new
579
+ node that is created to wrap the interpolated value. It is `None` if and only if
580
+ `throw_on_resolution_failure` is `False` and an error occurs during resolution.
581
+ """
582
+
583
+ try:
584
+ resolved = self.resolve_parse_tree(
585
+ parse_tree=parse_tree, node=value, key=key, memo=memo
586
+ )
587
+ except InterpolationResolutionError:
588
+ if throw_on_resolution_failure:
589
+ raise
590
+ return None
591
+
592
+ return self._validate_and_convert_interpolation_result(
593
+ parent=parent,
594
+ value=value,
595
+ key=key,
596
+ resolved=resolved,
597
+ throw_on_resolution_failure=throw_on_resolution_failure,
598
+ )
599
+
600
+ def _validate_and_convert_interpolation_result(
601
+ self,
602
+ parent: Optional["Container"],
603
+ value: "Node",
604
+ key: Any,
605
+ resolved: Any,
606
+ throw_on_resolution_failure: bool,
607
+ ) -> Optional["Node"]:
608
+ from .nodes import AnyNode, InterpolationResultNode, ValueNode
609
+
610
+ # If the output is not a Node already (e.g., because it is the output of a
611
+ # custom resolver), then we will need to wrap it within a Node.
612
+ must_wrap = not isinstance(resolved, Node)
613
+
614
+ # If the node is typed, validate (and possibly convert) the result.
615
+ if isinstance(value, ValueNode) and not isinstance(value, AnyNode):
616
+ res_value = _get_value(resolved)
617
+ try:
618
+ conv_value = value.validate_and_convert(res_value)
619
+ except ValidationError as e:
620
+ if throw_on_resolution_failure:
621
+ self._format_and_raise(
622
+ key=key,
623
+ value=res_value,
624
+ cause=e,
625
+ msg=f"While dereferencing interpolation '{value}': {e}",
626
+ type_override=InterpolationValidationError,
627
+ )
628
+ return None
629
+
630
+ # If the converted value is of the same type, it means that no conversion
631
+ # was actually needed. As a result, we can keep the original `resolved`
632
+ # (and otherwise, the converted value must be wrapped into a new node).
633
+ if type(conv_value) != type(res_value):
634
+ must_wrap = True
635
+ resolved = conv_value
636
+
637
+ if must_wrap:
638
+ return InterpolationResultNode(value=resolved, key=key, parent=parent)
639
+ else:
640
+ assert isinstance(resolved, Node)
641
+ return resolved
642
+
643
+ def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None:
644
+ parent: Optional[Node] = node
645
+ while parent is not None:
646
+ if parent is target:
647
+ raise InterpolationResolutionError(
648
+ "Interpolation to parent node detected"
649
+ )
650
+ parent = parent._get_parent()
651
+
652
+ def _resolve_node_interpolation(
653
+ self, inter_key: str, memo: Optional[Set[int]]
654
+ ) -> "Node":
655
+ """A node interpolation is of the form `${foo.bar}`"""
656
+ try:
657
+ root_node, inter_key = self._resolve_key_and_root(inter_key)
658
+ except ConfigKeyError as exc:
659
+ raise InterpolationKeyError(
660
+ f"ConfigKeyError while resolving interpolation: {exc}"
661
+ ).with_traceback(sys.exc_info()[2])
662
+
663
+ try:
664
+ parent, last_key, value = root_node._select_impl(
665
+ inter_key,
666
+ throw_on_missing=True,
667
+ throw_on_resolution_failure=True,
668
+ memo=memo,
669
+ )
670
+ except MissingMandatoryValue as exc:
671
+ raise InterpolationToMissingValueError(
672
+ f"MissingMandatoryValue while resolving interpolation: {exc}"
673
+ ).with_traceback(sys.exc_info()[2])
674
+
675
+ if parent is None or value is None:
676
+ raise InterpolationKeyError(f"Interpolation key '{inter_key}' not found")
677
+ else:
678
+ self._validate_not_dereferencing_to_parent(node=self, target=value)
679
+ return value
680
+
681
+ def _evaluate_custom_resolver(
682
+ self,
683
+ key: Any,
684
+ node: Node,
685
+ inter_type: str,
686
+ inter_args: Tuple[Any, ...],
687
+ inter_args_str: Tuple[str, ...],
688
+ ) -> Any:
689
+ from omegaconf import OmegaConf
690
+
691
+ resolver = OmegaConf._get_resolver(inter_type)
692
+ if resolver is not None:
693
+ root_node = self._get_root()
694
+ return resolver(
695
+ root_node,
696
+ self,
697
+ node,
698
+ inter_args,
699
+ inter_args_str,
700
+ )
701
+ else:
702
+ raise UnsupportedInterpolationType(
703
+ f"Unsupported interpolation type {inter_type}"
704
+ )
705
+
706
+ def _maybe_resolve_interpolation(
707
+ self,
708
+ parent: Optional["Container"],
709
+ key: Any,
710
+ value: Node,
711
+ throw_on_resolution_failure: bool,
712
+ memo: Optional[Set[int]] = None,
713
+ ) -> Optional[Node]:
714
+ value_kind = get_value_kind(value)
715
+ if value_kind != ValueKind.INTERPOLATION:
716
+ return value
717
+
718
+ parse_tree = parse(_get_value(value))
719
+ return self._resolve_interpolation_from_parse_tree(
720
+ parent=parent,
721
+ value=value,
722
+ key=key,
723
+ parse_tree=parse_tree,
724
+ throw_on_resolution_failure=throw_on_resolution_failure,
725
+ memo=memo if memo is not None else set(),
726
+ )
727
+
728
+ def resolve_parse_tree(
729
+ self,
730
+ parse_tree: ParserRuleContext,
731
+ node: Node,
732
+ memo: Optional[Set[int]] = None,
733
+ key: Optional[Any] = None,
734
+ ) -> Any:
735
+ """
736
+ Resolve a given parse tree into its value.
737
+
738
+ We make no assumption here on the type of the tree's root, so that the
739
+ return value may be of any type.
740
+ """
741
+
742
+ def node_interpolation_callback(
743
+ inter_key: str, memo: Optional[Set[int]]
744
+ ) -> Optional["Node"]:
745
+ return self._resolve_node_interpolation(inter_key=inter_key, memo=memo)
746
+
747
+ def resolver_interpolation_callback(
748
+ name: str, args: Tuple[Any, ...], args_str: Tuple[str, ...]
749
+ ) -> Any:
750
+ return self._evaluate_custom_resolver(
751
+ key=key,
752
+ node=node,
753
+ inter_type=name,
754
+ inter_args=args,
755
+ inter_args_str=args_str,
756
+ )
757
+
758
+ visitor = GrammarVisitor(
759
+ node_interpolation_callback=node_interpolation_callback,
760
+ resolver_interpolation_callback=resolver_interpolation_callback,
761
+ memo=memo,
762
+ )
763
+ try:
764
+ return visitor.visit(parse_tree)
765
+ except InterpolationResolutionError:
766
+ raise
767
+ except Exception as exc:
768
+ # Other kinds of exceptions are wrapped in an `InterpolationResolutionError`.
769
+ raise InterpolationResolutionError(
770
+ f"{type(exc).__name__} raised while resolving interpolation: {exc}"
771
+ ).with_traceback(sys.exc_info()[2])
772
+
773
+ def _invalidate_flags_cache(self) -> None:
774
+ from .dictconfig import DictConfig
775
+ from .listconfig import ListConfig
776
+
777
+ # invalidate subtree cache only if the cache is initialized in this node.
778
+
779
+ if self.__dict__["_flags_cache"] is not None:
780
+ self.__dict__["_flags_cache"] = None
781
+ if isinstance(self, DictConfig):
782
+ content = self.__dict__["_content"]
783
+ if isinstance(content, dict):
784
+ for value in self.__dict__["_content"].values():
785
+ value._invalidate_flags_cache()
786
+ elif isinstance(self, ListConfig):
787
+ content = self.__dict__["_content"]
788
+ if isinstance(content, list):
789
+ for item in self.__dict__["_content"]:
790
+ item._invalidate_flags_cache()
791
+
792
+
793
+ class SCMode(Enum):
794
+ DICT = 1 # Convert to plain dict
795
+ DICT_CONFIG = 2 # Keep as OmegaConf DictConfig
796
+ INSTANTIATE = 3 # Create a dataclass or attrs class instance
797
+
798
+
799
+ class UnionNode(Box):
800
+ """
801
+ This class handles Union type hints. The `_content` attribute is either a
802
+ child node that is compatible with the given Union ref_type, or it is a
803
+ special value (None or MISSING or interpolation).
804
+
805
+ Much of the logic for e.g. value assignment and type validation is
806
+ delegated to the child node. As such, UnionNode functions as a
807
+ "pass-through" node. User apps and downstream libraries should not need to
808
+ know about UnionNode (assuming they only use OmegaConf's public API).
809
+ """
810
+
811
+ _parent: Optional[Container]
812
+ _content: Union[Node, None, str]
813
+
814
+ def __init__(
815
+ self,
816
+ content: Any,
817
+ ref_type: Any,
818
+ is_optional: bool = True,
819
+ key: Any = None,
820
+ parent: Optional[Box] = None,
821
+ ) -> None:
822
+ try:
823
+ if not is_union_annotation(ref_type): # pragma: no cover
824
+ msg = (
825
+ f"UnionNode got unexpected ref_type {ref_type}. Please file a bug"
826
+ + " report at https://github.com/omry/omegaconf/issues"
827
+ )
828
+ raise AssertionError(msg)
829
+ if not isinstance(parent, (Container, NoneType)):
830
+ raise ConfigTypeError("Parent type is not omegaconf.Container")
831
+ super().__init__(
832
+ parent=parent,
833
+ metadata=Metadata(
834
+ ref_type=ref_type,
835
+ object_type=None,
836
+ optional=is_optional,
837
+ key=key,
838
+ flags={"convert": False},
839
+ ),
840
+ )
841
+ self._set_value(content)
842
+ except Exception as ex:
843
+ format_and_raise(node=None, key=key, value=content, msg=str(ex), cause=ex)
844
+
845
+ def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
846
+ parent = self._get_parent()
847
+ if parent is None:
848
+ if self._metadata.key is None:
849
+ return ""
850
+ else:
851
+ return str(self._metadata.key)
852
+ else:
853
+ return parent._get_full_key(self._metadata.key)
854
+
855
+ def __eq__(self, other: Any) -> bool:
856
+ content = self.__dict__["_content"]
857
+ if isinstance(content, Node):
858
+ ret = content.__eq__(other)
859
+ elif isinstance(other, Node):
860
+ ret = other.__eq__(content)
861
+ else:
862
+ ret = content.__eq__(other)
863
+ assert isinstance(ret, (bool, type(NotImplemented)))
864
+ return ret
865
+
866
+ def __ne__(self, other: Any) -> bool:
867
+ x = self.__eq__(other)
868
+ if x is NotImplemented:
869
+ return NotImplemented
870
+ return not x
871
+
872
+ def __hash__(self) -> int:
873
+ return hash(self.__dict__["_content"])
874
+
875
+ def _value(self) -> Union[Node, None, str]:
876
+ content = self.__dict__["_content"]
877
+ assert isinstance(content, (Node, NoneType, str))
878
+ return content
879
+
880
+ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
881
+ previous_content = self.__dict__["_content"]
882
+ previous_metadata = self.__dict__["_metadata"]
883
+ try:
884
+ self._set_value_impl(value, flags)
885
+ except Exception as e:
886
+ self.__dict__["_content"] = previous_content
887
+ self.__dict__["_metadata"] = previous_metadata
888
+ raise e
889
+
890
+ def _set_value_impl(
891
+ self, value: Any, flags: Optional[Dict[str, bool]] = None
892
+ ) -> None:
893
+ from omegaconf.omegaconf import _node_wrap
894
+
895
+ ref_type = self._metadata.ref_type
896
+ type_hint = self._metadata.type_hint
897
+
898
+ value = _get_value(value)
899
+ if _is_special(value):
900
+ assert isinstance(value, (str, NoneType))
901
+ if value is None:
902
+ if not self._is_optional():
903
+ raise ValidationError(
904
+ f"Value '$VALUE' is incompatible with type hint '{type_str(type_hint)}'"
905
+ )
906
+ self.__dict__["_content"] = value
907
+ elif isinstance(value, Container):
908
+ raise ValidationError(
909
+ f"Cannot assign container '$VALUE' of type '$VALUE_TYPE' to {type_str(type_hint)}"
910
+ )
911
+ else:
912
+ for candidate_ref_type in ref_type.__args__:
913
+ try:
914
+ self.__dict__["_content"] = _node_wrap(
915
+ value=value,
916
+ ref_type=candidate_ref_type,
917
+ is_optional=False,
918
+ key=None,
919
+ parent=self,
920
+ )
921
+ break
922
+ except ValidationError:
923
+ continue
924
+ else:
925
+ raise ValidationError(
926
+ f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_str(type_hint)}'"
927
+ )
928
+
929
+ def _is_optional(self) -> bool:
930
+ return self.__dict__["_metadata"].optional is True
931
+
932
+ def _is_interpolation(self) -> bool:
933
+ return _is_interpolation(self.__dict__["_content"])
934
+
935
+ def __str__(self) -> str:
936
+ return str(self.__dict__["_content"])
937
+
938
+ def __repr__(self) -> str:
939
+ return repr(self.__dict__["_content"])
940
+
941
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "UnionNode":
942
+ res = object.__new__(type(self))
943
+ for key, value in self.__dict__.items():
944
+ if key not in ("_content", "_parent"):
945
+ res.__dict__[key] = copy.deepcopy(value, memo=memo)
946
+
947
+ src_content = self.__dict__["_content"]
948
+ if isinstance(src_content, Node):
949
+ old_parent = src_content.__dict__["_parent"]
950
+ try:
951
+ src_content.__dict__["_parent"] = None
952
+ content_copy = copy.deepcopy(src_content, memo=memo)
953
+ content_copy.__dict__["_parent"] = res
954
+ finally:
955
+ src_content.__dict__["_parent"] = old_parent
956
+ else:
957
+ # None and strings can be assigned as is
958
+ content_copy = src_content
959
+
960
+ res.__dict__["_content"] = content_copy
961
+ res.__dict__["_parent"] = self.__dict__["_parent"]
962
+ return res
omegaconf/basecontainer.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import sys
3
+ from abc import ABC, abstractmethod
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union
6
+
7
+ import yaml
8
+
9
+ from ._utils import (
10
+ _DEFAULT_MARKER_,
11
+ ValueKind,
12
+ _ensure_container,
13
+ _get_value,
14
+ _is_interpolation,
15
+ _is_missing_value,
16
+ _is_none,
17
+ _is_special,
18
+ _resolve_optional,
19
+ get_structured_config_data,
20
+ get_type_hint,
21
+ get_value_kind,
22
+ get_yaml_loader,
23
+ is_container_annotation,
24
+ is_dict_annotation,
25
+ is_list_annotation,
26
+ is_primitive_dict,
27
+ is_primitive_type_annotation,
28
+ is_structured_config,
29
+ is_tuple_annotation,
30
+ is_union_annotation,
31
+ )
32
+ from .base import (
33
+ Box,
34
+ Container,
35
+ ContainerMetadata,
36
+ DictKeyType,
37
+ Node,
38
+ SCMode,
39
+ UnionNode,
40
+ )
41
+ from .errors import (
42
+ ConfigCycleDetectedException,
43
+ ConfigTypeError,
44
+ InterpolationResolutionError,
45
+ KeyValidationError,
46
+ MissingMandatoryValue,
47
+ OmegaConfBaseException,
48
+ ReadonlyConfigError,
49
+ ValidationError,
50
+ )
51
+
52
+ if TYPE_CHECKING:
53
+ from .dictconfig import DictConfig # pragma: no cover
54
+
55
+
56
+ class BaseContainer(Container, ABC):
57
+ _resolvers: ClassVar[Dict[str, Any]] = {}
58
+
59
+ def __init__(self, parent: Optional[Box], metadata: ContainerMetadata):
60
+ if not (parent is None or isinstance(parent, Box)):
61
+ raise ConfigTypeError("Parent type is not omegaconf.Box")
62
+ super().__init__(parent=parent, metadata=metadata)
63
+
64
+ def _get_child(
65
+ self,
66
+ key: Any,
67
+ validate_access: bool = True,
68
+ validate_key: bool = True,
69
+ throw_on_missing_value: bool = False,
70
+ throw_on_missing_key: bool = False,
71
+ ) -> Union[Optional[Node], List[Optional[Node]]]:
72
+ """Like _get_node, passing through to the nearest concrete Node."""
73
+ child = self._get_node(
74
+ key=key,
75
+ validate_access=validate_access,
76
+ validate_key=validate_key,
77
+ throw_on_missing_value=throw_on_missing_value,
78
+ throw_on_missing_key=throw_on_missing_key,
79
+ )
80
+ if isinstance(child, UnionNode) and not _is_special(child):
81
+ value = child._value()
82
+ assert isinstance(value, Node) and not isinstance(value, UnionNode)
83
+ child = value
84
+ return child
85
+
86
+ def _resolve_with_default(
87
+ self,
88
+ key: Union[DictKeyType, int],
89
+ value: Node,
90
+ default_value: Any = _DEFAULT_MARKER_,
91
+ ) -> Any:
92
+ """returns the value with the specified key, like obj.key and obj['key']"""
93
+ if _is_missing_value(value):
94
+ if default_value is not _DEFAULT_MARKER_:
95
+ return default_value
96
+ raise MissingMandatoryValue("Missing mandatory value: $FULL_KEY")
97
+
98
+ resolved_node = self._maybe_resolve_interpolation(
99
+ parent=self,
100
+ key=key,
101
+ value=value,
102
+ throw_on_resolution_failure=True,
103
+ )
104
+
105
+ return _get_value(resolved_node)
106
+
107
+ def __str__(self) -> str:
108
+ return self.__repr__()
109
+
110
+ def __repr__(self) -> str:
111
+ if self.__dict__["_content"] is None:
112
+ return "None"
113
+ elif self._is_interpolation() or self._is_missing():
114
+ v = self.__dict__["_content"]
115
+ return f"'{v}'"
116
+ else:
117
+ return self.__dict__["_content"].__repr__() # type: ignore
118
+
119
+ # Support pickle
120
+ def __getstate__(self) -> Dict[str, Any]:
121
+ dict_copy = copy.copy(self.__dict__)
122
+
123
+ # no need to serialize the flags cache, it can be re-constructed later
124
+ dict_copy.pop("_flags_cache", None)
125
+
126
+ dict_copy["_metadata"] = copy.copy(dict_copy["_metadata"])
127
+ ref_type = self._metadata.ref_type
128
+ if is_container_annotation(ref_type):
129
+ if is_dict_annotation(ref_type):
130
+ dict_copy["_metadata"].ref_type = Dict
131
+ elif is_list_annotation(ref_type):
132
+ dict_copy["_metadata"].ref_type = List
133
+ else:
134
+ assert False
135
+ if sys.version_info < (3, 7): # pragma: no cover
136
+ element_type = self._metadata.element_type
137
+ if is_union_annotation(element_type):
138
+ raise OmegaConfBaseException(
139
+ "Serializing structured configs with `Union` element type requires python >= 3.7"
140
+ )
141
+ return dict_copy
142
+
143
+ # Support pickle
144
+ def __setstate__(self, d: Dict[str, Any]) -> None:
145
+ from omegaconf import DictConfig
146
+ from omegaconf._utils import is_generic_dict, is_generic_list
147
+
148
+ if isinstance(self, DictConfig):
149
+ key_type = d["_metadata"].key_type
150
+
151
+ # backward compatibility to load OmegaConf 2.0 configs
152
+ if key_type is None:
153
+ key_type = Any
154
+ d["_metadata"].key_type = key_type
155
+
156
+ element_type = d["_metadata"].element_type
157
+
158
+ # backward compatibility to load OmegaConf 2.0 configs
159
+ if element_type is None:
160
+ element_type = Any
161
+ d["_metadata"].element_type = element_type
162
+
163
+ ref_type = d["_metadata"].ref_type
164
+ if is_container_annotation(ref_type):
165
+ if is_generic_dict(ref_type):
166
+ d["_metadata"].ref_type = Dict[key_type, element_type] # type: ignore
167
+ elif is_generic_list(ref_type):
168
+ d["_metadata"].ref_type = List[element_type] # type: ignore
169
+ else:
170
+ assert False
171
+
172
+ d["_flags_cache"] = None
173
+ self.__dict__.update(d)
174
+
175
+ @abstractmethod
176
+ def __delitem__(self, key: Any) -> None:
177
+ ...
178
+
179
+ def __len__(self) -> int:
180
+ if self._is_none() or self._is_missing() or self._is_interpolation():
181
+ return 0
182
+ content = self.__dict__["_content"]
183
+ return len(content)
184
+
185
+ def merge_with_cli(self) -> None:
186
+ args_list = sys.argv[1:]
187
+ self.merge_with_dotlist(args_list)
188
+
189
+ def merge_with_dotlist(self, dotlist: List[str]) -> None:
190
+ from omegaconf import OmegaConf
191
+
192
+ def fail() -> None:
193
+ raise ValueError("Input list must be a list or a tuple of strings")
194
+
195
+ if not isinstance(dotlist, (list, tuple)):
196
+ fail()
197
+
198
+ for arg in dotlist:
199
+ if not isinstance(arg, str):
200
+ fail()
201
+
202
+ idx = arg.find("=")
203
+ if idx == -1:
204
+ key = arg
205
+ value = None
206
+ else:
207
+ key = arg[0:idx]
208
+ value = arg[idx + 1 :]
209
+ value = yaml.load(value, Loader=get_yaml_loader())
210
+
211
+ OmegaConf.update(self, key, value)
212
+
213
+ def is_empty(self) -> bool:
214
+ """return true if config is empty"""
215
+ return len(self.__dict__["_content"]) == 0
216
+
217
+ @staticmethod
218
+ def _to_content(
219
+ conf: Container,
220
+ resolve: bool,
221
+ throw_on_missing: bool,
222
+ enum_to_str: bool = False,
223
+ structured_config_mode: SCMode = SCMode.DICT,
224
+ ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]:
225
+ from omegaconf import MISSING, DictConfig, ListConfig
226
+
227
+ def convert(val: Node) -> Any:
228
+ value = val._value()
229
+ if enum_to_str and isinstance(value, Enum):
230
+ value = f"{value.name}"
231
+
232
+ return value
233
+
234
+ def get_node_value(key: Union[DictKeyType, int]) -> Any:
235
+ try:
236
+ node = conf._get_child(key, throw_on_missing_value=throw_on_missing)
237
+ except MissingMandatoryValue as e:
238
+ conf._format_and_raise(key=key, value=None, cause=e)
239
+ assert isinstance(node, Node)
240
+ if resolve:
241
+ try:
242
+ node = node._dereference_node()
243
+ except InterpolationResolutionError as e:
244
+ conf._format_and_raise(key=key, value=None, cause=e)
245
+
246
+ if isinstance(node, Container):
247
+ value = BaseContainer._to_content(
248
+ node,
249
+ resolve=resolve,
250
+ throw_on_missing=throw_on_missing,
251
+ enum_to_str=enum_to_str,
252
+ structured_config_mode=structured_config_mode,
253
+ )
254
+ else:
255
+ value = convert(node)
256
+ return value
257
+
258
+ if conf._is_none():
259
+ return None
260
+ elif conf._is_missing():
261
+ if throw_on_missing:
262
+ conf._format_and_raise(
263
+ key=None,
264
+ value=None,
265
+ cause=MissingMandatoryValue("Missing mandatory value"),
266
+ )
267
+ else:
268
+ return MISSING
269
+ elif not resolve and conf._is_interpolation():
270
+ inter = conf._value()
271
+ assert isinstance(inter, str)
272
+ return inter
273
+
274
+ if resolve:
275
+ _conf = conf._dereference_node()
276
+ assert isinstance(_conf, Container)
277
+ conf = _conf
278
+
279
+ if isinstance(conf, DictConfig):
280
+ if (
281
+ conf._metadata.object_type not in (dict, None)
282
+ and structured_config_mode == SCMode.DICT_CONFIG
283
+ ):
284
+ return conf
285
+ if structured_config_mode == SCMode.INSTANTIATE and is_structured_config(
286
+ conf._metadata.object_type
287
+ ):
288
+ return conf._to_object()
289
+
290
+ retdict: Dict[DictKeyType, Any] = {}
291
+ for key in conf.keys():
292
+ value = get_node_value(key)
293
+ if enum_to_str and isinstance(key, Enum):
294
+ key = f"{key.name}"
295
+ retdict[key] = value
296
+ return retdict
297
+ elif isinstance(conf, ListConfig):
298
+ retlist: List[Any] = []
299
+ for index in range(len(conf)):
300
+ item = get_node_value(index)
301
+ retlist.append(item)
302
+
303
+ return retlist
304
+ assert False
305
+
306
+ @staticmethod
307
+ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
308
+ """merge src into dest and return a new copy, does not modified input"""
309
+ from omegaconf import AnyNode, DictConfig, ValueNode
310
+
311
+ assert isinstance(dest, DictConfig)
312
+ assert isinstance(src, DictConfig)
313
+ src_type = src._metadata.object_type
314
+ src_ref_type = get_type_hint(src)
315
+ assert src_ref_type is not None
316
+
317
+ # If source DictConfig is:
318
+ # - None => set the destination DictConfig to None
319
+ # - an interpolation => set the destination DictConfig to be the same interpolation
320
+ if src._is_none() or src._is_interpolation():
321
+ dest._set_value(src._value())
322
+ _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
323
+ return
324
+
325
+ dest._validate_merge(value=src)
326
+
327
+ def expand(node: Container) -> None:
328
+ rt = node._metadata.ref_type
329
+ val: Any
330
+ if rt is not Any:
331
+ if is_dict_annotation(rt):
332
+ val = {}
333
+ elif is_list_annotation(rt) or is_tuple_annotation(rt):
334
+ val = []
335
+ else:
336
+ val = rt
337
+ elif isinstance(node, DictConfig):
338
+ val = {}
339
+ else:
340
+ assert False
341
+
342
+ node._set_value(val)
343
+
344
+ if (
345
+ src._is_missing()
346
+ and not dest._is_missing()
347
+ and is_structured_config(src_ref_type)
348
+ ):
349
+ # Replace `src` with a prototype of its corresponding structured config
350
+ # whose fields are all missing (to avoid overwriting fields in `dest`).
351
+ assert src_type is None # src missing, so src's object_type should be None
352
+ src_type = src_ref_type
353
+ src = _create_structured_with_missing_fields(
354
+ ref_type=src_ref_type, object_type=src_type
355
+ )
356
+
357
+ if (dest._is_interpolation() or dest._is_missing()) and not src._is_missing():
358
+ expand(dest)
359
+
360
+ src_items = list(src) if not src._is_missing() else []
361
+ for key in src_items:
362
+ src_node = src._get_node(key, validate_access=False)
363
+ dest_node = dest._get_node(key, validate_access=False)
364
+ assert isinstance(src_node, Node)
365
+ assert dest_node is None or isinstance(dest_node, Node)
366
+ src_value = _get_value(src_node)
367
+
368
+ src_vk = get_value_kind(src_node)
369
+ src_node_missing = src_vk is ValueKind.MANDATORY_MISSING
370
+
371
+ if isinstance(dest_node, DictConfig):
372
+ dest_node._validate_merge(value=src_node)
373
+
374
+ if (
375
+ isinstance(dest_node, Container)
376
+ and dest_node._is_none()
377
+ and not src_node_missing
378
+ and not _is_none(src_node, resolve=True)
379
+ ):
380
+ expand(dest_node)
381
+
382
+ if dest_node is not None and dest_node._is_interpolation():
383
+ target_node = dest_node._maybe_dereference_node()
384
+ if isinstance(target_node, Container):
385
+ dest[key] = target_node
386
+ dest_node = dest._get_node(key)
387
+
388
+ is_optional, et = _resolve_optional(dest._metadata.element_type)
389
+ if dest_node is None and is_structured_config(et) and not src_node_missing:
390
+ # merging into a new node. Use element_type as a base
391
+ dest[key] = DictConfig(
392
+ et, parent=dest, ref_type=et, is_optional=is_optional
393
+ )
394
+ dest_node = dest._get_node(key)
395
+
396
+ if dest_node is not None:
397
+ if isinstance(dest_node, BaseContainer):
398
+ if isinstance(src_node, BaseContainer):
399
+ dest_node._merge_with(src_node)
400
+ elif not src_node_missing:
401
+ dest.__setitem__(key, src_node)
402
+ else:
403
+ if isinstance(src_node, BaseContainer):
404
+ dest.__setitem__(key, src_node)
405
+ else:
406
+ assert isinstance(dest_node, (ValueNode, UnionNode))
407
+ assert isinstance(src_node, (ValueNode, UnionNode))
408
+ try:
409
+ if isinstance(dest_node, AnyNode):
410
+ if src_node_missing:
411
+ node = copy.copy(src_node)
412
+ # if src node is missing, use the value from the dest_node,
413
+ # but validate it against the type of the src node before assigment
414
+ node._set_value(dest_node._value())
415
+ else:
416
+ node = src_node
417
+ dest.__setitem__(key, node)
418
+ else:
419
+ if not src_node_missing:
420
+ dest_node._set_value(src_value)
421
+
422
+ except (ValidationError, ReadonlyConfigError) as e:
423
+ dest._format_and_raise(key=key, value=src_value, cause=e)
424
+ else:
425
+ from omegaconf import open_dict
426
+
427
+ if is_structured_config(src_type):
428
+ # verified to be compatible above in _validate_merge
429
+ with open_dict(dest):
430
+ dest[key] = src._get_node(key)
431
+ else:
432
+ dest[key] = src._get_node(key)
433
+
434
+ _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
435
+
436
+ # explicit flags on the source config are replacing the flag values in the destination
437
+ flags = src._metadata.flags
438
+ assert flags is not None
439
+ for flag, value in flags.items():
440
+ if value is not None:
441
+ dest._set_flag(flag, value)
442
+
443
+ @staticmethod
444
+ def _list_merge(dest: Any, src: Any) -> None:
445
+ from omegaconf import DictConfig, ListConfig, OmegaConf
446
+
447
+ assert isinstance(dest, ListConfig)
448
+ assert isinstance(src, ListConfig)
449
+
450
+ if src._is_none():
451
+ dest._set_value(None)
452
+ elif src._is_missing():
453
+ # do not change dest if src is MISSING.
454
+ if dest._metadata.element_type is Any:
455
+ dest._metadata.element_type = src._metadata.element_type
456
+ elif src._is_interpolation():
457
+ dest._set_value(src._value())
458
+ else:
459
+ temp_target = ListConfig(content=[], parent=dest._get_parent())
460
+ temp_target.__dict__["_metadata"] = copy.deepcopy(
461
+ dest.__dict__["_metadata"]
462
+ )
463
+ is_optional, et = _resolve_optional(dest._metadata.element_type)
464
+ if is_structured_config(et):
465
+ prototype = DictConfig(et, ref_type=et, is_optional=is_optional)
466
+ for item in src._iter_ex(resolve=False):
467
+ if isinstance(item, DictConfig):
468
+ item = OmegaConf.merge(prototype, item)
469
+ temp_target.append(item)
470
+ else:
471
+ for item in src._iter_ex(resolve=False):
472
+ temp_target.append(item)
473
+
474
+ dest.__dict__["_content"] = temp_target.__dict__["_content"]
475
+
476
+ # explicit flags on the source config are replacing the flag values in the destination
477
+ flags = src._metadata.flags
478
+ assert flags is not None
479
+ for flag, value in flags.items():
480
+ if value is not None:
481
+ dest._set_flag(flag, value)
482
+
483
+ def merge_with(
484
+ self,
485
+ *others: Union[
486
+ "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
487
+ ],
488
+ ) -> None:
489
+ try:
490
+ self._merge_with(*others)
491
+ except Exception as e:
492
+ self._format_and_raise(key=None, value=None, cause=e)
493
+
494
+ def _merge_with(
495
+ self,
496
+ *others: Union[
497
+ "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
498
+ ],
499
+ ) -> None:
500
+ from .dictconfig import DictConfig
501
+ from .listconfig import ListConfig
502
+
503
+ """merge a list of other Config objects into this one, overriding as needed"""
504
+ for other in others:
505
+ if other is None:
506
+ raise ValueError("Cannot merge with a None config")
507
+
508
+ my_flags = {}
509
+ if self._get_flag("allow_objects") is True:
510
+ my_flags = {"allow_objects": True}
511
+ other = _ensure_container(other, flags=my_flags)
512
+
513
+ if isinstance(self, DictConfig) and isinstance(other, DictConfig):
514
+ BaseContainer._map_merge(self, other)
515
+ elif isinstance(self, ListConfig) and isinstance(other, ListConfig):
516
+ BaseContainer._list_merge(self, other)
517
+ else:
518
+ raise TypeError("Cannot merge DictConfig with ListConfig")
519
+
520
+ # recursively correct the parent hierarchy after the merge
521
+ self._re_parent()
522
+
523
+ # noinspection PyProtectedMember
524
+ def _set_item_impl(self, key: Any, value: Any) -> None:
525
+ """
526
+ Changes the value of the node key with the desired value. If the node key doesn't
527
+ exist it creates a new one.
528
+ """
529
+ from .nodes import AnyNode, ValueNode
530
+
531
+ if isinstance(value, Node):
532
+ do_deepcopy = not self._get_flag("no_deepcopy_set_nodes")
533
+ if not do_deepcopy and isinstance(value, Box):
534
+ # if value is from the same config, perform a deepcopy no matter what.
535
+ if self._get_root() is value._get_root():
536
+ do_deepcopy = True
537
+
538
+ if do_deepcopy:
539
+ value = copy.deepcopy(value)
540
+ value._set_parent(None)
541
+
542
+ try:
543
+ old = value._key()
544
+ value._set_key(key)
545
+ self._validate_set(key, value)
546
+ finally:
547
+ value._set_key(old)
548
+ else:
549
+ self._validate_set(key, value)
550
+
551
+ if self._get_flag("readonly"):
552
+ raise ReadonlyConfigError("Cannot change read-only config container")
553
+
554
+ input_is_node = isinstance(value, Node)
555
+ target_node_ref = self._get_node(key)
556
+ assert target_node_ref is None or isinstance(target_node_ref, Node)
557
+
558
+ input_is_typed_vnode = isinstance(value, ValueNode) and not isinstance(
559
+ value, AnyNode
560
+ )
561
+
562
+ def get_target_type_hint(val: Any) -> Any:
563
+ if not is_structured_config(val):
564
+ type_hint = self._metadata.element_type
565
+ else:
566
+ target = self._get_node(key)
567
+ if target is None:
568
+ type_hint = self._metadata.element_type
569
+ else:
570
+ assert isinstance(target, Node)
571
+ type_hint = target._metadata.type_hint
572
+ return type_hint
573
+
574
+ target_type_hint = get_target_type_hint(value)
575
+ _, target_ref_type = _resolve_optional(target_type_hint)
576
+
577
+ def assign(value_key: Any, val: Node) -> None:
578
+ assert val._get_parent() is None
579
+ v = val
580
+ v._set_parent(self)
581
+ v._set_key(value_key)
582
+ _deep_update_type_hint(node=v, type_hint=self._metadata.element_type)
583
+ self.__dict__["_content"][value_key] = v
584
+
585
+ if input_is_typed_vnode and not is_union_annotation(target_ref_type):
586
+ assign(key, value)
587
+ else:
588
+ # input is not a ValueNode, can be primitive or box
589
+
590
+ special_value = _is_special(value)
591
+ # We use the `Node._set_value` method if the target node exists and:
592
+ # 1. the target has an explicit ref_type, or
593
+ # 2. the target is an AnyNode and the input is a primitive type.
594
+ should_set_value = target_node_ref is not None and (
595
+ target_node_ref._has_ref_type()
596
+ or (
597
+ isinstance(target_node_ref, AnyNode)
598
+ and is_primitive_type_annotation(value)
599
+ )
600
+ )
601
+ if should_set_value:
602
+ if special_value and isinstance(value, Node):
603
+ value = value._value()
604
+ self.__dict__["_content"][key]._set_value(value)
605
+ elif input_is_node:
606
+ if (
607
+ special_value
608
+ and (
609
+ is_container_annotation(target_ref_type)
610
+ or is_structured_config(target_ref_type)
611
+ )
612
+ or is_primitive_type_annotation(target_ref_type)
613
+ or is_union_annotation(target_ref_type)
614
+ ):
615
+ value = _get_value(value)
616
+ self._wrap_value_and_set(key, value, target_type_hint)
617
+ else:
618
+ assign(key, value)
619
+ else:
620
+ self._wrap_value_and_set(key, value, target_type_hint)
621
+
622
+ def _wrap_value_and_set(self, key: Any, val: Any, type_hint: Any) -> None:
623
+ from omegaconf.omegaconf import _maybe_wrap
624
+
625
+ is_optional, ref_type = _resolve_optional(type_hint)
626
+
627
+ try:
628
+ wrapped = _maybe_wrap(
629
+ ref_type=ref_type,
630
+ key=key,
631
+ value=val,
632
+ is_optional=is_optional,
633
+ parent=self,
634
+ )
635
+ except ValidationError as e:
636
+ self._format_and_raise(key=key, value=val, cause=e)
637
+ self.__dict__["_content"][key] = wrapped
638
+
639
+ @staticmethod
640
+ def _item_eq(
641
+ c1: Container,
642
+ k1: Union[DictKeyType, int],
643
+ c2: Container,
644
+ k2: Union[DictKeyType, int],
645
+ ) -> bool:
646
+ v1 = c1._get_child(k1)
647
+ v2 = c2._get_child(k2)
648
+ assert v1 is not None and v2 is not None
649
+
650
+ assert isinstance(v1, Node)
651
+ assert isinstance(v2, Node)
652
+
653
+ if v1._is_none() and v2._is_none():
654
+ return True
655
+
656
+ if v1._is_missing() and v2._is_missing():
657
+ return True
658
+
659
+ v1_inter = v1._is_interpolation()
660
+ v2_inter = v2._is_interpolation()
661
+ dv1: Optional[Node] = v1
662
+ dv2: Optional[Node] = v2
663
+
664
+ if v1_inter:
665
+ dv1 = v1._maybe_dereference_node()
666
+ if v2_inter:
667
+ dv2 = v2._maybe_dereference_node()
668
+
669
+ if v1_inter and v2_inter:
670
+ if dv1 is None or dv2 is None:
671
+ return v1 == v2
672
+ else:
673
+ # both are not none, if both are containers compare as container
674
+ if isinstance(dv1, Container) and isinstance(dv2, Container):
675
+ if dv1 != dv2:
676
+ return False
677
+ dv1 = _get_value(dv1)
678
+ dv2 = _get_value(dv2)
679
+ return dv1 == dv2
680
+ elif not v1_inter and not v2_inter:
681
+ v1 = _get_value(v1)
682
+ v2 = _get_value(v2)
683
+ ret = v1 == v2
684
+ assert isinstance(ret, bool)
685
+ return ret
686
+ else:
687
+ dv1 = _get_value(dv1)
688
+ dv2 = _get_value(dv2)
689
+ ret = dv1 == dv2
690
+ assert isinstance(ret, bool)
691
+ return ret
692
+
693
+ def _is_optional(self) -> bool:
694
+ return self.__dict__["_metadata"].optional is True
695
+
696
+ def _is_interpolation(self) -> bool:
697
+ return _is_interpolation(self.__dict__["_content"])
698
+
699
+ @abstractmethod
700
+ def _validate_get(self, key: Any, value: Any = None) -> None:
701
+ ...
702
+
703
+ @abstractmethod
704
+ def _validate_set(self, key: Any, value: Any) -> None:
705
+ ...
706
+
707
+ def _value(self) -> Any:
708
+ return self.__dict__["_content"]
709
+
710
+ def _get_full_key(self, key: Union[DictKeyType, int, slice, None]) -> str:
711
+ from .listconfig import ListConfig
712
+ from .omegaconf import _select_one
713
+
714
+ if not isinstance(key, (int, str, Enum, float, bool, slice, bytes, type(None))):
715
+ return ""
716
+
717
+ def _slice_to_str(x: slice) -> str:
718
+ if x.step is not None:
719
+ return f"{x.start}:{x.stop}:{x.step}"
720
+ else:
721
+ return f"{x.start}:{x.stop}"
722
+
723
+ def prepand(
724
+ full_key: str,
725
+ parent_type: Any,
726
+ cur_type: Any,
727
+ key: Optional[Union[DictKeyType, int, slice]],
728
+ ) -> str:
729
+ if key is None:
730
+ return full_key
731
+
732
+ if isinstance(key, slice):
733
+ key = _slice_to_str(key)
734
+ elif isinstance(key, Enum):
735
+ key = key.name
736
+ else:
737
+ key = str(key)
738
+
739
+ assert isinstance(key, str)
740
+
741
+ if issubclass(parent_type, ListConfig):
742
+ if full_key != "":
743
+ if issubclass(cur_type, ListConfig):
744
+ full_key = f"[{key}]{full_key}"
745
+ else:
746
+ full_key = f"[{key}].{full_key}"
747
+ else:
748
+ full_key = f"[{key}]"
749
+ else:
750
+ if full_key == "":
751
+ full_key = key
752
+ else:
753
+ if issubclass(cur_type, ListConfig):
754
+ full_key = f"{key}{full_key}"
755
+ else:
756
+ full_key = f"{key}.{full_key}"
757
+ return full_key
758
+
759
+ if key is not None and key != "":
760
+ assert isinstance(self, Container)
761
+ cur, _ = _select_one(
762
+ c=self, key=str(key), throw_on_missing=False, throw_on_type_error=False
763
+ )
764
+ if cur is None:
765
+ cur = self
766
+ full_key = prepand("", type(cur), None, key)
767
+ if cur._key() is not None:
768
+ full_key = prepand(
769
+ full_key, type(cur._get_parent()), type(cur), cur._key()
770
+ )
771
+ else:
772
+ full_key = prepand("", type(cur._get_parent()), type(cur), cur._key())
773
+ else:
774
+ cur = self
775
+ if cur._key() is None:
776
+ return ""
777
+ full_key = self._key()
778
+
779
+ assert cur is not None
780
+ memo = {id(cur)} # remember already visited nodes so as to detect cycles
781
+ while cur._get_parent() is not None:
782
+ cur = cur._get_parent()
783
+ if id(cur) in memo:
784
+ raise ConfigCycleDetectedException(
785
+ f"Cycle when iterating over parents of key `{key!s}`"
786
+ )
787
+ memo.add(id(cur))
788
+ assert cur is not None
789
+ if cur._key() is not None:
790
+ full_key = prepand(
791
+ full_key, type(cur._get_parent()), type(cur), cur._key()
792
+ )
793
+
794
+ return full_key
795
+
796
+
797
+ def _create_structured_with_missing_fields(
798
+ ref_type: type, object_type: Optional[type] = None
799
+ ) -> "DictConfig":
800
+ from . import MISSING, DictConfig
801
+
802
+ cfg_data = get_structured_config_data(ref_type)
803
+ for v in cfg_data.values():
804
+ v._set_value(MISSING)
805
+
806
+ cfg = DictConfig(cfg_data)
807
+ cfg._metadata.optional, cfg._metadata.ref_type = _resolve_optional(ref_type)
808
+ cfg._metadata.object_type = object_type
809
+
810
+ return cfg
811
+
812
+
813
+ def _update_types(node: Node, ref_type: Any, object_type: Optional[type]) -> None:
814
+ if object_type is not None and not is_primitive_dict(object_type):
815
+ node._metadata.object_type = object_type
816
+
817
+ if node._metadata.ref_type is Any:
818
+ _deep_update_type_hint(node, ref_type)
819
+
820
+
821
+ def _deep_update_type_hint(node: Node, type_hint: Any) -> None:
822
+ """Ensure node is compatible with type_hint, mutating if necessary."""
823
+ from omegaconf import DictConfig, ListConfig
824
+
825
+ from ._utils import get_dict_key_value_types, get_list_element_type
826
+
827
+ if type_hint is Any:
828
+ return
829
+
830
+ _shallow_validate_type_hint(node, type_hint)
831
+
832
+ new_is_optional, new_ref_type = _resolve_optional(type_hint)
833
+ node._metadata.ref_type = new_ref_type
834
+ node._metadata.optional = new_is_optional
835
+
836
+ if is_list_annotation(new_ref_type) and isinstance(node, ListConfig):
837
+ new_element_type = get_list_element_type(new_ref_type)
838
+ node._metadata.element_type = new_element_type
839
+ if not _is_special(node):
840
+ for i in range(len(node)):
841
+ _deep_update_subnode(node, i, new_element_type)
842
+
843
+ if is_dict_annotation(new_ref_type) and isinstance(node, DictConfig):
844
+ new_key_type, new_element_type = get_dict_key_value_types(new_ref_type)
845
+ node._metadata.key_type = new_key_type
846
+ node._metadata.element_type = new_element_type
847
+ if not _is_special(node):
848
+ for key in node:
849
+ if new_key_type is not Any and not isinstance(key, new_key_type):
850
+ raise KeyValidationError(
851
+ f"Key {key!r} ({type(key).__name__}) is incompatible"
852
+ + f" with key type hint '{new_key_type.__name__}'"
853
+ )
854
+ _deep_update_subnode(node, key, new_element_type)
855
+
856
+
857
+ def _deep_update_subnode(node: BaseContainer, key: Any, value_type_hint: Any) -> None:
858
+ """Get node[key] and ensure it is compatible with value_type_hint, mutating if necessary."""
859
+ subnode = node._get_node(key)
860
+ assert isinstance(subnode, Node)
861
+ if _is_special(subnode):
862
+ # Ensure special values are wrapped in a Node subclass that
863
+ # is compatible with the type hint.
864
+ node._wrap_value_and_set(key, subnode._value(), value_type_hint)
865
+ subnode = node._get_node(key)
866
+ assert isinstance(subnode, Node)
867
+ _deep_update_type_hint(subnode, value_type_hint)
868
+
869
+
870
+ def _shallow_validate_type_hint(node: Node, type_hint: Any) -> None:
871
+ """Error if node's type, content and metadata are not compatible with type_hint."""
872
+ from omegaconf import DictConfig, ListConfig, ValueNode
873
+
874
+ is_optional, ref_type = _resolve_optional(type_hint)
875
+
876
+ vk = get_value_kind(node)
877
+
878
+ if node._is_none():
879
+ if not is_optional:
880
+ value = _get_value(node)
881
+ raise ValidationError(
882
+ f"Value {value!r} ({type(value).__name__})"
883
+ + f" is incompatible with type hint '{ref_type.__name__}'"
884
+ )
885
+ return
886
+ elif vk in (ValueKind.MANDATORY_MISSING, ValueKind.INTERPOLATION):
887
+ return
888
+ elif vk == ValueKind.VALUE:
889
+ if is_primitive_type_annotation(ref_type) and isinstance(node, ValueNode):
890
+ value = node._value()
891
+ if not isinstance(value, ref_type):
892
+ raise ValidationError(
893
+ f"Value {value!r} ({type(value).__name__})"
894
+ + f" is incompatible with type hint '{ref_type.__name__}'"
895
+ )
896
+ elif is_structured_config(ref_type) and isinstance(node, DictConfig):
897
+ return
898
+ elif is_dict_annotation(ref_type) and isinstance(node, DictConfig):
899
+ return
900
+ elif is_list_annotation(ref_type) and isinstance(node, ListConfig):
901
+ return
902
+ else:
903
+ if isinstance(node, ValueNode):
904
+ value = node._value()
905
+ raise ValidationError(
906
+ f"Value {value!r} ({type(value).__name__})"
907
+ + f" is incompatible with type hint '{ref_type}'"
908
+ )
909
+ else:
910
+ raise ValidationError(
911
+ f"'{type(node).__name__}' is incompatible"
912
+ + f" with type hint '{ref_type}'"
913
+ )
914
+
915
+ else:
916
+ assert False
omegaconf/dictconfig.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from enum import Enum
3
+ from typing import (
4
+ Any,
5
+ Dict,
6
+ ItemsView,
7
+ Iterable,
8
+ Iterator,
9
+ KeysView,
10
+ List,
11
+ MutableMapping,
12
+ Optional,
13
+ Sequence,
14
+ Tuple,
15
+ Type,
16
+ Union,
17
+ )
18
+
19
+ from ._utils import (
20
+ _DEFAULT_MARKER_,
21
+ ValueKind,
22
+ _get_value,
23
+ _is_interpolation,
24
+ _is_missing_literal,
25
+ _is_missing_value,
26
+ _is_none,
27
+ _resolve_optional,
28
+ _valid_dict_key_annotation_type,
29
+ format_and_raise,
30
+ get_structured_config_data,
31
+ get_structured_config_init_field_names,
32
+ get_type_of,
33
+ get_value_kind,
34
+ is_container_annotation,
35
+ is_dict,
36
+ is_primitive_dict,
37
+ is_structured_config,
38
+ is_structured_config_frozen,
39
+ type_str,
40
+ )
41
+ from .base import Box, Container, ContainerMetadata, DictKeyType, Node
42
+ from .basecontainer import BaseContainer
43
+ from .errors import (
44
+ ConfigAttributeError,
45
+ ConfigKeyError,
46
+ ConfigTypeError,
47
+ InterpolationResolutionError,
48
+ KeyValidationError,
49
+ MissingMandatoryValue,
50
+ OmegaConfBaseException,
51
+ ReadonlyConfigError,
52
+ ValidationError,
53
+ )
54
+ from .nodes import EnumNode, ValueNode
55
+
56
+
57
+ class DictConfig(BaseContainer, MutableMapping[Any, Any]):
58
+
59
+ _metadata: ContainerMetadata
60
+ _content: Union[Dict[DictKeyType, Node], None, str]
61
+
62
+ def __init__(
63
+ self,
64
+ content: Union[Dict[DictKeyType, Any], "DictConfig", Any],
65
+ key: Any = None,
66
+ parent: Optional[Box] = None,
67
+ ref_type: Union[Any, Type[Any]] = Any,
68
+ key_type: Union[Any, Type[Any]] = Any,
69
+ element_type: Union[Any, Type[Any]] = Any,
70
+ is_optional: bool = True,
71
+ flags: Optional[Dict[str, bool]] = None,
72
+ ) -> None:
73
+ try:
74
+ if isinstance(content, DictConfig):
75
+ if flags is None:
76
+ flags = content._metadata.flags
77
+ super().__init__(
78
+ parent=parent,
79
+ metadata=ContainerMetadata(
80
+ key=key,
81
+ optional=is_optional,
82
+ ref_type=ref_type,
83
+ object_type=dict,
84
+ key_type=key_type,
85
+ element_type=element_type,
86
+ flags=flags,
87
+ ),
88
+ )
89
+
90
+ if not _valid_dict_key_annotation_type(key_type):
91
+ raise KeyValidationError(f"Unsupported key type {key_type}")
92
+
93
+ if is_structured_config(content) or is_structured_config(ref_type):
94
+ self._set_value(content, flags=flags)
95
+ if is_structured_config_frozen(content) or is_structured_config_frozen(
96
+ ref_type
97
+ ):
98
+ self._set_flag("readonly", True)
99
+
100
+ else:
101
+ if isinstance(content, DictConfig):
102
+ metadata = copy.deepcopy(content._metadata)
103
+ metadata.key = key
104
+ metadata.ref_type = ref_type
105
+ metadata.optional = is_optional
106
+ metadata.element_type = element_type
107
+ metadata.key_type = key_type
108
+ self.__dict__["_metadata"] = metadata
109
+ self._set_value(content, flags=flags)
110
+ except Exception as ex:
111
+ format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
112
+
113
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "DictConfig":
114
+ res = DictConfig(None)
115
+ res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
116
+ res.__dict__["_flags_cache"] = copy.deepcopy(
117
+ self.__dict__["_flags_cache"], memo=memo
118
+ )
119
+
120
+ src_content = self.__dict__["_content"]
121
+ if isinstance(src_content, dict):
122
+ content_copy = {}
123
+ for k, v in src_content.items():
124
+ old_parent = v.__dict__["_parent"]
125
+ try:
126
+ v.__dict__["_parent"] = None
127
+ vc = copy.deepcopy(v, memo=memo)
128
+ vc.__dict__["_parent"] = res
129
+ content_copy[k] = vc
130
+ finally:
131
+ v.__dict__["_parent"] = old_parent
132
+ else:
133
+ # None and strings can be assigned as is
134
+ content_copy = src_content
135
+
136
+ res.__dict__["_content"] = content_copy
137
+ # parent is retained, but not copied
138
+ res.__dict__["_parent"] = self.__dict__["_parent"]
139
+ return res
140
+
141
+ def copy(self) -> "DictConfig":
142
+ return copy.copy(self)
143
+
144
+ def _is_typed(self) -> bool:
145
+ return self._metadata.object_type not in (Any, None) and not is_dict(
146
+ self._metadata.object_type
147
+ )
148
+
149
+ def _validate_get(self, key: Any, value: Any = None) -> None:
150
+ is_typed = self._is_typed()
151
+
152
+ is_struct = self._get_flag("struct") is True
153
+ if key not in self.__dict__["_content"]:
154
+ if is_typed:
155
+ # do not raise an exception if struct is explicitly set to False
156
+ if self._get_node_flag("struct") is False:
157
+ return
158
+ if is_typed or is_struct:
159
+ if is_typed:
160
+ assert self._metadata.object_type not in (dict, None)
161
+ msg = f"Key '{key}' not in '{self._metadata.object_type.__name__}'"
162
+ else:
163
+ msg = f"Key '{key}' is not in struct"
164
+ self._format_and_raise(
165
+ key=key, value=value, cause=ConfigAttributeError(msg)
166
+ )
167
+
168
+ def _validate_set(self, key: Any, value: Any) -> None:
169
+ from omegaconf import OmegaConf
170
+
171
+ vk = get_value_kind(value)
172
+ if vk == ValueKind.INTERPOLATION:
173
+ return
174
+ if _is_none(value):
175
+ self._validate_non_optional(key, value)
176
+ return
177
+ if vk == ValueKind.MANDATORY_MISSING or value is None:
178
+ return
179
+
180
+ target = self._get_node(key) if key is not None else self
181
+
182
+ target_has_ref_type = isinstance(
183
+ target, DictConfig
184
+ ) and target._metadata.ref_type not in (Any, dict)
185
+ is_valid_target = target is None or not target_has_ref_type
186
+
187
+ if is_valid_target:
188
+ return
189
+
190
+ assert isinstance(target, Node)
191
+
192
+ target_type = target._metadata.ref_type
193
+ value_type = OmegaConf.get_type(value)
194
+
195
+ if is_dict(value_type) and is_dict(target_type):
196
+ return
197
+ if is_container_annotation(target_type) and not is_container_annotation(
198
+ value_type
199
+ ):
200
+ raise ValidationError(
201
+ f"Cannot assign {type_str(value_type)} to {type_str(target_type)}"
202
+ )
203
+
204
+ if target_type is not None and value_type is not None:
205
+ origin = getattr(target_type, "__origin__", target_type)
206
+ if not issubclass(value_type, origin):
207
+ self._raise_invalid_value(value, value_type, target_type)
208
+
209
+ def _validate_merge(self, value: Any) -> None:
210
+ from omegaconf import OmegaConf
211
+
212
+ dest = self
213
+ src = value
214
+
215
+ self._validate_non_optional(None, src)
216
+
217
+ dest_obj_type = OmegaConf.get_type(dest)
218
+ src_obj_type = OmegaConf.get_type(src)
219
+
220
+ if dest._is_missing() and src._metadata.object_type not in (dict, None):
221
+ self._validate_set(key=None, value=_get_value(src))
222
+
223
+ if src._is_missing():
224
+ return
225
+
226
+ validation_error = (
227
+ dest_obj_type is not None
228
+ and src_obj_type is not None
229
+ and is_structured_config(dest_obj_type)
230
+ and not src._is_none()
231
+ and not is_dict(src_obj_type)
232
+ and not issubclass(src_obj_type, dest_obj_type)
233
+ )
234
+ if validation_error:
235
+ msg = (
236
+ f"Merge error: {type_str(src_obj_type)} is not a "
237
+ f"subclass of {type_str(dest_obj_type)}. value: {src}"
238
+ )
239
+ raise ValidationError(msg)
240
+
241
+ def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None:
242
+ if _is_none(value, resolve=True, throw_on_resolution_failure=False):
243
+
244
+ if key is not None:
245
+ child = self._get_node(key)
246
+ if child is not None:
247
+ assert isinstance(child, Node)
248
+ field_is_optional = child._is_optional()
249
+ else:
250
+ field_is_optional, _ = _resolve_optional(
251
+ self._metadata.element_type
252
+ )
253
+ else:
254
+ field_is_optional = self._is_optional()
255
+
256
+ if not field_is_optional:
257
+ self._format_and_raise(
258
+ key=key,
259
+ value=value,
260
+ cause=ValidationError("field '$FULL_KEY' is not Optional"),
261
+ )
262
+
263
+ def _raise_invalid_value(
264
+ self, value: Any, value_type: Any, target_type: Any
265
+ ) -> None:
266
+ assert value_type is not None
267
+ assert target_type is not None
268
+ msg = (
269
+ f"Invalid type assigned: {type_str(value_type)} is not a "
270
+ f"subclass of {type_str(target_type)}. value: {value}"
271
+ )
272
+ raise ValidationError(msg)
273
+
274
+ def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
275
+ return self._s_validate_and_normalize_key(self._metadata.key_type, key)
276
+
277
+ def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
278
+ if key_type is Any:
279
+ for t in DictKeyType.__args__: # type: ignore
280
+ if isinstance(key, t):
281
+ return key # type: ignore
282
+ raise KeyValidationError("Incompatible key type '$KEY_TYPE'")
283
+ elif key_type is bool and key in [0, 1]:
284
+ # Python treats True as 1 and False as 0 when used as dict keys
285
+ # assert hash(0) == hash(False)
286
+ # assert hash(1) == hash(True)
287
+ return bool(key)
288
+ elif key_type in (str, bytes, int, float, bool): # primitive type
289
+ if not isinstance(key, key_type):
290
+ raise KeyValidationError(
291
+ f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
292
+ )
293
+
294
+ return key # type: ignore
295
+ elif issubclass(key_type, Enum):
296
+ try:
297
+ return EnumNode.validate_and_convert_to_enum(key_type, key)
298
+ except ValidationError:
299
+ valid = ", ".join([x for x in key_type.__members__.keys()])
300
+ raise KeyValidationError(
301
+ f"Key '$KEY' is incompatible with the enum type '{key_type.__name__}', valid: [{valid}]"
302
+ )
303
+ else:
304
+ assert False, f"Unsupported key type {key_type}"
305
+
306
+ def __setitem__(self, key: DictKeyType, value: Any) -> None:
307
+ try:
308
+ self.__set_impl(key=key, value=value)
309
+ except AttributeError as e:
310
+ self._format_and_raise(
311
+ key=key, value=value, type_override=ConfigKeyError, cause=e
312
+ )
313
+ except Exception as e:
314
+ self._format_and_raise(key=key, value=value, cause=e)
315
+
316
+ def __set_impl(self, key: DictKeyType, value: Any) -> None:
317
+ key = self._validate_and_normalize_key(key)
318
+ self._set_item_impl(key, value)
319
+
320
+ # hide content while inspecting in debugger
321
+ def __dir__(self) -> Iterable[str]:
322
+ if self._is_missing() or self._is_none():
323
+ return []
324
+ return self.__dict__["_content"].keys() # type: ignore
325
+
326
+ def __setattr__(self, key: str, value: Any) -> None:
327
+ """
328
+ Allow assigning attributes to DictConfig
329
+ :param key:
330
+ :param value:
331
+ :return:
332
+ """
333
+ try:
334
+ self.__set_impl(key, value)
335
+ except Exception as e:
336
+ if isinstance(e, OmegaConfBaseException) and e._initialized:
337
+ raise e
338
+ self._format_and_raise(key=key, value=value, cause=e)
339
+ assert False
340
+
341
+ def __getattr__(self, key: str) -> Any:
342
+ """
343
+ Allow accessing dictionary values as attributes
344
+ :param key:
345
+ :return:
346
+ """
347
+ if key == "__name__":
348
+ raise AttributeError()
349
+
350
+ try:
351
+ return self._get_impl(
352
+ key=key, default_value=_DEFAULT_MARKER_, validate_key=False
353
+ )
354
+ except ConfigKeyError as e:
355
+ self._format_and_raise(
356
+ key=key, value=None, cause=e, type_override=ConfigAttributeError
357
+ )
358
+ except Exception as e:
359
+ self._format_and_raise(key=key, value=None, cause=e)
360
+
361
+ def __getitem__(self, key: DictKeyType) -> Any:
362
+ """
363
+ Allow map style access
364
+ :param key:
365
+ :return:
366
+ """
367
+
368
+ try:
369
+ return self._get_impl(key=key, default_value=_DEFAULT_MARKER_)
370
+ except AttributeError as e:
371
+ self._format_and_raise(
372
+ key=key, value=None, cause=e, type_override=ConfigKeyError
373
+ )
374
+ except Exception as e:
375
+ self._format_and_raise(key=key, value=None, cause=e)
376
+
377
+ def __delattr__(self, key: str) -> None:
378
+ """
379
+ Allow deleting dictionary values as attributes
380
+ :param key:
381
+ :return:
382
+ """
383
+ if self._get_flag("readonly"):
384
+ self._format_and_raise(
385
+ key=key,
386
+ value=None,
387
+ cause=ReadonlyConfigError(
388
+ "DictConfig in read-only mode does not support deletion"
389
+ ),
390
+ )
391
+ try:
392
+ del self.__dict__["_content"][key]
393
+ except KeyError:
394
+ msg = "Attribute not found: '$KEY'"
395
+ self._format_and_raise(key=key, value=None, cause=ConfigAttributeError(msg))
396
+
397
+ def __delitem__(self, key: DictKeyType) -> None:
398
+ key = self._validate_and_normalize_key(key)
399
+ if self._get_flag("readonly"):
400
+ self._format_and_raise(
401
+ key=key,
402
+ value=None,
403
+ cause=ReadonlyConfigError(
404
+ "DictConfig in read-only mode does not support deletion"
405
+ ),
406
+ )
407
+ if self._get_flag("struct"):
408
+ self._format_and_raise(
409
+ key=key,
410
+ value=None,
411
+ cause=ConfigTypeError(
412
+ "DictConfig in struct mode does not support deletion"
413
+ ),
414
+ )
415
+ if self._is_typed() and self._get_node_flag("struct") is not False:
416
+ self._format_and_raise(
417
+ key=key,
418
+ value=None,
419
+ cause=ConfigTypeError(
420
+ f"{type_str(self._metadata.object_type)} (DictConfig) does not support deletion"
421
+ ),
422
+ )
423
+
424
+ try:
425
+ del self.__dict__["_content"][key]
426
+ except KeyError:
427
+ msg = "Key not found: '$KEY'"
428
+ self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg))
429
+
430
+ def get(self, key: DictKeyType, default_value: Any = None) -> Any:
431
+ """Return the value for `key` if `key` is in the dictionary, else
432
+ `default_value` (defaulting to `None`)."""
433
+ try:
434
+ return self._get_impl(key=key, default_value=default_value)
435
+ except KeyValidationError as e:
436
+ self._format_and_raise(key=key, value=None, cause=e)
437
+
438
+ def _get_impl(
439
+ self, key: DictKeyType, default_value: Any, validate_key: bool = True
440
+ ) -> Any:
441
+ try:
442
+ node = self._get_child(
443
+ key=key, throw_on_missing_key=True, validate_key=validate_key
444
+ )
445
+ except (ConfigAttributeError, ConfigKeyError):
446
+ if default_value is not _DEFAULT_MARKER_:
447
+ return default_value
448
+ else:
449
+ raise
450
+ assert isinstance(node, Node)
451
+ return self._resolve_with_default(
452
+ key=key, value=node, default_value=default_value
453
+ )
454
+
455
+ def _get_node(
456
+ self,
457
+ key: DictKeyType,
458
+ validate_access: bool = True,
459
+ validate_key: bool = True,
460
+ throw_on_missing_value: bool = False,
461
+ throw_on_missing_key: bool = False,
462
+ ) -> Optional[Node]:
463
+ try:
464
+ key = self._validate_and_normalize_key(key)
465
+ except KeyValidationError:
466
+ if validate_access and validate_key:
467
+ raise
468
+ else:
469
+ if throw_on_missing_key:
470
+ raise ConfigAttributeError
471
+ else:
472
+ return None
473
+
474
+ if validate_access:
475
+ self._validate_get(key)
476
+
477
+ value: Optional[Node] = self.__dict__["_content"].get(key)
478
+ if value is None:
479
+ if throw_on_missing_key:
480
+ raise ConfigKeyError(f"Missing key {key!s}")
481
+ elif throw_on_missing_value and value._is_missing():
482
+ raise MissingMandatoryValue("Missing mandatory value: $KEY")
483
+ return value
484
+
485
+ def pop(self, key: DictKeyType, default: Any = _DEFAULT_MARKER_) -> Any:
486
+ try:
487
+ if self._get_flag("readonly"):
488
+ raise ReadonlyConfigError("Cannot pop from read-only node")
489
+ if self._get_flag("struct"):
490
+ raise ConfigTypeError("DictConfig in struct mode does not support pop")
491
+ if self._is_typed() and self._get_node_flag("struct") is not False:
492
+ raise ConfigTypeError(
493
+ f"{type_str(self._metadata.object_type)} (DictConfig) does not support pop"
494
+ )
495
+ key = self._validate_and_normalize_key(key)
496
+ node = self._get_child(key=key, validate_access=False)
497
+ if node is not None:
498
+ assert isinstance(node, Node)
499
+ value = self._resolve_with_default(
500
+ key=key, value=node, default_value=default
501
+ )
502
+
503
+ del self[key]
504
+ return value
505
+ else:
506
+ if default is not _DEFAULT_MARKER_:
507
+ return default
508
+ else:
509
+ full = self._get_full_key(key=key)
510
+ if full != key:
511
+ raise ConfigKeyError(
512
+ f"Key not found: '{key!s}' (path: '{full}')"
513
+ )
514
+ else:
515
+ raise ConfigKeyError(f"Key not found: '{key!s}'")
516
+ except Exception as e:
517
+ self._format_and_raise(key=key, value=None, cause=e)
518
+
519
+ def keys(self) -> KeysView[DictKeyType]:
520
+ if self._is_missing() or self._is_interpolation() or self._is_none():
521
+ return {}.keys()
522
+ ret = self.__dict__["_content"].keys()
523
+ assert isinstance(ret, KeysView)
524
+ return ret
525
+
526
+ def __contains__(self, key: object) -> bool:
527
+ """
528
+ A key is contained in a DictConfig if there is an associated value and
529
+ it is not a mandatory missing value ('???').
530
+ :param key:
531
+ :return:
532
+ """
533
+
534
+ try:
535
+ key = self._validate_and_normalize_key(key)
536
+ except KeyValidationError:
537
+ return False
538
+
539
+ try:
540
+ node = self._get_child(key)
541
+ assert node is None or isinstance(node, Node)
542
+ except (KeyError, AttributeError):
543
+ node = None
544
+
545
+ if node is None:
546
+ return False
547
+ else:
548
+ try:
549
+ self._resolve_with_default(key=key, value=node)
550
+ return True
551
+ except InterpolationResolutionError:
552
+ # Interpolations that fail count as existing.
553
+ return True
554
+ except MissingMandatoryValue:
555
+ # Missing values count as *not* existing.
556
+ return False
557
+
558
+ def __iter__(self) -> Iterator[DictKeyType]:
559
+ return iter(self.keys())
560
+
561
+ def items(self) -> ItemsView[DictKeyType, Any]:
562
+ return dict(self.items_ex(resolve=True, keys=None)).items()
563
+
564
+ def setdefault(self, key: DictKeyType, default: Any = None) -> Any:
565
+ if key in self:
566
+ ret = self.__getitem__(key)
567
+ else:
568
+ ret = default
569
+ self.__setitem__(key, default)
570
+ return ret
571
+
572
+ def items_ex(
573
+ self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None
574
+ ) -> List[Tuple[DictKeyType, Any]]:
575
+ items: List[Tuple[DictKeyType, Any]] = []
576
+
577
+ if self._is_none():
578
+ self._format_and_raise(
579
+ key=None,
580
+ value=None,
581
+ cause=TypeError("Cannot iterate a DictConfig object representing None"),
582
+ )
583
+ if self._is_missing():
584
+ raise MissingMandatoryValue("Cannot iterate a missing DictConfig")
585
+
586
+ for key in self.keys():
587
+ if resolve:
588
+ value = self[key]
589
+ else:
590
+ value = self.__dict__["_content"][key]
591
+ if isinstance(value, ValueNode):
592
+ value = value._value()
593
+ if keys is None or key in keys:
594
+ items.append((key, value))
595
+
596
+ return items
597
+
598
+ def __eq__(self, other: Any) -> bool:
599
+ if other is None:
600
+ return self.__dict__["_content"] is None
601
+ if is_primitive_dict(other) or is_structured_config(other):
602
+ other = DictConfig(other, flags={"allow_objects": True})
603
+ return DictConfig._dict_conf_eq(self, other)
604
+ if isinstance(other, DictConfig):
605
+ return DictConfig._dict_conf_eq(self, other)
606
+ if self._is_missing():
607
+ return _is_missing_literal(other)
608
+ return NotImplemented
609
+
610
+ def __ne__(self, other: Any) -> bool:
611
+ x = self.__eq__(other)
612
+ if x is not NotImplemented:
613
+ return not x
614
+ return NotImplemented
615
+
616
+ def __hash__(self) -> int:
617
+ return hash(str(self))
618
+
619
+ def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None:
620
+ """
621
+ Retypes a node.
622
+ This should only be used in rare circumstances, where you want to dynamically change
623
+ the runtime structured-type of a DictConfig.
624
+ It will change the type and add the additional fields based on the input class or object
625
+ """
626
+ if type_or_prototype is None:
627
+ return
628
+ if not is_structured_config(type_or_prototype):
629
+ raise ValueError(f"Expected structured config class: {type_or_prototype}")
630
+
631
+ from omegaconf import OmegaConf
632
+
633
+ proto: DictConfig = OmegaConf.structured(type_or_prototype)
634
+ object_type = proto._metadata.object_type
635
+ # remove the type to prevent assignment validation from rejecting the promotion.
636
+ proto._metadata.object_type = None
637
+ self.merge_with(proto)
638
+ # restore the type.
639
+ self._metadata.object_type = object_type
640
+
641
+ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
642
+ try:
643
+ previous_content = self.__dict__["_content"]
644
+ self._set_value_impl(value, flags)
645
+ except Exception as e:
646
+ self.__dict__["_content"] = previous_content
647
+ raise e
648
+
649
+ def _set_value_impl(
650
+ self, value: Any, flags: Optional[Dict[str, bool]] = None
651
+ ) -> None:
652
+ from omegaconf import MISSING, flag_override
653
+
654
+ if flags is None:
655
+ flags = {}
656
+
657
+ assert not isinstance(value, ValueNode)
658
+ self._validate_set(key=None, value=value)
659
+
660
+ if _is_none(value, resolve=True):
661
+ self.__dict__["_content"] = None
662
+ self._metadata.object_type = None
663
+ elif _is_interpolation(value, strict_interpolation_validation=True):
664
+ self.__dict__["_content"] = value
665
+ self._metadata.object_type = None
666
+ elif _is_missing_value(value):
667
+ self.__dict__["_content"] = MISSING
668
+ self._metadata.object_type = None
669
+ else:
670
+ self.__dict__["_content"] = {}
671
+ if is_structured_config(value):
672
+ self._metadata.object_type = None
673
+ ao = self._get_flag("allow_objects")
674
+ data = get_structured_config_data(value, allow_objects=ao)
675
+ with flag_override(self, ["struct", "readonly"], False):
676
+ for k, v in data.items():
677
+ self.__setitem__(k, v)
678
+ self._metadata.object_type = get_type_of(value)
679
+
680
+ elif isinstance(value, DictConfig):
681
+ self._metadata.flags = copy.deepcopy(flags)
682
+ with flag_override(self, ["struct", "readonly"], False):
683
+ for k, v in value.__dict__["_content"].items():
684
+ self.__setitem__(k, v)
685
+ self._metadata.object_type = value._metadata.object_type
686
+
687
+ elif isinstance(value, dict):
688
+ with flag_override(self, ["struct", "readonly"], False):
689
+ for k, v in value.items():
690
+ self.__setitem__(k, v)
691
+ self._metadata.object_type = dict
692
+
693
+ else: # pragma: no cover
694
+ msg = f"Unsupported value type: {value}"
695
+ raise ValidationError(msg)
696
+
697
+ @staticmethod
698
+ def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool:
699
+
700
+ d1_none = d1.__dict__["_content"] is None
701
+ d2_none = d2.__dict__["_content"] is None
702
+ if d1_none and d2_none:
703
+ return True
704
+ if d1_none != d2_none:
705
+ return False
706
+
707
+ assert isinstance(d1, DictConfig)
708
+ assert isinstance(d2, DictConfig)
709
+ if len(d1) != len(d2):
710
+ return False
711
+ if d1._is_missing() or d2._is_missing():
712
+ return d1._is_missing() is d2._is_missing()
713
+
714
+ for k, v in d1.items_ex(resolve=False):
715
+ if k not in d2.__dict__["_content"]:
716
+ return False
717
+ if not BaseContainer._item_eq(d1, k, d2, k):
718
+ return False
719
+
720
+ return True
721
+
722
+ def _to_object(self) -> Any:
723
+ """
724
+ Instantiate an instance of `self._metadata.object_type`.
725
+ This requires `self` to be a structured config.
726
+ Nested subconfigs are converted by calling `OmegaConf.to_object`.
727
+ """
728
+ from omegaconf import OmegaConf
729
+
730
+ object_type = self._metadata.object_type
731
+ assert is_structured_config(object_type)
732
+ init_field_names = set(get_structured_config_init_field_names(object_type))
733
+
734
+ init_field_items: Dict[str, Any] = {}
735
+ non_init_field_items: Dict[str, Any] = {}
736
+ for k in self.keys():
737
+ assert isinstance(k, str)
738
+ node = self._get_child(k)
739
+ assert isinstance(node, Node)
740
+ try:
741
+ node = node._dereference_node()
742
+ except InterpolationResolutionError as e:
743
+ self._format_and_raise(key=k, value=None, cause=e)
744
+ if node._is_missing():
745
+ if k not in init_field_names:
746
+ continue # MISSING is ignored for init=False fields
747
+ self._format_and_raise(
748
+ key=k,
749
+ value=None,
750
+ cause=MissingMandatoryValue(
751
+ "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY"
752
+ ),
753
+ )
754
+ if isinstance(node, Container):
755
+ v = OmegaConf.to_object(node)
756
+ else:
757
+ v = node._value()
758
+
759
+ if k in init_field_names:
760
+ init_field_items[k] = v
761
+ else:
762
+ non_init_field_items[k] = v
763
+
764
+ try:
765
+ result = object_type(**init_field_items)
766
+ except TypeError as exc:
767
+ self._format_and_raise(
768
+ key=None,
769
+ value=None,
770
+ cause=exc,
771
+ msg="Could not create instance of `$OBJECT_TYPE`: " + str(exc),
772
+ )
773
+
774
+ for k, v in non_init_field_items.items():
775
+ setattr(result, k, v)
776
+ return result
omegaconf/errors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Type
2
+
3
+
4
+ class OmegaConfBaseException(Exception):
5
+ # would ideally be typed Optional[Node]
6
+ parent_node: Any
7
+ child_node: Any
8
+ key: Any
9
+ full_key: Optional[str]
10
+ value: Any
11
+ msg: Optional[str]
12
+ cause: Optional[Exception]
13
+ object_type: Optional[Type[Any]]
14
+ object_type_str: Optional[str]
15
+ ref_type: Optional[Type[Any]]
16
+ ref_type_str: Optional[str]
17
+
18
+ _initialized: bool = False
19
+
20
+ def __init__(self, *_args: Any, **_kwargs: Any) -> None:
21
+ self.parent_node = None
22
+ self.child_node = None
23
+ self.key = None
24
+ self.full_key = None
25
+ self.value = None
26
+ self.msg = None
27
+ self.object_type = None
28
+ self.ref_type = None
29
+
30
+
31
+ class MissingMandatoryValue(OmegaConfBaseException):
32
+ """Thrown when a variable flagged with '???' value is accessed to
33
+ indicate that the value was not set"""
34
+
35
+
36
+ class KeyValidationError(OmegaConfBaseException, ValueError):
37
+ """
38
+ Thrown when an a key of invalid type is used
39
+ """
40
+
41
+
42
+ class ValidationError(OmegaConfBaseException, ValueError):
43
+ """
44
+ Thrown when a value fails validation
45
+ """
46
+
47
+
48
+ class UnsupportedValueType(ValidationError, ValueError):
49
+ """
50
+ Thrown when an input value is not of supported type
51
+ """
52
+
53
+
54
+ class ReadonlyConfigError(OmegaConfBaseException):
55
+ """
56
+ Thrown when someone tries to modify a frozen config
57
+ """
58
+
59
+
60
+ class InterpolationResolutionError(OmegaConfBaseException, ValueError):
61
+ """
62
+ Base class for exceptions raised when resolving an interpolation.
63
+ """
64
+
65
+
66
+ class UnsupportedInterpolationType(InterpolationResolutionError):
67
+ """
68
+ Thrown when an attempt to use an unregistered interpolation is made
69
+ """
70
+
71
+
72
+ class InterpolationKeyError(InterpolationResolutionError):
73
+ """
74
+ Thrown when a node does not exist when resolving an interpolation.
75
+ """
76
+
77
+
78
+ class InterpolationToMissingValueError(InterpolationResolutionError):
79
+ """
80
+ Thrown when a node interpolation points to a node that is set to ???.
81
+ """
82
+
83
+
84
+ class InterpolationValidationError(InterpolationResolutionError, ValidationError):
85
+ """
86
+ Thrown when the result of an interpolation fails the validation step.
87
+ """
88
+
89
+
90
+ class ConfigKeyError(OmegaConfBaseException, KeyError):
91
+ """
92
+ Thrown from DictConfig when a regular dict access would have caused a KeyError.
93
+ """
94
+
95
+ msg: str
96
+
97
+ def __init__(self, msg: str) -> None:
98
+ super().__init__(msg)
99
+ self.msg = msg
100
+
101
+ def __str__(self) -> str:
102
+ """
103
+ Workaround to nasty KeyError quirk: https://bugs.python.org/issue2651
104
+ """
105
+ return self.msg
106
+
107
+
108
+ class ConfigAttributeError(OmegaConfBaseException, AttributeError):
109
+ """
110
+ Thrown from a config object when a regular access would have caused an AttributeError.
111
+ """
112
+
113
+
114
+ class ConfigTypeError(OmegaConfBaseException, TypeError):
115
+ """
116
+ Thrown from a config object when a regular access would have caused a TypeError.
117
+ """
118
+
119
+
120
+ class ConfigIndexError(OmegaConfBaseException, IndexError):
121
+ """
122
+ Thrown from a config object when a regular access would have caused an IndexError.
123
+ """
124
+
125
+
126
+ class ConfigValueError(OmegaConfBaseException, ValueError):
127
+ """
128
+ Thrown from a config object when a regular access would have caused a ValueError.
129
+ """
130
+
131
+
132
+ class ConfigCycleDetectedException(OmegaConfBaseException):
133
+ """
134
+ Thrown when a cycle is detected in the graph made by config nodes.
135
+ """
136
+
137
+
138
+ class GrammarParseError(OmegaConfBaseException):
139
+ """
140
+ Thrown when failing to parse an expression according to the ANTLR grammar.
141
+ """
omegaconf/grammar/OmegaConfGrammarLexer.g4 ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Regenerate lexer and parser by running 'python setup.py antlr' at project root.
2
+ // See `OmegaConfGrammarParser.g4` for some important information regarding how to
3
+ // properly maintain this grammar.
4
+
5
+ lexer grammar OmegaConfGrammarLexer;
6
+
7
+ // Re-usable fragments.
8
+ fragment CHAR: [a-zA-Z];
9
+ fragment DIGIT: [0-9];
10
+ fragment INT_UNSIGNED: '0' | [1-9] (('_')? DIGIT)*;
11
+ fragment ESC_BACKSLASH: '\\\\'; // escaped backslash
12
+
13
+ /////////////////////////////
14
+ // DEFAULT_MODE (TOPLEVEL) //
15
+ /////////////////////////////
16
+
17
+ TOP_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
18
+
19
+ // Regular string: anything that does not contain any $ and does not end with \
20
+ // (this ensures this rule will not consume characters required to recognize other tokens).
21
+ ANY_STR: ~[$]* ~[\\$];
22
+
23
+ // Escaped interpolation: '\${', optionally preceded by an even number of \
24
+ ESC_INTER: ESC_BACKSLASH* '\\${';
25
+
26
+ // Backslashes that *may* be escaped (even number).
27
+ TOP_ESC: ESC_BACKSLASH+;
28
+
29
+ // Other backslashes that will not need escaping (odd number due to not matching the previous rule).
30
+ BACKSLASHES: '\\'+ -> type(ANY_STR);
31
+
32
+ // The dollar sign must be singled out so that we can recognize interpolations.
33
+ DOLLAR: '$' -> type(ANY_STR);
34
+
35
+
36
+ ////////////////
37
+ // VALUE_MODE //
38
+ ////////////////
39
+
40
+ mode VALUE_MODE;
41
+
42
+ INTER_OPEN: '${' WS? -> pushMode(INTERPOLATION_MODE);
43
+ BRACE_OPEN: '{' WS? -> pushMode(VALUE_MODE); // must keep track of braces to detect end of interpolation
44
+ BRACE_CLOSE: WS? '}' -> popMode;
45
+ QUOTE_OPEN_SINGLE: '\'' -> pushMode(QUOTED_SINGLE_MODE);
46
+ QUOTE_OPEN_DOUBLE: '"' -> pushMode(QUOTED_DOUBLE_MODE);
47
+
48
+ COMMA: WS? ',' WS?;
49
+ BRACKET_OPEN: '[' WS?;
50
+ BRACKET_CLOSE: WS? ']';
51
+ COLON: WS? ':' WS?;
52
+
53
+ // Numbers.
54
+
55
+ fragment POINT_FLOAT: INT_UNSIGNED '.' | INT_UNSIGNED? '.' DIGIT (('_')? DIGIT)*;
56
+ fragment EXPONENT_FLOAT: (INT_UNSIGNED | POINT_FLOAT) [eE] [+-]? DIGIT (('_')? DIGIT)*;
57
+ FLOAT: [+-]? (POINT_FLOAT | EXPONENT_FLOAT | [Ii][Nn][Ff] | [Nn][Aa][Nn]);
58
+ INT: [+-]? INT_UNSIGNED;
59
+
60
+ // Other reserved keywords.
61
+
62
+ BOOL:
63
+ [Tt][Rr][Uu][Ee] // TRUE
64
+ | [Ff][Aa][Ll][Ss][Ee]; // FALSE
65
+
66
+ NULL: [Nn][Uu][Ll][Ll];
67
+
68
+ UNQUOTED_CHAR: [/\-\\+.$%*@?|]; // other characters allowed in unquoted strings
69
+ ID: (CHAR|'_') (CHAR|DIGIT|'_'|'-')*;
70
+ ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
71
+ '\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
72
+ WS: [ \t]+;
73
+
74
+
75
+ ////////////////////////
76
+ // INTERPOLATION_MODE //
77
+ ////////////////////////
78
+
79
+ mode INTERPOLATION_MODE;
80
+
81
+ NESTED_INTER_OPEN: INTER_OPEN WS? -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
82
+ INTER_COLON: WS? ':' WS? -> type(COLON), mode(VALUE_MODE);
83
+ INTER_CLOSE: WS? '}' -> popMode;
84
+
85
+ DOT: '.';
86
+ INTER_BRACKET_OPEN: '[' -> type(BRACKET_OPEN);
87
+ INTER_BRACKET_CLOSE: ']' -> type(BRACKET_CLOSE);
88
+ INTER_ID: ID -> type(ID);
89
+
90
+ // Interpolation key, may contain any non special character.
91
+ // Note that we can allow '$' because the parser does not support interpolations that
92
+ // are only part of a key name, i.e., "${foo${bar}}" is not allowed. As a result, it
93
+ // is ok to "consume" all '$' characters within the `INTER_KEY` token.
94
+ INTER_KEY: ~[\\{}()[\]:. \t'"]+;
95
+
96
+
97
+ ////////////////////////
98
+ // QUOTED_SINGLE_MODE //
99
+ ////////////////////////
100
+
101
+ mode QUOTED_SINGLE_MODE;
102
+
103
+ // This mode is very similar to `DEFAULT_MODE` except for the handling of quotes.
104
+
105
+ QSINGLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
106
+ MATCHING_QUOTE_CLOSE: '\'' -> popMode;
107
+
108
+ // Regular string: anything that does not contain any $ *or quote* and does not end with \
109
+ QSINGLE_STR: ~['$]* ~['\\$] -> type(ANY_STR);
110
+
111
+ QSINGLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);
112
+
113
+ // Escaped quote (optionally preceded by an even number of backslashes).
114
+ QSINGLE_ESC_QUOTE: ESC_BACKSLASH* '\\\'' -> type(ESC);
115
+
116
+ QUOTED_ESC: ESC_BACKSLASH+;
117
+ QSINGLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
118
+ QSINGLE_DOLLAR: '$' -> type(ANY_STR);
119
+
120
+
121
+ ////////////////////////
122
+ // QUOTED_DOUBLE_MODE //
123
+ ////////////////////////
124
+
125
+ mode QUOTED_DOUBLE_MODE;
126
+
127
+ // Same as `QUOTED_SINGLE_MODE` but for double quotes.
128
+
129
+ QDOUBLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
130
+ QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode;
131
+
132
+ QDOUBLE_STR: ~["$]* ~["\\$] -> type(ANY_STR);
133
+ QDOUBLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);
134
+ QDOUBLE_ESC_QUOTE: ESC_BACKSLASH* '\\"' -> type(ESC);
135
+ QDOUBLE_ESC: ESC_BACKSLASH+ -> type(QUOTED_ESC);
136
+ QDOUBLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
137
+ QDOUBLE_DOLLAR: '$' -> type(ANY_STR);
omegaconf/grammar/OmegaConfGrammarParser.g4 ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Regenerate parser by running 'python setup.py antlr' at project root.
2
+
3
+ // Maintenance guidelines when modifying this grammar:
4
+ //
5
+ // - Consider whether the regex pattern `SIMPLE_INTERPOLATION_PATTERN` found in
6
+ // `grammar_parser.py` should be updated as well.
7
+ //
8
+ // - Update Hydra's grammar accordingly.
9
+ //
10
+ // - Keep up-to-date the comments in the visitor (in `grammar_visitor.py`)
11
+ // that contain grammar excerpts (within each `visit...()` method).
12
+ //
13
+ // - Remember to update the documentation (including the tutorial notebook as
14
+ // well as grammar.rst)
15
+
16
+ parser grammar OmegaConfGrammarParser;
17
+ options {tokenVocab = OmegaConfGrammarLexer;}
18
+
19
+ // Main rules used to parse OmegaConf strings.
20
+
21
+ configValue: text EOF;
22
+ singleElement: element EOF;
23
+
24
+
25
+ // Composite text expression (may contain interpolations).
26
+
27
+ text: (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+;
28
+
29
+
30
+ // Elements.
31
+
32
+ element:
33
+ primitive
34
+ | quotedValue
35
+ | listContainer
36
+ | dictContainer
37
+ ;
38
+
39
+
40
+ // Data structures.
41
+
42
+ listContainer: BRACKET_OPEN sequence? BRACKET_CLOSE; // [], [1,2,3], [a,b,[1,2]]
43
+ dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
44
+ dictKeyValuePair: dictKey COLON element;
45
+ sequence: (element (COMMA element?)*) | (COMMA element?)+;
46
+
47
+
48
+ // Interpolations.
49
+
50
+ interpolation: interpolationNode | interpolationResolver;
51
+
52
+ interpolationNode:
53
+ INTER_OPEN
54
+ DOT* // relative interpolation?
55
+ (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
56
+ (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
57
+ INTER_CLOSE;
58
+ interpolationResolver: INTER_OPEN resolverName COLON sequence? BRACE_CLOSE;
59
+ configKey: interpolation | ID | INTER_KEY;
60
+ resolverName: (interpolation | ID) (DOT (interpolation | ID))* ; // oc.env, myfunc, ns.${x}, ns1.ns2.f
61
+
62
+
63
+ // Primitive types.
64
+
65
+ // Ex: "hello world", 'hello ${world}'
66
+ quotedValue: (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE;
67
+
68
+ primitive:
69
+ ( ID // foo_10
70
+ | NULL // null, NULL
71
+ | INT // 0, 10, -20, 1_000_000
72
+ | FLOAT // 3.14, -20.0, 1e-1, -10e3
73
+ | BOOL // true, TrUe, false, False
74
+ | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, |
75
+ | COLON // :
76
+ | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
77
+ | WS // whitespaces
78
+ | interpolation
79
+ )+;
80
+
81
+ // Same as `primitive` except that `COLON` and interpolations are not allowed.
82
+ dictKey:
83
+ ( ID // foo_10
84
+ | NULL // null, NULL
85
+ | INT // 0, 10, -20, 1_000_000
86
+ | FLOAT // 3.14, -20.0, 1e-1, -10e3
87
+ | BOOL // true, TrUe, false, False
88
+ | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, |
89
+ | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
90
+ | WS // whitespaces
91
+ )+;
omegaconf/grammar/__init__.py ADDED
File without changes
omegaconf/grammar/gen/__init__.py ADDED
File without changes
omegaconf/grammar_parser.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import threading
3
+ from typing import Any
4
+
5
+ from antlr4 import CommonTokenStream, InputStream, ParserRuleContext
6
+ from antlr4.error.ErrorListener import ErrorListener
7
+
8
+ from .errors import GrammarParseError
9
+
10
+ # Import from visitor in order to check the presence of generated grammar files
11
+ # files in a single place.
12
+ from .grammar_visitor import ( # type: ignore
13
+ OmegaConfGrammarLexer,
14
+ OmegaConfGrammarParser,
15
+ )
16
+
17
+ # Used to cache grammar objects to avoid re-creating them on each call to `parse()`.
18
+ # We use a per-thread cache to make it thread-safe.
19
+ _grammar_cache = threading.local()
20
+
21
+ # Build regex pattern to efficiently identify typical interpolations.
22
+ # See test `test_match_simple_interpolation_pattern` for examples.
23
+ _config_key = r"[$\w]+" # foo, $0, $bar, $foo_$bar123$
24
+ _key_maybe_brackets = f"{_config_key}|\\[{_config_key}\\]" # foo, [foo], [$bar]
25
+ _node_access = f"\\.{_key_maybe_brackets}" # .foo, [foo], [$bar]
26
+ _node_path = f"(\\.)*({_key_maybe_brackets})({_node_access})*" # [foo].bar, .foo[bar]
27
+ _node_inter = f"\\${{\\s*{_node_path}\\s*}}" # node interpolation ${foo.bar}
28
+ _id = "[a-zA-Z_][\\w\\-]*" # foo, foo_bar, foo-bar, abc123
29
+ _resolver_name = f"({_id}(\\.{_id})*)?" # foo, ns.bar3, ns_1.ns_2.b0z
30
+ _arg = r"[a-zA-Z_0-9/\-\+.$%*@?|]+" # string representing a resolver argument
31
+ _args = f"{_arg}(\\s*,\\s*{_arg})*" # list of resolver arguments
32
+ _resolver_inter = f"\\${{\\s*{_resolver_name}\\s*:\\s*{_args}?\\s*}}" # ${foo:bar}
33
+ _inter = f"({_node_inter}|{_resolver_inter})" # any kind of interpolation
34
+ _outer = "([^$]|\\$(?!{))+" # any character except $ (unless not followed by {)
35
+ SIMPLE_INTERPOLATION_PATTERN = re.compile(
36
+ f"({_outer})?({_inter}({_outer})?)+$", flags=re.ASCII
37
+ )
38
+ # NOTE: SIMPLE_INTERPOLATION_PATTERN must not generate false positive matches:
39
+ # it must not accept anything that isn't a valid interpolation (per the
40
+ # interpolation grammar defined in `omegaconf/grammar/*.g4`).
41
+
42
+
43
+ class OmegaConfErrorListener(ErrorListener): # type: ignore
44
+ def syntaxError(
45
+ self,
46
+ recognizer: Any,
47
+ offending_symbol: Any,
48
+ line: Any,
49
+ column: Any,
50
+ msg: Any,
51
+ e: Any,
52
+ ) -> None:
53
+ raise GrammarParseError(str(e) if msg is None else msg) from e
54
+
55
+ def reportAmbiguity(
56
+ self,
57
+ recognizer: Any,
58
+ dfa: Any,
59
+ startIndex: Any,
60
+ stopIndex: Any,
61
+ exact: Any,
62
+ ambigAlts: Any,
63
+ configs: Any,
64
+ ) -> None:
65
+ raise GrammarParseError("ANTLR error: Ambiguity") # pragma: no cover
66
+
67
+ def reportAttemptingFullContext(
68
+ self,
69
+ recognizer: Any,
70
+ dfa: Any,
71
+ startIndex: Any,
72
+ stopIndex: Any,
73
+ conflictingAlts: Any,
74
+ configs: Any,
75
+ ) -> None:
76
+ # Note: for now we raise an error to be safe. However this is mostly a
77
+ # performance warning, so in the future this may be relaxed if we need
78
+ # to change the grammar in such a way that this warning cannot be
79
+ # avoided (another option would be to switch to SLL parsing mode).
80
+ raise GrammarParseError(
81
+ "ANTLR error: Attempting Full Context"
82
+ ) # pragma: no cover
83
+
84
+ def reportContextSensitivity(
85
+ self,
86
+ recognizer: Any,
87
+ dfa: Any,
88
+ startIndex: Any,
89
+ stopIndex: Any,
90
+ prediction: Any,
91
+ configs: Any,
92
+ ) -> None:
93
+ raise GrammarParseError("ANTLR error: ContextSensitivity") # pragma: no cover
94
+
95
+
96
+ def parse(
97
+ value: str, parser_rule: str = "configValue", lexer_mode: str = "DEFAULT_MODE"
98
+ ) -> ParserRuleContext:
99
+ """
100
+ Parse interpolated string `value` (and return the parse tree).
101
+ """
102
+ l_mode = getattr(OmegaConfGrammarLexer, lexer_mode)
103
+ istream = InputStream(value)
104
+
105
+ cached = getattr(_grammar_cache, "data", None)
106
+ if cached is None:
107
+ error_listener = OmegaConfErrorListener()
108
+ lexer = OmegaConfGrammarLexer(istream)
109
+ lexer.removeErrorListeners()
110
+ lexer.addErrorListener(error_listener)
111
+ lexer.mode(l_mode)
112
+ token_stream = CommonTokenStream(lexer)
113
+ parser = OmegaConfGrammarParser(token_stream)
114
+ parser.removeErrorListeners()
115
+ parser.addErrorListener(error_listener)
116
+
117
+ # The two lines below could be enabled in the future if we decide to switch
118
+ # to SLL prediction mode. Warning though, it has not been fully tested yet!
119
+ # from antlr4 import PredictionMode
120
+ # parser._interp.predictionMode = PredictionMode.SLL
121
+
122
+ # Note that although the input stream `istream` is implicitly cached within
123
+ # the lexer, it will be replaced by a new input next time the lexer is re-used.
124
+ _grammar_cache.data = lexer, token_stream, parser
125
+
126
+ else:
127
+ lexer, token_stream, parser = cached
128
+ # Replace the old input stream with the new one.
129
+ lexer.inputStream = istream
130
+ # Initialize the lexer / token stream / parser to process the new input.
131
+ lexer.mode(l_mode)
132
+ token_stream.setTokenSource(lexer)
133
+ parser.reset()
134
+
135
+ try:
136
+ return getattr(parser, parser_rule)()
137
+ except Exception as exc:
138
+ if type(exc) is Exception and str(exc) == "Empty Stack":
139
+ # This exception is raised by antlr when trying to pop a mode while
140
+ # no mode has been pushed. We convert it into an `GrammarParseError`
141
+ # to facilitate exception handling from the caller.
142
+ raise GrammarParseError("Empty Stack")
143
+ else:
144
+ raise
omegaconf/grammar_visitor.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+ from itertools import zip_longest
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Callable,
8
+ Dict,
9
+ Generator,
10
+ List,
11
+ Optional,
12
+ Set,
13
+ Tuple,
14
+ Union,
15
+ )
16
+
17
+ from antlr4 import TerminalNode
18
+
19
+ from .errors import InterpolationResolutionError
20
+
21
+ if TYPE_CHECKING:
22
+ from .base import Node # noqa F401
23
+
24
+ try:
25
+ from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer
26
+ from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
27
+ from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import (
28
+ OmegaConfGrammarParserVisitor,
29
+ )
30
+
31
+ except ModuleNotFoundError: # pragma: no cover
32
+ print(
33
+ "Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.",
34
+ file=sys.stderr,
35
+ )
36
+ sys.exit(1)
37
+
38
+
39
+ class GrammarVisitor(OmegaConfGrammarParserVisitor):
40
+ def __init__(
41
+ self,
42
+ node_interpolation_callback: Callable[
43
+ [str, Optional[Set[int]]],
44
+ Optional["Node"],
45
+ ],
46
+ resolver_interpolation_callback: Callable[..., Any],
47
+ memo: Optional[Set[int]],
48
+ **kw: Dict[Any, Any],
49
+ ):
50
+ """
51
+ Constructor.
52
+
53
+ :param node_interpolation_callback: Callback function that is called when
54
+ needing to resolve a node interpolation. This function should take a single
55
+ string input which is the key's dot path (ex: `"foo.bar"`).
56
+
57
+ :param resolver_interpolation_callback: Callback function that is called when
58
+ needing to resolve a resolver interpolation. This function should accept
59
+ three keyword arguments: `name` (str, the name of the resolver),
60
+ `args` (tuple, the inputs to the resolver), and `args_str` (tuple,
61
+ the string representation of the inputs to the resolver).
62
+
63
+ :param kw: Additional keyword arguments to be forwarded to parent class.
64
+ """
65
+ super().__init__(**kw)
66
+ self.node_interpolation_callback = node_interpolation_callback
67
+ self.resolver_interpolation_callback = resolver_interpolation_callback
68
+ self.memo = memo
69
+
70
+ def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]:
71
+ raise NotImplementedError
72
+
73
+ def defaultResult(self) -> List[Any]:
74
+ # Raising an exception because not currently used (like `aggregateResult()`).
75
+ raise NotImplementedError
76
+
77
+ def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
78
+ from ._utils import _get_value
79
+
80
+ # interpolation | ID | INTER_KEY
81
+ assert ctx.getChildCount() == 1
82
+ child = ctx.getChild(0)
83
+ if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
84
+ res = _get_value(self.visitInterpolation(child))
85
+ if not isinstance(res, str):
86
+ raise InterpolationResolutionError(
87
+ f"The following interpolation is used to denote a config key and "
88
+ f"thus should return a string, but instead returned `{res}` of "
89
+ f"type `{type(res)}`: {ctx.getChild(0).getText()}"
90
+ )
91
+ return res
92
+ else:
93
+ assert isinstance(child, TerminalNode) and isinstance(
94
+ child.symbol.text, str
95
+ )
96
+ return child.symbol.text
97
+
98
+ def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
99
+ # text EOF
100
+ assert ctx.getChildCount() == 2
101
+ return self.visit(ctx.getChild(0))
102
+
103
+ def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any:
104
+ return self._createPrimitive(ctx)
105
+
106
+ def visitDictContainer(
107
+ self, ctx: OmegaConfGrammarParser.DictContainerContext
108
+ ) -> Dict[Any, Any]:
109
+ # BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE
110
+ assert ctx.getChildCount() >= 2
111
+ return dict(
112
+ self.visitDictKeyValuePair(ctx.getChild(i))
113
+ for i in range(1, ctx.getChildCount() - 1, 2)
114
+ )
115
+
116
+ def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:
117
+ # primitive | quotedValue | listContainer | dictContainer
118
+ assert ctx.getChildCount() == 1
119
+ return self.visit(ctx.getChild(0))
120
+
121
+ def visitInterpolation(
122
+ self, ctx: OmegaConfGrammarParser.InterpolationContext
123
+ ) -> Any:
124
+ assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
125
+ return self.visit(ctx.getChild(0))
126
+
127
+ def visitInterpolationNode(
128
+ self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
129
+ ) -> Optional["Node"]:
130
+ # INTER_OPEN
131
+ # DOT* // relative interpolation?
132
+ # (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
133
+ # (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
134
+ # INTER_CLOSE;
135
+
136
+ assert ctx.getChildCount() >= 3
137
+
138
+ inter_key_tokens = [] # parsed elements of the dot path
139
+ for child in ctx.getChildren():
140
+ if isinstance(child, TerminalNode):
141
+ s = child.symbol
142
+ if s.type in [
143
+ OmegaConfGrammarLexer.DOT,
144
+ OmegaConfGrammarLexer.BRACKET_OPEN,
145
+ OmegaConfGrammarLexer.BRACKET_CLOSE,
146
+ ]:
147
+ inter_key_tokens.append(s.text)
148
+ else:
149
+ assert s.type in (
150
+ OmegaConfGrammarLexer.INTER_OPEN,
151
+ OmegaConfGrammarLexer.INTER_CLOSE,
152
+ )
153
+ else:
154
+ assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext)
155
+ inter_key_tokens.append(self.visitConfigKey(child))
156
+
157
+ inter_key = "".join(inter_key_tokens)
158
+ return self.node_interpolation_callback(inter_key, self.memo)
159
+
160
+ def visitInterpolationResolver(
161
+ self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
162
+ ) -> Any:
163
+
164
+ # INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
165
+ assert 4 <= ctx.getChildCount() <= 5
166
+
167
+ resolver_name = self.visit(ctx.getChild(1))
168
+ maybe_seq = ctx.getChild(3)
169
+ args = []
170
+ args_str = []
171
+ if isinstance(maybe_seq, TerminalNode): # means there are no args
172
+ assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE
173
+ else:
174
+ assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext)
175
+ for val, txt in self.visitSequence(maybe_seq):
176
+ args.append(val)
177
+ args_str.append(txt)
178
+
179
+ return self.resolver_interpolation_callback(
180
+ name=resolver_name,
181
+ args=tuple(args),
182
+ args_str=tuple(args_str),
183
+ )
184
+
185
+ def visitDictKeyValuePair(
186
+ self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext
187
+ ) -> Tuple[Any, Any]:
188
+ from ._utils import _get_value
189
+
190
+ assert ctx.getChildCount() == 3 # dictKey COLON element
191
+ key = self.visit(ctx.getChild(0))
192
+ colon = ctx.getChild(1)
193
+ assert (
194
+ isinstance(colon, TerminalNode)
195
+ and colon.symbol.type == OmegaConfGrammarLexer.COLON
196
+ )
197
+ value = _get_value(self.visitElement(ctx.getChild(2)))
198
+ return key, value
199
+
200
+ def visitListContainer(
201
+ self, ctx: OmegaConfGrammarParser.ListContainerContext
202
+ ) -> List[Any]:
203
+ # BRACKET_OPEN sequence? BRACKET_CLOSE;
204
+ assert ctx.getChildCount() in (2, 3)
205
+ if ctx.getChildCount() == 2:
206
+ return []
207
+ sequence = ctx.getChild(1)
208
+ assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext)
209
+ return list(val for val, _ in self.visitSequence(sequence)) # ignore raw text
210
+
211
+ def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any:
212
+ return self._createPrimitive(ctx)
213
+
214
+ def visitQuotedValue(self, ctx: OmegaConfGrammarParser.QuotedValueContext) -> str:
215
+ # (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE
216
+ n = ctx.getChildCount()
217
+ assert n in [2, 3]
218
+ return str(self.visit(ctx.getChild(1))) if n == 3 else ""
219
+
220
+ def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> str:
221
+ from ._utils import _get_value
222
+
223
+ # (interpolation | ID) (DOT (interpolation | ID))*
224
+ assert ctx.getChildCount() >= 1
225
+ items = []
226
+ for child in list(ctx.getChildren())[::2]:
227
+ if isinstance(child, TerminalNode):
228
+ assert child.symbol.type == OmegaConfGrammarLexer.ID
229
+ items.append(child.symbol.text)
230
+ else:
231
+ assert isinstance(child, OmegaConfGrammarParser.InterpolationContext)
232
+ item = _get_value(self.visitInterpolation(child))
233
+ if not isinstance(item, str):
234
+ raise InterpolationResolutionError(
235
+ f"The name of a resolver must be a string, but the interpolation "
236
+ f"{child.getText()} resolved to `{item}` which is of type "
237
+ f"{type(item)}"
238
+ )
239
+ items.append(item)
240
+ return ".".join(items)
241
+
242
+ def visitSequence(
243
+ self, ctx: OmegaConfGrammarParser.SequenceContext
244
+ ) -> Generator[Any, None, None]:
245
+ from ._utils import _get_value
246
+
247
+ # (element (COMMA element?)*) | (COMMA element?)+
248
+ assert ctx.getChildCount() >= 1
249
+
250
+ # DEPRECATED: remove in 2.2 (revert #571)
251
+ def empty_str_warning() -> None:
252
+ txt = ctx.getText()
253
+ warnings.warn(
254
+ f"In the sequence `{txt}` some elements are missing: please replace "
255
+ f"them with empty quoted strings. "
256
+ f"See https://github.com/omry/omegaconf/issues/572 for details.",
257
+ category=UserWarning,
258
+ )
259
+
260
+ is_previous_comma = True # whether previous child was a comma (init to True)
261
+ for child in ctx.getChildren():
262
+ if isinstance(child, OmegaConfGrammarParser.ElementContext):
263
+ # Also preserve the original text representation of `child` so
264
+ # as to allow backward compatibility with old resolvers (registered
265
+ # with `legacy_register_resolver()`). Note that we cannot just cast
266
+ # the value to string later as for instance `null` would become "None".
267
+ yield _get_value(self.visitElement(child)), child.getText()
268
+ is_previous_comma = False
269
+ else:
270
+ assert (
271
+ isinstance(child, TerminalNode)
272
+ and child.symbol.type == OmegaConfGrammarLexer.COMMA
273
+ )
274
+ if is_previous_comma:
275
+ empty_str_warning()
276
+ yield "", ""
277
+ else:
278
+ is_previous_comma = True
279
+ if is_previous_comma:
280
+ # Trailing comma.
281
+ empty_str_warning()
282
+ yield "", ""
283
+
284
+ def visitSingleElement(
285
+ self, ctx: OmegaConfGrammarParser.SingleElementContext
286
+ ) -> Any:
287
+ # element EOF
288
+ assert ctx.getChildCount() == 2
289
+ return self.visit(ctx.getChild(0))
290
+
291
+ def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
292
+ # (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+
293
+
294
+ # Single interpolation? If yes, return its resolved value "as is".
295
+ if ctx.getChildCount() == 1:
296
+ c = ctx.getChild(0)
297
+ if isinstance(c, OmegaConfGrammarParser.InterpolationContext):
298
+ return self.visitInterpolation(c)
299
+
300
+ # Otherwise, concatenate string representations together.
301
+ return self._unescape(list(ctx.getChildren()))
302
+
303
+ def _createPrimitive(
304
+ self,
305
+ ctx: Union[
306
+ OmegaConfGrammarParser.PrimitiveContext,
307
+ OmegaConfGrammarParser.DictKeyContext,
308
+ ],
309
+ ) -> Any:
310
+ # (ID | NULL | INT | FLOAT | BOOL | UNQUOTED_CHAR | COLON | ESC | WS | interpolation)+
311
+ if ctx.getChildCount() == 1:
312
+ child = ctx.getChild(0)
313
+ if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
314
+ return self.visitInterpolation(child)
315
+ assert isinstance(child, TerminalNode)
316
+ symbol = child.symbol
317
+ # Parse primitive types.
318
+ if symbol.type in (
319
+ OmegaConfGrammarLexer.ID,
320
+ OmegaConfGrammarLexer.UNQUOTED_CHAR,
321
+ OmegaConfGrammarLexer.COLON,
322
+ ):
323
+ return symbol.text
324
+ elif symbol.type == OmegaConfGrammarLexer.NULL:
325
+ return None
326
+ elif symbol.type == OmegaConfGrammarLexer.INT:
327
+ return int(symbol.text)
328
+ elif symbol.type == OmegaConfGrammarLexer.FLOAT:
329
+ return float(symbol.text)
330
+ elif symbol.type == OmegaConfGrammarLexer.BOOL:
331
+ return symbol.text.lower() == "true"
332
+ elif symbol.type == OmegaConfGrammarLexer.ESC:
333
+ return self._unescape([child])
334
+ elif symbol.type == OmegaConfGrammarLexer.WS: # pragma: no cover
335
+ # A single WS should have been "consumed" by another token.
336
+ raise AssertionError("WS should never be reached")
337
+ assert False, symbol.type
338
+ # Concatenation of multiple items ==> un-escape the concatenation.
339
+ return self._unescape(list(ctx.getChildren()))
340
+
341
+ def _unescape(
342
+ self,
343
+ seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
344
+ ) -> str:
345
+ """
346
+ Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed.
347
+
348
+ Interpolations are resolved and cast to string *WITHOUT* escaping their result
349
+ (it is assumed that whatever escaping is required was already handled during the
350
+ resolving of the interpolation).
351
+ """
352
+ chrs = []
353
+ for node, next_node in zip_longest(seq, seq[1:]):
354
+ if isinstance(node, TerminalNode):
355
+ s = node.symbol
356
+ if s.type == OmegaConfGrammarLexer.ESC_INTER:
357
+ # `ESC_INTER` is of the form `\\...\${`: the formula below computes
358
+ # the number of characters to keep at the end of the string to remove
359
+ # the correct number of backslashes.
360
+ text = s.text[-(len(s.text) // 2 + 1) :]
361
+ elif (
362
+ # Character sequence identified as requiring un-escaping.
363
+ s.type == OmegaConfGrammarLexer.ESC
364
+ or (
365
+ # At top level, we need to un-escape backslashes that precede
366
+ # an interpolation.
367
+ s.type == OmegaConfGrammarLexer.TOP_ESC
368
+ and isinstance(
369
+ next_node, OmegaConfGrammarParser.InterpolationContext
370
+ )
371
+ )
372
+ or (
373
+ # In a quoted sring, we need to un-escape backslashes that
374
+ # either end the string, or are followed by an interpolation.
375
+ s.type == OmegaConfGrammarLexer.QUOTED_ESC
376
+ and (
377
+ next_node is None
378
+ or isinstance(
379
+ next_node, OmegaConfGrammarParser.InterpolationContext
380
+ )
381
+ )
382
+ )
383
+ ):
384
+ text = s.text[1::2] # un-escape the sequence
385
+ else:
386
+ text = s.text # keep the original text
387
+ else:
388
+ assert isinstance(node, OmegaConfGrammarParser.InterpolationContext)
389
+ text = str(self.visitInterpolation(node))
390
+ chrs.append(text)
391
+
392
+ return "".join(chrs)
omegaconf/listconfig.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Iterable,
8
+ Iterator,
9
+ List,
10
+ MutableSequence,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ Union,
15
+ )
16
+
17
+ from ._utils import (
18
+ ValueKind,
19
+ _is_missing_literal,
20
+ _is_none,
21
+ _resolve_optional,
22
+ format_and_raise,
23
+ get_value_kind,
24
+ is_int,
25
+ is_primitive_list,
26
+ is_structured_config,
27
+ type_str,
28
+ )
29
+ from .base import Box, ContainerMetadata, Node
30
+ from .basecontainer import BaseContainer
31
+ from .errors import (
32
+ ConfigAttributeError,
33
+ ConfigTypeError,
34
+ ConfigValueError,
35
+ KeyValidationError,
36
+ MissingMandatoryValue,
37
+ ReadonlyConfigError,
38
+ ValidationError,
39
+ )
40
+
41
+
42
+ class ListConfig(BaseContainer, MutableSequence[Any]):
43
+
44
+ _content: Union[List[Node], None, str]
45
+
46
+ def __init__(
47
+ self,
48
+ content: Union[List[Any], Tuple[Any, ...], "ListConfig", str, None],
49
+ key: Any = None,
50
+ parent: Optional[Box] = None,
51
+ element_type: Union[Type[Any], Any] = Any,
52
+ is_optional: bool = True,
53
+ ref_type: Union[Type[Any], Any] = Any,
54
+ flags: Optional[Dict[str, bool]] = None,
55
+ ) -> None:
56
+ try:
57
+ if isinstance(content, ListConfig):
58
+ if flags is None:
59
+ flags = content._metadata.flags
60
+ super().__init__(
61
+ parent=parent,
62
+ metadata=ContainerMetadata(
63
+ ref_type=ref_type,
64
+ object_type=list,
65
+ key=key,
66
+ optional=is_optional,
67
+ element_type=element_type,
68
+ key_type=int,
69
+ flags=flags,
70
+ ),
71
+ )
72
+
73
+ if isinstance(content, ListConfig):
74
+ metadata = copy.deepcopy(content._metadata)
75
+ metadata.key = key
76
+ metadata.ref_type = ref_type
77
+ metadata.optional = is_optional
78
+ metadata.element_type = element_type
79
+ self.__dict__["_metadata"] = metadata
80
+ self._set_value(value=content, flags=flags)
81
+ except Exception as ex:
82
+ format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
83
+
84
+ def _validate_get(self, key: Any, value: Any = None) -> None:
85
+ if not isinstance(key, (int, slice)):
86
+ raise KeyValidationError(
87
+ "ListConfig indices must be integers or slices, not $KEY_TYPE"
88
+ )
89
+
90
+ def _validate_set(self, key: Any, value: Any) -> None:
91
+ from omegaconf import OmegaConf
92
+
93
+ self._validate_get(key, value)
94
+
95
+ if self._get_flag("readonly"):
96
+ raise ReadonlyConfigError("ListConfig is read-only")
97
+
98
+ if 0 <= key < self.__len__():
99
+ target = self._get_node(key)
100
+ if target is not None:
101
+ assert isinstance(target, Node)
102
+ if value is None and not target._is_optional():
103
+ raise ValidationError(
104
+ "$FULL_KEY is not optional and cannot be assigned None"
105
+ )
106
+
107
+ vk = get_value_kind(value)
108
+ if vk == ValueKind.MANDATORY_MISSING:
109
+ return
110
+ else:
111
+ is_optional, target_type = _resolve_optional(self._metadata.element_type)
112
+ value_type = OmegaConf.get_type(value)
113
+
114
+ if (value_type is None and not is_optional) or (
115
+ is_structured_config(target_type)
116
+ and value_type is not None
117
+ and not issubclass(value_type, target_type)
118
+ ):
119
+ msg = (
120
+ f"Invalid type assigned: {type_str(value_type)} is not a "
121
+ f"subclass of {type_str(target_type)}. value: {value}"
122
+ )
123
+ raise ValidationError(msg)
124
+
125
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
126
+ res = ListConfig(None)
127
+ res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
128
+ res.__dict__["_flags_cache"] = copy.deepcopy(
129
+ self.__dict__["_flags_cache"], memo=memo
130
+ )
131
+
132
+ src_content = self.__dict__["_content"]
133
+ if isinstance(src_content, list):
134
+ content_copy: List[Optional[Node]] = []
135
+ for v in src_content:
136
+ old_parent = v.__dict__["_parent"]
137
+ try:
138
+ v.__dict__["_parent"] = None
139
+ vc = copy.deepcopy(v, memo=memo)
140
+ vc.__dict__["_parent"] = res
141
+ content_copy.append(vc)
142
+ finally:
143
+ v.__dict__["_parent"] = old_parent
144
+ else:
145
+ # None and strings can be assigned as is
146
+ content_copy = src_content
147
+
148
+ res.__dict__["_content"] = content_copy
149
+ res.__dict__["_parent"] = self.__dict__["_parent"]
150
+
151
+ return res
152
+
153
+ def copy(self) -> "ListConfig":
154
+ return copy.copy(self)
155
+
156
+ # hide content while inspecting in debugger
157
+ def __dir__(self) -> Iterable[str]:
158
+ if self._is_missing() or self._is_none():
159
+ return []
160
+ return [str(x) for x in range(0, len(self))]
161
+
162
+ def __setattr__(self, key: str, value: Any) -> None:
163
+ self._format_and_raise(
164
+ key=key,
165
+ value=value,
166
+ cause=ConfigAttributeError("ListConfig does not support attribute access"),
167
+ )
168
+ assert False
169
+
170
+ def __getattr__(self, key: str) -> Any:
171
+ # PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that.
172
+ if key == "__members__":
173
+ raise AttributeError()
174
+
175
+ if key == "__name__":
176
+ raise AttributeError()
177
+
178
+ if is_int(key):
179
+ return self.__getitem__(int(key))
180
+ else:
181
+ self._format_and_raise(
182
+ key=key,
183
+ value=None,
184
+ cause=ConfigAttributeError(
185
+ "ListConfig does not support attribute access"
186
+ ),
187
+ )
188
+
189
+ def __getitem__(self, index: Union[int, slice]) -> Any:
190
+ try:
191
+ if self._is_missing():
192
+ raise MissingMandatoryValue("ListConfig is missing")
193
+ self._validate_get(index, None)
194
+ if self._is_none():
195
+ raise TypeError(
196
+ "ListConfig object representing None is not subscriptable"
197
+ )
198
+
199
+ assert isinstance(self.__dict__["_content"], list)
200
+ if isinstance(index, slice):
201
+ result = []
202
+ start, stop, step = self._correct_index_params(index)
203
+ for slice_idx in itertools.islice(
204
+ range(0, len(self)), start, stop, step
205
+ ):
206
+ val = self._resolve_with_default(
207
+ key=slice_idx, value=self.__dict__["_content"][slice_idx]
208
+ )
209
+ result.append(val)
210
+ if index.step and index.step < 0:
211
+ result.reverse()
212
+ return result
213
+ else:
214
+ return self._resolve_with_default(
215
+ key=index, value=self.__dict__["_content"][index]
216
+ )
217
+ except Exception as e:
218
+ self._format_and_raise(key=index, value=None, cause=e)
219
+
220
+ def _correct_index_params(self, index: slice) -> Tuple[int, int, int]:
221
+ start = index.start
222
+ stop = index.stop
223
+ step = index.step
224
+ if index.start and index.start < 0:
225
+ start = self.__len__() + index.start
226
+ if index.stop and index.stop < 0:
227
+ stop = self.__len__() + index.stop
228
+ if index.step and index.step < 0:
229
+ step = abs(step)
230
+ if start and stop:
231
+ if start > stop:
232
+ start, stop = stop + 1, start + 1
233
+ else:
234
+ start = stop = 0
235
+ elif not start and stop:
236
+ start = list(range(self.__len__() - 1, stop, -step))[0]
237
+ stop = None
238
+ elif start and not stop:
239
+ stop = start + 1
240
+ start = (stop - 1) % step
241
+ else:
242
+ start = (self.__len__() - 1) % step
243
+ return start, stop, step
244
+
245
+ def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
246
+ self._set_item_impl(index, value)
247
+
248
+ def __setitem__(self, index: Union[int, slice], value: Any) -> None:
249
+ try:
250
+ if isinstance(index, slice):
251
+ _ = iter(value) # check iterable
252
+ self_indices = index.indices(len(self))
253
+ indexes = range(*self_indices)
254
+
255
+ # Ensure lengths match for extended slice assignment
256
+ if index.step not in (None, 1):
257
+ if len(indexes) != len(value):
258
+ raise ValueError(
259
+ f"attempt to assign sequence of size {len(value)}"
260
+ f" to extended slice of size {len(indexes)}"
261
+ )
262
+
263
+ # Initialize insertion offsets for empty slices
264
+ if len(indexes) == 0:
265
+ curr_index = self_indices[0] - 1
266
+ val_i = -1
267
+
268
+ work_copy = self.copy() # For atomicity manipulate a copy
269
+
270
+ # Delete and optionally replace non empty slices
271
+ only_removed = 0
272
+ for val_i, i in enumerate(indexes):
273
+ curr_index = i - only_removed
274
+ del work_copy[curr_index]
275
+ if val_i < len(value):
276
+ work_copy.insert(curr_index, value[val_i])
277
+ else:
278
+ only_removed += 1
279
+
280
+ # Insert any remaining input items
281
+ for val_i in range(val_i + 1, len(value)):
282
+ curr_index += 1
283
+ work_copy.insert(curr_index, value[val_i])
284
+
285
+ # Reinitialize self with work_copy
286
+ self.clear()
287
+ self.extend(work_copy)
288
+ else:
289
+ self._set_at_index(index, value)
290
+ except Exception as e:
291
+ self._format_and_raise(key=index, value=value, cause=e)
292
+
293
+ def append(self, item: Any) -> None:
294
+ content = self.__dict__["_content"]
295
+ index = len(content)
296
+ content.append(None)
297
+ try:
298
+ self._set_item_impl(index, item)
299
+ except Exception as e:
300
+ del content[index]
301
+ self._format_and_raise(key=index, value=item, cause=e)
302
+ assert False
303
+
304
+ def _update_keys(self) -> None:
305
+ for i in range(len(self)):
306
+ node = self._get_node(i)
307
+ if node is not None:
308
+ assert isinstance(node, Node)
309
+ node._metadata.key = i
310
+
311
+ def insert(self, index: int, item: Any) -> None:
312
+ from omegaconf.omegaconf import _maybe_wrap
313
+
314
+ try:
315
+ if self._get_flag("readonly"):
316
+ raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
317
+ if self._is_none():
318
+ raise TypeError(
319
+ "Cannot insert into ListConfig object representing None"
320
+ )
321
+ if self._is_missing():
322
+ raise MissingMandatoryValue("Cannot insert into missing ListConfig")
323
+
324
+ try:
325
+ assert isinstance(self.__dict__["_content"], list)
326
+ # insert place holder
327
+ self.__dict__["_content"].insert(index, None)
328
+ is_optional, ref_type = _resolve_optional(self._metadata.element_type)
329
+ node = _maybe_wrap(
330
+ ref_type=ref_type,
331
+ key=index,
332
+ value=item,
333
+ is_optional=is_optional,
334
+ parent=self,
335
+ )
336
+ self._validate_set(key=index, value=node)
337
+ self._set_at_index(index, node)
338
+ self._update_keys()
339
+ except Exception:
340
+ del self.__dict__["_content"][index]
341
+ self._update_keys()
342
+ raise
343
+ except Exception as e:
344
+ self._format_and_raise(key=index, value=item, cause=e)
345
+ assert False
346
+
347
+ def extend(self, lst: Iterable[Any]) -> None:
348
+ assert isinstance(lst, (tuple, list, ListConfig))
349
+ for x in lst:
350
+ self.append(x)
351
+
352
+ def remove(self, x: Any) -> None:
353
+ del self[self.index(x)]
354
+
355
+ def __delitem__(self, key: Union[int, slice]) -> None:
356
+ if self._get_flag("readonly"):
357
+ self._format_and_raise(
358
+ key=key,
359
+ value=None,
360
+ cause=ReadonlyConfigError(
361
+ "Cannot delete item from read-only ListConfig"
362
+ ),
363
+ )
364
+ del self.__dict__["_content"][key]
365
+ self._update_keys()
366
+
367
+ def clear(self) -> None:
368
+ del self[:]
369
+
370
+ def index(
371
+ self, x: Any, start: Optional[int] = None, end: Optional[int] = None
372
+ ) -> int:
373
+ if start is None:
374
+ start = 0
375
+ if end is None:
376
+ end = len(self)
377
+ assert start >= 0
378
+ assert end <= len(self)
379
+ found_idx = -1
380
+ for idx in range(start, end):
381
+ item = self[idx]
382
+ if x == item:
383
+ found_idx = idx
384
+ break
385
+ if found_idx != -1:
386
+ return found_idx
387
+ else:
388
+ self._format_and_raise(
389
+ key=None,
390
+ value=None,
391
+ cause=ConfigValueError("Item not found in ListConfig"),
392
+ )
393
+ assert False
394
+
395
+ def count(self, x: Any) -> int:
396
+ c = 0
397
+ for item in self:
398
+ if item == x:
399
+ c = c + 1
400
+ return c
401
+
402
+ def _get_node(
403
+ self,
404
+ key: Union[int, slice],
405
+ validate_access: bool = True,
406
+ validate_key: bool = True,
407
+ throw_on_missing_value: bool = False,
408
+ throw_on_missing_key: bool = False,
409
+ ) -> Union[Optional[Node], List[Optional[Node]]]:
410
+ try:
411
+ if self._is_none():
412
+ raise TypeError(
413
+ "Cannot get_node from a ListConfig object representing None"
414
+ )
415
+ if self._is_missing():
416
+ raise MissingMandatoryValue("Cannot get_node from a missing ListConfig")
417
+ assert isinstance(self.__dict__["_content"], list)
418
+ if validate_access:
419
+ self._validate_get(key)
420
+
421
+ value = self.__dict__["_content"][key]
422
+ if value is not None:
423
+ if isinstance(key, slice):
424
+ assert isinstance(value, list)
425
+ for v in value:
426
+ if throw_on_missing_value and v._is_missing():
427
+ raise MissingMandatoryValue("Missing mandatory value")
428
+ else:
429
+ assert isinstance(value, Node)
430
+ if throw_on_missing_value and value._is_missing():
431
+ raise MissingMandatoryValue("Missing mandatory value: $KEY")
432
+ return value
433
+ except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
434
+ if isinstance(e, MissingMandatoryValue) and throw_on_missing_value:
435
+ raise
436
+ if validate_access:
437
+ self._format_and_raise(key=key, value=None, cause=e)
438
+ assert False
439
+ else:
440
+ return None
441
+
442
+ def get(self, index: int, default_value: Any = None) -> Any:
443
+ try:
444
+ if self._is_none():
445
+ raise TypeError("Cannot get from a ListConfig object representing None")
446
+ if self._is_missing():
447
+ raise MissingMandatoryValue("Cannot get from a missing ListConfig")
448
+ self._validate_get(index, None)
449
+ assert isinstance(self.__dict__["_content"], list)
450
+ return self._resolve_with_default(
451
+ key=index,
452
+ value=self.__dict__["_content"][index],
453
+ default_value=default_value,
454
+ )
455
+ except Exception as e:
456
+ self._format_and_raise(key=index, value=None, cause=e)
457
+ assert False
458
+
459
+ def pop(self, index: int = -1) -> Any:
460
+ try:
461
+ if self._get_flag("readonly"):
462
+ raise ReadonlyConfigError("Cannot pop from read-only ListConfig")
463
+ if self._is_none():
464
+ raise TypeError("Cannot pop from a ListConfig object representing None")
465
+ if self._is_missing():
466
+ raise MissingMandatoryValue("Cannot pop from a missing ListConfig")
467
+
468
+ assert isinstance(self.__dict__["_content"], list)
469
+ node = self._get_child(index)
470
+ assert isinstance(node, Node)
471
+ ret = self._resolve_with_default(key=index, value=node, default_value=None)
472
+ del self.__dict__["_content"][index]
473
+ self._update_keys()
474
+ return ret
475
+ except KeyValidationError as e:
476
+ self._format_and_raise(
477
+ key=index, value=None, cause=e, type_override=ConfigTypeError
478
+ )
479
+ assert False
480
+ except Exception as e:
481
+ self._format_and_raise(key=index, value=None, cause=e)
482
+ assert False
483
+
484
+ def sort(
485
+ self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False
486
+ ) -> None:
487
+ try:
488
+ if self._get_flag("readonly"):
489
+ raise ReadonlyConfigError("Cannot sort a read-only ListConfig")
490
+ if self._is_none():
491
+ raise TypeError("Cannot sort a ListConfig object representing None")
492
+ if self._is_missing():
493
+ raise MissingMandatoryValue("Cannot sort a missing ListConfig")
494
+
495
+ if key is None:
496
+
497
+ def key1(x: Any) -> Any:
498
+ return x._value()
499
+
500
+ else:
501
+
502
+ def key1(x: Any) -> Any:
503
+ return key(x._value()) # type: ignore
504
+
505
+ assert isinstance(self.__dict__["_content"], list)
506
+ self.__dict__["_content"].sort(key=key1, reverse=reverse)
507
+
508
+ except Exception as e:
509
+ self._format_and_raise(key=None, value=None, cause=e)
510
+ assert False
511
+
512
+ def __eq__(self, other: Any) -> bool:
513
+ if isinstance(other, (list, tuple)) or other is None:
514
+ other = ListConfig(other, flags={"allow_objects": True})
515
+ return ListConfig._list_eq(self, other)
516
+ if other is None or isinstance(other, ListConfig):
517
+ return ListConfig._list_eq(self, other)
518
+ if self._is_missing():
519
+ return _is_missing_literal(other)
520
+ return NotImplemented
521
+
522
+ def __ne__(self, other: Any) -> bool:
523
+ x = self.__eq__(other)
524
+ if x is not NotImplemented:
525
+ return not x
526
+ return NotImplemented
527
+
528
+ def __hash__(self) -> int:
529
+ return hash(str(self))
530
+
531
+ def __iter__(self) -> Iterator[Any]:
532
+ return self._iter_ex(resolve=True)
533
+
534
+ class ListIterator(Iterator[Any]):
535
+ def __init__(self, lst: Any, resolve: bool) -> None:
536
+ self.resolve = resolve
537
+ self.iterator = iter(lst.__dict__["_content"])
538
+ self.index = 0
539
+ from .nodes import ValueNode
540
+
541
+ self.ValueNode = ValueNode
542
+
543
+ def __next__(self) -> Any:
544
+
545
+ x = next(self.iterator)
546
+ if self.resolve:
547
+ x = x._dereference_node()
548
+ if x._is_missing():
549
+ raise MissingMandatoryValue(f"Missing value at index {self.index}")
550
+
551
+ self.index = self.index + 1
552
+ if isinstance(x, self.ValueNode):
553
+ return x._value()
554
+ else:
555
+ # Must be omegaconf.Container. not checking for perf reasons.
556
+ if x._is_none():
557
+ return None
558
+ return x
559
+
560
+ def __repr__(self) -> str: # pragma: no cover
561
+ return f"ListConfig.ListIterator(resolve={self.resolve})"
562
+
563
+ def _iter_ex(self, resolve: bool) -> Iterator[Any]:
564
+ try:
565
+ if self._is_none():
566
+ raise TypeError("Cannot iterate a ListConfig object representing None")
567
+ if self._is_missing():
568
+ raise MissingMandatoryValue("Cannot iterate a missing ListConfig")
569
+
570
+ return ListConfig.ListIterator(self, resolve)
571
+ except (TypeError, MissingMandatoryValue) as e:
572
+ self._format_and_raise(key=None, value=None, cause=e)
573
+ assert False
574
+
575
+ def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
576
+ # res is sharing this list's parent to allow interpolation to work as expected
577
+ res = ListConfig(parent=self._get_parent(), content=[])
578
+ res.extend(self)
579
+ res.extend(other)
580
+ return res
581
+
582
+ def __radd__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
583
+ # res is sharing this list's parent to allow interpolation to work as expected
584
+ res = ListConfig(parent=self._get_parent(), content=[])
585
+ res.extend(other)
586
+ res.extend(self)
587
+ return res
588
+
589
+ def __iadd__(self, other: Iterable[Any]) -> "ListConfig":
590
+ self.extend(other)
591
+ return self
592
+
593
+ def __contains__(self, item: Any) -> bool:
594
+ if self._is_none():
595
+ raise TypeError(
596
+ "Cannot check if an item is in a ListConfig object representing None"
597
+ )
598
+ if self._is_missing():
599
+ raise MissingMandatoryValue(
600
+ "Cannot check if an item is in missing ListConfig"
601
+ )
602
+
603
+ lst = self.__dict__["_content"]
604
+ for x in lst:
605
+ x = x._dereference_node()
606
+ if x == item:
607
+ return True
608
+ return False
609
+
610
+ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
611
+ try:
612
+ previous_content = self.__dict__["_content"]
613
+ previous_metadata = self.__dict__["_metadata"]
614
+ self._set_value_impl(value, flags)
615
+ except Exception as e:
616
+ self.__dict__["_content"] = previous_content
617
+ self.__dict__["_metadata"] = previous_metadata
618
+ raise e
619
+
620
+ def _set_value_impl(
621
+ self, value: Any, flags: Optional[Dict[str, bool]] = None
622
+ ) -> None:
623
+ from omegaconf import MISSING, flag_override
624
+
625
+ if flags is None:
626
+ flags = {}
627
+
628
+ vk = get_value_kind(value, strict_interpolation_validation=True)
629
+ if _is_none(value):
630
+ if not self._is_optional():
631
+ raise ValidationError(
632
+ "Non optional ListConfig cannot be constructed from None"
633
+ )
634
+ self.__dict__["_content"] = None
635
+ self._metadata.object_type = None
636
+ elif vk is ValueKind.MANDATORY_MISSING:
637
+ self.__dict__["_content"] = MISSING
638
+ self._metadata.object_type = None
639
+ elif vk == ValueKind.INTERPOLATION:
640
+ self.__dict__["_content"] = value
641
+ self._metadata.object_type = None
642
+ else:
643
+ if not (is_primitive_list(value) or isinstance(value, ListConfig)):
644
+ type_ = type(value)
645
+ msg = f"Invalid value assigned: {type_.__name__} is not a ListConfig, list or tuple."
646
+ raise ValidationError(msg)
647
+
648
+ self.__dict__["_content"] = []
649
+ if isinstance(value, ListConfig):
650
+ self._metadata.flags = copy.deepcopy(flags)
651
+ # disable struct and readonly for the construction phase
652
+ # retaining other flags like allow_objects. The real flags are restored at the end of this function
653
+ with flag_override(self, ["struct", "readonly"], False):
654
+ for item in value._iter_ex(resolve=False):
655
+ self.append(item)
656
+ elif is_primitive_list(value):
657
+ with flag_override(self, ["struct", "readonly"], False):
658
+ for item in value:
659
+ self.append(item)
660
+ self._metadata.object_type = list
661
+
662
+ @staticmethod
663
+ def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
664
+ l1_none = l1.__dict__["_content"] is None
665
+ l2_none = l2.__dict__["_content"] is None
666
+ if l1_none and l2_none:
667
+ return True
668
+ if l1_none != l2_none:
669
+ return False
670
+
671
+ assert isinstance(l1, ListConfig)
672
+ assert isinstance(l2, ListConfig)
673
+ if len(l1) != len(l2):
674
+ return False
675
+ for i in range(len(l1)):
676
+ if not BaseContainer._item_eq(l1, i, l2, i):
677
+ return False
678
+
679
+ return True
omegaconf/nodes.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import sys
4
+ from abc import abstractmethod
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Type, Union
8
+
9
+ from omegaconf._utils import (
10
+ ValueKind,
11
+ _is_interpolation,
12
+ get_type_of,
13
+ get_value_kind,
14
+ is_primitive_container,
15
+ type_str,
16
+ )
17
+ from omegaconf.base import Box, DictKeyType, Metadata, Node
18
+ from omegaconf.errors import ReadonlyConfigError, UnsupportedValueType, ValidationError
19
+
20
+
21
+ class ValueNode(Node):
22
+ _val: Any
23
+
24
+ def __init__(self, parent: Optional[Box], value: Any, metadata: Metadata):
25
+ from omegaconf import read_write
26
+
27
+ super().__init__(parent=parent, metadata=metadata)
28
+ with read_write(self):
29
+ self._set_value(value) # lgtm [py/init-calls-subclass]
30
+
31
+ def _value(self) -> Any:
32
+ return self._val
33
+
34
+ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
35
+ if self._get_flag("readonly"):
36
+ raise ReadonlyConfigError("Cannot set value of read-only config node")
37
+
38
+ if isinstance(value, str) and get_value_kind(
39
+ value, strict_interpolation_validation=True
40
+ ) in (
41
+ ValueKind.INTERPOLATION,
42
+ ValueKind.MANDATORY_MISSING,
43
+ ):
44
+ self._val = value
45
+ else:
46
+ self._val = self.validate_and_convert(value)
47
+
48
+ def _strict_validate_type(self, value: Any) -> None:
49
+ ref_type = self._metadata.ref_type
50
+ if isinstance(ref_type, type) and type(value) is not ref_type:
51
+ type_hint = type_str(self._metadata.type_hint)
52
+ raise ValidationError(
53
+ f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_hint}'"
54
+ )
55
+
56
+ def validate_and_convert(self, value: Any) -> Any:
57
+ """
58
+ Validates input and converts to canonical form
59
+ :param value: input value
60
+ :return: converted value ("100" may be converted to 100 for example)
61
+ """
62
+ if value is None:
63
+ if self._is_optional():
64
+ return None
65
+ ref_type_str = type_str(self._metadata.ref_type)
66
+ raise ValidationError(
67
+ f"Incompatible value '{value}' for field of type '{ref_type_str}'"
68
+ )
69
+
70
+ # Subclasses can assume that `value` is not None in
71
+ # `_validate_and_convert_impl()` and in `_strict_validate_type()`.
72
+ if self._get_flag("convert") is False:
73
+ self._strict_validate_type(value)
74
+ return value
75
+ else:
76
+ return self._validate_and_convert_impl(value)
77
+
78
+ @abstractmethod
79
+ def _validate_and_convert_impl(self, value: Any) -> Any:
80
+ ...
81
+
82
+ def __str__(self) -> str:
83
+ return str(self._val)
84
+
85
+ def __repr__(self) -> str:
86
+ return repr(self._val) if hasattr(self, "_val") else "__INVALID__"
87
+
88
+ def __eq__(self, other: Any) -> bool:
89
+ if isinstance(other, AnyNode):
90
+ return self._val == other._val # type: ignore
91
+ else:
92
+ return self._val == other # type: ignore
93
+
94
+ def __ne__(self, other: Any) -> bool:
95
+ x = self.__eq__(other)
96
+ assert x is not NotImplemented
97
+ return not x
98
+
99
+ def __hash__(self) -> int:
100
+ return hash(self._val)
101
+
102
+ def _deepcopy_impl(self, res: Any, memo: Dict[int, Any]) -> None:
103
+ res.__dict__["_metadata"] = copy.deepcopy(self._metadata, memo=memo)
104
+ # shallow copy for value to support non-copyable value
105
+ res.__dict__["_val"] = self._val
106
+
107
+ # parent is retained, but not copied
108
+ res.__dict__["_parent"] = self._parent
109
+
110
+ def _is_optional(self) -> bool:
111
+ return self._metadata.optional
112
+
113
+ def _is_interpolation(self) -> bool:
114
+ return _is_interpolation(self._value())
115
+
116
+ def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
117
+ parent = self._get_parent()
118
+ if parent is None:
119
+ if self._metadata.key is None:
120
+ return ""
121
+ else:
122
+ return str(self._metadata.key)
123
+ else:
124
+ return parent._get_full_key(self._metadata.key)
125
+
126
+
127
+ class AnyNode(ValueNode):
128
+ def __init__(
129
+ self,
130
+ value: Any = None,
131
+ key: Any = None,
132
+ parent: Optional[Box] = None,
133
+ flags: Optional[Dict[str, bool]] = None,
134
+ ):
135
+ super().__init__(
136
+ parent=parent,
137
+ value=value,
138
+ metadata=Metadata(
139
+ ref_type=Any, object_type=None, key=key, optional=True, flags=flags
140
+ ),
141
+ )
142
+
143
+ def _validate_and_convert_impl(self, value: Any) -> Any:
144
+ from ._utils import is_primitive_type_annotation
145
+
146
+ # allow_objects is internal and not an official API. use at your own risk.
147
+ # Please be aware that this support is subject to change without notice.
148
+ # If this is deemed useful and supportable it may become an official API.
149
+
150
+ if self._get_flag(
151
+ "allow_objects"
152
+ ) is not True and not is_primitive_type_annotation(value):
153
+ t = get_type_of(value)
154
+ raise UnsupportedValueType(
155
+ f"Value '{t.__name__}' is not a supported primitive type"
156
+ )
157
+ return value
158
+
159
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "AnyNode":
160
+ res = AnyNode()
161
+ self._deepcopy_impl(res, memo)
162
+ return res
163
+
164
+
165
+ class StringNode(ValueNode):
166
+ def __init__(
167
+ self,
168
+ value: Any = None,
169
+ key: Any = None,
170
+ parent: Optional[Box] = None,
171
+ is_optional: bool = True,
172
+ flags: Optional[Dict[str, bool]] = None,
173
+ ):
174
+ super().__init__(
175
+ parent=parent,
176
+ value=value,
177
+ metadata=Metadata(
178
+ key=key,
179
+ optional=is_optional,
180
+ ref_type=str,
181
+ object_type=str,
182
+ flags=flags,
183
+ ),
184
+ )
185
+
186
+ def _validate_and_convert_impl(self, value: Any) -> str:
187
+ from omegaconf import OmegaConf
188
+
189
+ if (
190
+ OmegaConf.is_config(value)
191
+ or is_primitive_container(value)
192
+ or isinstance(value, bytes)
193
+ ):
194
+ raise ValidationError("Cannot convert '$VALUE_TYPE' to string: '$VALUE'")
195
+ return str(value)
196
+
197
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode":
198
+ res = StringNode()
199
+ self._deepcopy_impl(res, memo)
200
+ return res
201
+
202
+
203
+ class PathNode(ValueNode):
204
+ def __init__(
205
+ self,
206
+ value: Any = None,
207
+ key: Any = None,
208
+ parent: Optional[Box] = None,
209
+ is_optional: bool = True,
210
+ flags: Optional[Dict[str, bool]] = None,
211
+ ):
212
+ super().__init__(
213
+ parent=parent,
214
+ value=value,
215
+ metadata=Metadata(
216
+ key=key,
217
+ optional=is_optional,
218
+ ref_type=Path,
219
+ object_type=Path,
220
+ flags=flags,
221
+ ),
222
+ )
223
+
224
+ def _strict_validate_type(self, value: Any) -> None:
225
+ if not isinstance(value, Path):
226
+ raise ValidationError(
227
+ "Value '$VALUE' of type '$VALUE_TYPE' is not an instance of 'pathlib.Path'"
228
+ )
229
+
230
+ def _validate_and_convert_impl(self, value: Any) -> Path:
231
+ if not isinstance(value, (str, Path)):
232
+ raise ValidationError(
233
+ "Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Path"
234
+ )
235
+
236
+ return Path(value)
237
+
238
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "PathNode":
239
+ res = PathNode()
240
+ self._deepcopy_impl(res, memo)
241
+ return res
242
+
243
+
244
+ class IntegerNode(ValueNode):
245
+ def __init__(
246
+ self,
247
+ value: Any = None,
248
+ key: Any = None,
249
+ parent: Optional[Box] = None,
250
+ is_optional: bool = True,
251
+ flags: Optional[Dict[str, bool]] = None,
252
+ ):
253
+ super().__init__(
254
+ parent=parent,
255
+ value=value,
256
+ metadata=Metadata(
257
+ key=key,
258
+ optional=is_optional,
259
+ ref_type=int,
260
+ object_type=int,
261
+ flags=flags,
262
+ ),
263
+ )
264
+
265
+ def _validate_and_convert_impl(self, value: Any) -> int:
266
+ try:
267
+ if type(value) in (str, int):
268
+ val = int(value)
269
+ else:
270
+ raise ValueError()
271
+ except ValueError:
272
+ raise ValidationError(
273
+ "Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Integer"
274
+ )
275
+ return val
276
+
277
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "IntegerNode":
278
+ res = IntegerNode()
279
+ self._deepcopy_impl(res, memo)
280
+ return res
281
+
282
+
283
+ class BytesNode(ValueNode):
284
+ def __init__(
285
+ self,
286
+ value: Any = None,
287
+ key: Any = None,
288
+ parent: Optional[Box] = None,
289
+ is_optional: bool = True,
290
+ flags: Optional[Dict[str, bool]] = None,
291
+ ):
292
+ super().__init__(
293
+ parent=parent,
294
+ value=value,
295
+ metadata=Metadata(
296
+ key=key,
297
+ optional=is_optional,
298
+ ref_type=bytes,
299
+ object_type=bytes,
300
+ flags=flags,
301
+ ),
302
+ )
303
+
304
+ def _validate_and_convert_impl(self, value: Any) -> bytes:
305
+ if not isinstance(value, bytes):
306
+ raise ValidationError(
307
+ "Value '$VALUE' of type '$VALUE_TYPE' is not of type 'bytes'"
308
+ )
309
+ return value
310
+
311
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "BytesNode":
312
+ res = BytesNode()
313
+ self._deepcopy_impl(res, memo)
314
+ return res
315
+
316
+
317
+ class FloatNode(ValueNode):
318
+ def __init__(
319
+ self,
320
+ value: Any = None,
321
+ key: Any = None,
322
+ parent: Optional[Box] = None,
323
+ is_optional: bool = True,
324
+ flags: Optional[Dict[str, bool]] = None,
325
+ ):
326
+ super().__init__(
327
+ parent=parent,
328
+ value=value,
329
+ metadata=Metadata(
330
+ key=key,
331
+ optional=is_optional,
332
+ ref_type=float,
333
+ object_type=float,
334
+ flags=flags,
335
+ ),
336
+ )
337
+
338
+ def _validate_and_convert_impl(self, value: Any) -> float:
339
+ try:
340
+ if type(value) in (float, str, int):
341
+ return float(value)
342
+ else:
343
+ raise ValueError()
344
+ except ValueError:
345
+ raise ValidationError(
346
+ "Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Float"
347
+ )
348
+
349
+ def __eq__(self, other: Any) -> bool:
350
+ if isinstance(other, ValueNode):
351
+ other_val = other._val
352
+ else:
353
+ other_val = other
354
+ if self._val is None and other is None:
355
+ return True
356
+ if self._val is None and other is not None:
357
+ return False
358
+ if self._val is not None and other is None:
359
+ return False
360
+ nan1 = math.isnan(self._val) if isinstance(self._val, float) else False
361
+ nan2 = math.isnan(other_val) if isinstance(other_val, float) else False
362
+ return self._val == other_val or (nan1 and nan2)
363
+
364
+ def __hash__(self) -> int:
365
+ return hash(self._val)
366
+
367
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "FloatNode":
368
+ res = FloatNode()
369
+ self._deepcopy_impl(res, memo)
370
+ return res
371
+
372
+
373
+ class BooleanNode(ValueNode):
374
+ def __init__(
375
+ self,
376
+ value: Any = None,
377
+ key: Any = None,
378
+ parent: Optional[Box] = None,
379
+ is_optional: bool = True,
380
+ flags: Optional[Dict[str, bool]] = None,
381
+ ):
382
+ super().__init__(
383
+ parent=parent,
384
+ value=value,
385
+ metadata=Metadata(
386
+ key=key,
387
+ optional=is_optional,
388
+ ref_type=bool,
389
+ object_type=bool,
390
+ flags=flags,
391
+ ),
392
+ )
393
+
394
+ def _validate_and_convert_impl(self, value: Any) -> bool:
395
+ if isinstance(value, bool):
396
+ return value
397
+ if isinstance(value, int):
398
+ return value != 0
399
+ elif isinstance(value, str):
400
+ try:
401
+ return self._validate_and_convert_impl(int(value))
402
+ except ValueError as e:
403
+ if value.lower() in ("yes", "y", "on", "true"):
404
+ return True
405
+ elif value.lower() in ("no", "n", "off", "false"):
406
+ return False
407
+ else:
408
+ raise ValidationError(
409
+ "Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
410
+ ).with_traceback(sys.exc_info()[2]) from e
411
+ else:
412
+ raise ValidationError(
413
+ "Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
414
+ )
415
+
416
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "BooleanNode":
417
+ res = BooleanNode()
418
+ self._deepcopy_impl(res, memo)
419
+ return res
420
+
421
+
422
+ class EnumNode(ValueNode): # lgtm [py/missing-equals] : Intentional.
423
+ """
424
+ NOTE: EnumNode is serialized to yaml as a string ("Color.BLUE"), not as a fully qualified yaml type.
425
+ this means serialization to YAML of a typed config (with EnumNode) will not retain the type of the Enum
426
+ when loaded.
427
+ This is intentional, Please open an issue against OmegaConf if you wish to discuss this decision.
428
+ """
429
+
430
+ def __init__(
431
+ self,
432
+ enum_type: Type[Enum],
433
+ value: Optional[Union[Enum, str]] = None,
434
+ key: Any = None,
435
+ parent: Optional[Box] = None,
436
+ is_optional: bool = True,
437
+ flags: Optional[Dict[str, bool]] = None,
438
+ ):
439
+ if not isinstance(enum_type, type) or not issubclass(enum_type, Enum):
440
+ raise ValidationError(
441
+ f"EnumNode can only operate on Enum subclasses ({enum_type})"
442
+ )
443
+ self.fields: Dict[str, str] = {}
444
+ self.enum_type: Type[Enum] = enum_type
445
+ for name, constant in enum_type.__members__.items():
446
+ self.fields[name] = constant.value
447
+ super().__init__(
448
+ parent=parent,
449
+ value=value,
450
+ metadata=Metadata(
451
+ key=key,
452
+ optional=is_optional,
453
+ ref_type=enum_type,
454
+ object_type=enum_type,
455
+ flags=flags,
456
+ ),
457
+ )
458
+
459
+ def _strict_validate_type(self, value: Any) -> None:
460
+ ref_type = self._metadata.ref_type
461
+ if not isinstance(value, ref_type):
462
+ type_hint = type_str(self._metadata.type_hint)
463
+ raise ValidationError(
464
+ f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_hint}'"
465
+ )
466
+
467
+ def _validate_and_convert_impl(self, value: Any) -> Enum:
468
+ return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value)
469
+
470
+ @staticmethod
471
+ def validate_and_convert_to_enum(enum_type: Type[Enum], value: Any) -> Enum:
472
+ if not isinstance(value, (str, int)) and not isinstance(value, enum_type):
473
+ raise ValidationError(
474
+ f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}"
475
+ )
476
+
477
+ if isinstance(value, enum_type):
478
+ return value
479
+
480
+ try:
481
+ if isinstance(value, (float, bool)):
482
+ raise ValueError
483
+
484
+ if isinstance(value, int):
485
+ return enum_type(value)
486
+
487
+ if isinstance(value, str):
488
+ prefix = f"{enum_type.__name__}."
489
+ if value.startswith(prefix):
490
+ value = value[len(prefix) :]
491
+ return enum_type[value]
492
+
493
+ assert False
494
+
495
+ except (ValueError, KeyError) as e:
496
+ valid = ", ".join([x for x in enum_type.__members__.keys()])
497
+ raise ValidationError(
498
+ f"Invalid value '$VALUE', expected one of [{valid}]"
499
+ ).with_traceback(sys.exc_info()[2]) from e
500
+
501
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode":
502
+ res = EnumNode(enum_type=self.enum_type)
503
+ self._deepcopy_impl(res, memo)
504
+ return res
505
+
506
+
507
+ class InterpolationResultNode(ValueNode):
508
+ """
509
+ Special node type, used to wrap interpolation results.
510
+ """
511
+
512
+ def __init__(
513
+ self,
514
+ value: Any,
515
+ key: Any = None,
516
+ parent: Optional[Box] = None,
517
+ flags: Optional[Dict[str, bool]] = None,
518
+ ):
519
+ super().__init__(
520
+ parent=parent,
521
+ value=value,
522
+ metadata=Metadata(
523
+ ref_type=Any, object_type=None, key=key, optional=True, flags=flags
524
+ ),
525
+ )
526
+ # In general we should not try to write into interpolation results.
527
+ if flags is None or "readonly" not in flags:
528
+ self._set_flag("readonly", True)
529
+
530
+ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
531
+ if self._get_flag("readonly"):
532
+ raise ReadonlyConfigError("Cannot set value of read-only config node")
533
+ self._val = self.validate_and_convert(value)
534
+
535
+ def _validate_and_convert_impl(self, value: Any) -> Any:
536
+ # Interpolation results may be anything.
537
+ return value
538
+
539
+ def __deepcopy__(self, memo: Dict[int, Any]) -> "InterpolationResultNode":
540
+ # Currently there should be no need to deep-copy such nodes.
541
+ raise NotImplementedError
542
+
543
+ def _is_interpolation(self) -> bool:
544
+ # The result of an interpolation cannot be itself an interpolation.
545
+ return False
omegaconf/omegaconf.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OmegaConf module"""
2
+ import copy
3
+ import inspect
4
+ import io
5
+ import os
6
+ import pathlib
7
+ import sys
8
+ import warnings
9
+ from collections import defaultdict
10
+ from contextlib import contextmanager
11
+ from enum import Enum
12
+ from textwrap import dedent
13
+ from typing import (
14
+ IO,
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Generator,
19
+ Iterable,
20
+ List,
21
+ Optional,
22
+ Set,
23
+ Tuple,
24
+ Type,
25
+ Union,
26
+ overload,
27
+ )
28
+
29
+ import yaml
30
+
31
+ from . import DictConfig, DictKeyType, ListConfig
32
+ from ._utils import (
33
+ _DEFAULT_MARKER_,
34
+ _ensure_container,
35
+ _get_value,
36
+ format_and_raise,
37
+ get_dict_key_value_types,
38
+ get_list_element_type,
39
+ get_omega_conf_dumper,
40
+ get_type_of,
41
+ is_attr_class,
42
+ is_dataclass,
43
+ is_dict_annotation,
44
+ is_int,
45
+ is_list_annotation,
46
+ is_primitive_container,
47
+ is_primitive_dict,
48
+ is_primitive_list,
49
+ is_structured_config,
50
+ is_tuple_annotation,
51
+ is_union_annotation,
52
+ nullcontext,
53
+ split_key,
54
+ type_str,
55
+ )
56
+ from .base import Box, Container, Node, SCMode, UnionNode
57
+ from .basecontainer import BaseContainer
58
+ from .errors import (
59
+ MissingMandatoryValue,
60
+ OmegaConfBaseException,
61
+ UnsupportedInterpolationType,
62
+ ValidationError,
63
+ )
64
+ from .nodes import (
65
+ AnyNode,
66
+ BooleanNode,
67
+ BytesNode,
68
+ EnumNode,
69
+ FloatNode,
70
+ IntegerNode,
71
+ PathNode,
72
+ StringNode,
73
+ ValueNode,
74
+ )
75
+
76
+ MISSING: Any = "???"
77
+
78
+ Resolver = Callable[..., Any]
79
+
80
+
81
+ def II(interpolation: str) -> Any:
82
+ """
83
+ Equivalent to ``${interpolation}``
84
+
85
+ :param interpolation:
86
+ :return: input ``${node}`` with type Any
87
+ """
88
+ return "${" + interpolation + "}"
89
+
90
+
91
+ def SI(interpolation: str) -> Any:
92
+ """
93
+ Use this for String interpolation, for example ``"http://${host}:${port}"``
94
+
95
+ :param interpolation: interpolation string
96
+ :return: input interpolation with type ``Any``
97
+ """
98
+ return interpolation
99
+
100
+
101
+ def register_default_resolvers() -> None:
102
+ from omegaconf.resolvers import oc
103
+
104
+ OmegaConf.register_new_resolver("oc.create", oc.create)
105
+ OmegaConf.register_new_resolver("oc.decode", oc.decode)
106
+ OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
107
+ OmegaConf.register_new_resolver("oc.env", oc.env)
108
+ OmegaConf.register_new_resolver("oc.select", oc.select)
109
+ OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys)
110
+ OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values)
111
+
112
+
113
+ class OmegaConf:
114
+ """OmegaConf primary class"""
115
+
116
+ def __init__(self) -> None:
117
+ raise NotImplementedError("Use one of the static construction functions")
118
+
119
+ @staticmethod
120
+ def structured(
121
+ obj: Any,
122
+ parent: Optional[BaseContainer] = None,
123
+ flags: Optional[Dict[str, bool]] = None,
124
+ ) -> Any:
125
+ return OmegaConf.create(obj, parent, flags)
126
+
127
+ @staticmethod
128
+ @overload
129
+ def create(
130
+ obj: str,
131
+ parent: Optional[BaseContainer] = None,
132
+ flags: Optional[Dict[str, bool]] = None,
133
+ ) -> Union[DictConfig, ListConfig]:
134
+ ...
135
+
136
+ @staticmethod
137
+ @overload
138
+ def create(
139
+ obj: Union[List[Any], Tuple[Any, ...]],
140
+ parent: Optional[BaseContainer] = None,
141
+ flags: Optional[Dict[str, bool]] = None,
142
+ ) -> ListConfig:
143
+ ...
144
+
145
+ @staticmethod
146
+ @overload
147
+ def create(
148
+ obj: DictConfig,
149
+ parent: Optional[BaseContainer] = None,
150
+ flags: Optional[Dict[str, bool]] = None,
151
+ ) -> DictConfig:
152
+ ...
153
+
154
+ @staticmethod
155
+ @overload
156
+ def create(
157
+ obj: ListConfig,
158
+ parent: Optional[BaseContainer] = None,
159
+ flags: Optional[Dict[str, bool]] = None,
160
+ ) -> ListConfig:
161
+ ...
162
+
163
+ @staticmethod
164
+ @overload
165
+ def create(
166
+ obj: Optional[Dict[Any, Any]] = None,
167
+ parent: Optional[BaseContainer] = None,
168
+ flags: Optional[Dict[str, bool]] = None,
169
+ ) -> DictConfig:
170
+ ...
171
+
172
+ @staticmethod
173
+ def create( # noqa F811
174
+ obj: Any = _DEFAULT_MARKER_,
175
+ parent: Optional[BaseContainer] = None,
176
+ flags: Optional[Dict[str, bool]] = None,
177
+ ) -> Union[DictConfig, ListConfig]:
178
+ return OmegaConf._create_impl(
179
+ obj=obj,
180
+ parent=parent,
181
+ flags=flags,
182
+ )
183
+
184
+ @staticmethod
185
+ def load(file_: Union[str, pathlib.Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
186
+ from ._utils import get_yaml_loader
187
+
188
+ if isinstance(file_, (str, pathlib.Path)):
189
+ with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f:
190
+ obj = yaml.load(f, Loader=get_yaml_loader())
191
+ elif getattr(file_, "read", None):
192
+ obj = yaml.load(file_, Loader=get_yaml_loader())
193
+ else:
194
+ raise TypeError("Unexpected file type")
195
+
196
+ if obj is not None and not isinstance(obj, (list, dict, str)):
197
+ raise IOError( # pragma: no cover
198
+ f"Invalid loaded object type: {type(obj).__name__}"
199
+ )
200
+
201
+ ret: Union[DictConfig, ListConfig]
202
+ if obj is None:
203
+ ret = OmegaConf.create()
204
+ else:
205
+ ret = OmegaConf.create(obj)
206
+ return ret
207
+
208
+ @staticmethod
209
+ def save(
210
+ config: Any, f: Union[str, pathlib.Path, IO[Any]], resolve: bool = False
211
+ ) -> None:
212
+ """
213
+ Save as configuration object to a file
214
+
215
+ :param config: omegaconf.Config object (DictConfig or ListConfig).
216
+ :param f: filename or file object
217
+ :param resolve: True to save a resolved config (defaults to False)
218
+ """
219
+ if is_dataclass(config) or is_attr_class(config):
220
+ config = OmegaConf.create(config)
221
+ data = OmegaConf.to_yaml(config, resolve=resolve)
222
+ if isinstance(f, (str, pathlib.Path)):
223
+ with io.open(os.path.abspath(f), "w", encoding="utf-8") as file:
224
+ file.write(data)
225
+ elif hasattr(f, "write"):
226
+ f.write(data)
227
+ f.flush()
228
+ else:
229
+ raise TypeError("Unexpected file type")
230
+
231
+ @staticmethod
232
+ def from_cli(args_list: Optional[List[str]] = None) -> DictConfig:
233
+ if args_list is None:
234
+ # Skip program name
235
+ args_list = sys.argv[1:]
236
+ return OmegaConf.from_dotlist(args_list)
237
+
238
+ @staticmethod
239
+ def from_dotlist(dotlist: List[str]) -> DictConfig:
240
+ """
241
+ Creates config from the content sys.argv or from the specified args list of not None
242
+
243
+ :param dotlist: A list of dotlist-style strings, e.g. ``["foo.bar=1", "baz=qux"]``.
244
+ :return: A ``DictConfig`` object created from the dotlist.
245
+ """
246
+ conf = OmegaConf.create()
247
+ conf.merge_with_dotlist(dotlist)
248
+ return conf
249
+
250
+ @staticmethod
251
+ def merge(
252
+ *configs: Union[
253
+ DictConfig,
254
+ ListConfig,
255
+ Dict[DictKeyType, Any],
256
+ List[Any],
257
+ Tuple[Any, ...],
258
+ Any,
259
+ ],
260
+ ) -> Union[ListConfig, DictConfig]:
261
+ """
262
+ Merge a list of previously created configs into a single one
263
+
264
+ :param configs: Input configs
265
+ :return: the merged config object.
266
+ """
267
+ assert len(configs) > 0
268
+ target = copy.deepcopy(configs[0])
269
+ target = _ensure_container(target)
270
+ assert isinstance(target, (DictConfig, ListConfig))
271
+
272
+ with flag_override(target, "readonly", False):
273
+ target.merge_with(*configs[1:])
274
+ turned_readonly = target._get_flag("readonly") is True
275
+
276
+ if turned_readonly:
277
+ OmegaConf.set_readonly(target, True)
278
+
279
+ return target
280
+
281
+ @staticmethod
282
+ def unsafe_merge(
283
+ *configs: Union[
284
+ DictConfig,
285
+ ListConfig,
286
+ Dict[DictKeyType, Any],
287
+ List[Any],
288
+ Tuple[Any, ...],
289
+ Any,
290
+ ],
291
+ ) -> Union[ListConfig, DictConfig]:
292
+ """
293
+ Merge a list of previously created configs into a single one
294
+ This is much faster than OmegaConf.merge() as the input configs are not copied.
295
+ However, the input configs must not be used after this operation as will become inconsistent.
296
+
297
+ :param configs: Input configs
298
+ :return: the merged config object.
299
+ """
300
+ assert len(configs) > 0
301
+ target = configs[0]
302
+ target = _ensure_container(target)
303
+ assert isinstance(target, (DictConfig, ListConfig))
304
+
305
+ with flag_override(
306
+ target, ["readonly", "no_deepcopy_set_nodes"], [False, True]
307
+ ):
308
+ target.merge_with(*configs[1:])
309
+ turned_readonly = target._get_flag("readonly") is True
310
+
311
+ if turned_readonly:
312
+ OmegaConf.set_readonly(target, True)
313
+
314
+ return target
315
+
316
+ @staticmethod
317
+ def register_resolver(name: str, resolver: Resolver) -> None:
318
+ warnings.warn(
319
+ dedent(
320
+ """\
321
+ register_resolver() is deprecated.
322
+ See https://github.com/omry/omegaconf/issues/426 for migration instructions.
323
+ """
324
+ ),
325
+ stacklevel=2,
326
+ )
327
+ return OmegaConf.legacy_register_resolver(name, resolver)
328
+
329
+ # This function will eventually be deprecated and removed.
330
+ @staticmethod
331
+ def legacy_register_resolver(name: str, resolver: Resolver) -> None:
332
+ assert callable(resolver), "resolver must be callable"
333
+ # noinspection PyProtectedMember
334
+ assert (
335
+ name not in BaseContainer._resolvers
336
+ ), f"resolver '{name}' is already registered"
337
+
338
+ def resolver_wrapper(
339
+ config: BaseContainer,
340
+ parent: BaseContainer,
341
+ node: Node,
342
+ args: Tuple[Any, ...],
343
+ args_str: Tuple[str, ...],
344
+ ) -> Any:
345
+ cache = OmegaConf.get_cache(config)[name]
346
+ # "Un-escape " spaces and commas.
347
+ args_unesc = [x.replace(r"\ ", " ").replace(r"\,", ",") for x in args_str]
348
+
349
+ # Nested interpolations behave in a potentially surprising way with
350
+ # legacy resolvers (they remain as strings, e.g., "${foo}"). If any
351
+ # input looks like an interpolation we thus raise an exception.
352
+ try:
353
+ bad_arg = next(i for i in args_unesc if "${" in i)
354
+ except StopIteration:
355
+ pass
356
+ else:
357
+ raise ValueError(
358
+ f"Resolver '{name}' was called with argument '{bad_arg}' that appears "
359
+ f"to be an interpolation. Nested interpolations are not supported for "
360
+ f"resolvers registered with `[legacy_]register_resolver()`, please use "
361
+ f"`register_new_resolver()` instead (see "
362
+ f"https://github.com/omry/omegaconf/issues/426 for migration instructions)."
363
+ )
364
+ key = args_str
365
+ val = cache[key] if key in cache else resolver(*args_unesc)
366
+ cache[key] = val
367
+ return val
368
+
369
+ # noinspection PyProtectedMember
370
+ BaseContainer._resolvers[name] = resolver_wrapper
371
+
372
+ @staticmethod
373
+ def register_new_resolver(
374
+ name: str,
375
+ resolver: Resolver,
376
+ *,
377
+ replace: bool = False,
378
+ use_cache: bool = False,
379
+ ) -> None:
380
+ """
381
+ Register a resolver.
382
+
383
+ :param name: Name of the resolver.
384
+ :param resolver: Callable whose arguments are provided in the interpolation,
385
+ e.g., with ${foo:x,0,${y.z}} these arguments are respectively "x" (str),
386
+ 0 (int) and the value of ``y.z``.
387
+ :param replace: If set to ``False`` (default), then a ``ValueError`` is raised if
388
+ an existing resolver has already been registered with the same name.
389
+ If set to ``True``, then the new resolver replaces the previous one.
390
+ NOTE: The cache on existing config objects is not affected, use
391
+ ``OmegaConf.clear_cache(cfg)`` to clear it.
392
+ :param use_cache: Whether the resolver's outputs should be cached. The cache is
393
+ based only on the string literals representing the resolver arguments, e.g.,
394
+ ${foo:${bar}} will always return the same value regardless of the value of
395
+ ``bar`` if the cache is enabled for ``foo``.
396
+ """
397
+ if not callable(resolver):
398
+ raise TypeError("resolver must be callable")
399
+ if not name:
400
+ raise ValueError("cannot use an empty resolver name")
401
+
402
+ if not replace and OmegaConf.has_resolver(name):
403
+ raise ValueError(f"resolver '{name}' is already registered")
404
+
405
+ try:
406
+ sig: Optional[inspect.Signature] = inspect.signature(resolver)
407
+ except ValueError:
408
+ sig = None
409
+
410
+ def _should_pass(special: str) -> bool:
411
+ ret = sig is not None and special in sig.parameters
412
+ if ret and use_cache:
413
+ raise ValueError(
414
+ f"use_cache=True is incompatible with functions that receive the {special}"
415
+ )
416
+ return ret
417
+
418
+ pass_parent = _should_pass("_parent_")
419
+ pass_node = _should_pass("_node_")
420
+ pass_root = _should_pass("_root_")
421
+
422
+ def resolver_wrapper(
423
+ config: BaseContainer,
424
+ parent: Container,
425
+ node: Node,
426
+ args: Tuple[Any, ...],
427
+ args_str: Tuple[str, ...],
428
+ ) -> Any:
429
+ if use_cache:
430
+ cache = OmegaConf.get_cache(config)[name]
431
+ try:
432
+ return cache[args_str]
433
+ except KeyError:
434
+ pass
435
+
436
+ # Call resolver.
437
+ kwargs: Dict[str, Node] = {}
438
+ if pass_parent:
439
+ kwargs["_parent_"] = parent
440
+ if pass_node:
441
+ kwargs["_node_"] = node
442
+ if pass_root:
443
+ kwargs["_root_"] = config
444
+
445
+ ret = resolver(*args, **kwargs)
446
+
447
+ if use_cache:
448
+ cache[args_str] = ret
449
+ return ret
450
+
451
+ # noinspection PyProtectedMember
452
+ BaseContainer._resolvers[name] = resolver_wrapper
453
+
454
+ @classmethod
455
+ def has_resolver(cls, name: str) -> bool:
456
+ return cls._get_resolver(name) is not None
457
+
458
+ # noinspection PyProtectedMember
459
+ @staticmethod
460
+ def clear_resolvers() -> None:
461
+ """
462
+ Clear(remove) all OmegaConf resolvers, then re-register OmegaConf's default resolvers.
463
+ """
464
+ BaseContainer._resolvers = {}
465
+ register_default_resolvers()
466
+
467
+ @classmethod
468
+ def clear_resolver(cls, name: str) -> bool:
469
+ """
470
+ Clear(remove) any resolver only if it exists.
471
+
472
+ Returns a bool: True if resolver is removed and False if not removed.
473
+
474
+ .. warning:
475
+ This method can remove deafult resolvers as well.
476
+
477
+ :param name: Name of the resolver.
478
+ :return: A bool (``True`` if resolver is removed, ``False`` if not found before removing).
479
+ """
480
+ if cls.has_resolver(name):
481
+ BaseContainer._resolvers.pop(name)
482
+ return True
483
+ else:
484
+ # return False if resolver does not exist
485
+ return False
486
+
487
+ @staticmethod
488
+ def get_cache(conf: BaseContainer) -> Dict[str, Any]:
489
+ return conf._metadata.resolver_cache
490
+
491
+ @staticmethod
492
+ def set_cache(conf: BaseContainer, cache: Dict[str, Any]) -> None:
493
+ conf._metadata.resolver_cache = copy.deepcopy(cache)
494
+
495
+ @staticmethod
496
+ def clear_cache(conf: BaseContainer) -> None:
497
+ OmegaConf.set_cache(conf, defaultdict(dict, {}))
498
+
499
+ @staticmethod
500
+ def copy_cache(from_config: BaseContainer, to_config: BaseContainer) -> None:
501
+ OmegaConf.set_cache(to_config, OmegaConf.get_cache(from_config))
502
+
503
+ @staticmethod
504
+ def set_readonly(conf: Node, value: Optional[bool]) -> None:
505
+ # noinspection PyProtectedMember
506
+ conf._set_flag("readonly", value)
507
+
508
+ @staticmethod
509
+ def is_readonly(conf: Node) -> Optional[bool]:
510
+ # noinspection PyProtectedMember
511
+ return conf._get_flag("readonly")
512
+
513
+ @staticmethod
514
+ def set_struct(conf: Container, value: Optional[bool]) -> None:
515
+ # noinspection PyProtectedMember
516
+ conf._set_flag("struct", value)
517
+
518
+ @staticmethod
519
+ def is_struct(conf: Container) -> Optional[bool]:
520
+ # noinspection PyProtectedMember
521
+ return conf._get_flag("struct")
522
+
523
+ @staticmethod
524
+ def masked_copy(conf: DictConfig, keys: Union[str, List[str]]) -> DictConfig:
525
+ """
526
+ Create a masked copy of of this config that contains a subset of the keys
527
+
528
+ :param conf: DictConfig object
529
+ :param keys: keys to preserve in the copy
530
+ :return: The masked ``DictConfig`` object.
531
+ """
532
+ from .dictconfig import DictConfig
533
+
534
+ if not isinstance(conf, DictConfig):
535
+ raise ValueError("masked_copy is only supported for DictConfig")
536
+
537
+ if isinstance(keys, str):
538
+ keys = [keys]
539
+ content = {key: value for key, value in conf.items_ex(resolve=False, keys=keys)}
540
+ return DictConfig(content=content)
541
+
542
+ @staticmethod
543
+ def to_container(
544
+ cfg: Any,
545
+ *,
546
+ resolve: bool = False,
547
+ throw_on_missing: bool = False,
548
+ enum_to_str: bool = False,
549
+ structured_config_mode: SCMode = SCMode.DICT,
550
+ ) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
551
+ """
552
+ Resursively converts an OmegaConf config to a primitive container (dict or list).
553
+
554
+ :param cfg: the config to convert
555
+ :param resolve: True to resolve all values
556
+ :param throw_on_missing: When True, raise MissingMandatoryValue if any missing values are present.
557
+ When False (the default), replace missing values with the string "???" in the output container.
558
+ :param enum_to_str: True to convert Enum keys and values to strings
559
+ :param structured_config_mode: Specify how Structured Configs (DictConfigs backed by a dataclass) are handled.
560
+ - By default (``structured_config_mode=SCMode.DICT``) structured configs are converted to plain dicts.
561
+ - If ``structured_config_mode=SCMode.DICT_CONFIG``, structured config nodes will remain as DictConfig.
562
+ - If ``structured_config_mode=SCMode.INSTANTIATE``, this function will instantiate structured configs
563
+ (DictConfigs backed by a dataclass), by creating an instance of the underlying dataclass.
564
+
565
+ See also OmegaConf.to_object.
566
+ :return: A dict or a list representing this config as a primitive container.
567
+ """
568
+ if not OmegaConf.is_config(cfg):
569
+ raise ValueError(
570
+ f"Input cfg is not an OmegaConf config object ({type_str(type(cfg))})"
571
+ )
572
+
573
+ return BaseContainer._to_content(
574
+ cfg,
575
+ resolve=resolve,
576
+ throw_on_missing=throw_on_missing,
577
+ enum_to_str=enum_to_str,
578
+ structured_config_mode=structured_config_mode,
579
+ )
580
+
581
+ @staticmethod
582
+ def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
583
+ """
584
+ Resursively converts an OmegaConf config to a primitive container (dict or list).
585
+ Any DictConfig objects backed by dataclasses or attrs classes are instantiated
586
+ as instances of those backing classes.
587
+
588
+ This is an alias for OmegaConf.to_container(..., resolve=True, throw_on_missing=True,
589
+ structured_config_mode=SCMode.INSTANTIATE)
590
+
591
+ :param cfg: the config to convert
592
+ :return: A dict or a list or dataclass representing this config.
593
+ """
594
+ return OmegaConf.to_container(
595
+ cfg=cfg,
596
+ resolve=True,
597
+ throw_on_missing=True,
598
+ enum_to_str=False,
599
+ structured_config_mode=SCMode.INSTANTIATE,
600
+ )
601
+
602
+ @staticmethod
603
+ def is_missing(cfg: Any, key: DictKeyType) -> bool:
604
+ assert isinstance(cfg, Container)
605
+ try:
606
+ node = cfg._get_child(key)
607
+ if node is None:
608
+ return False
609
+ assert isinstance(node, Node)
610
+ return node._is_missing()
611
+ except (UnsupportedInterpolationType, KeyError, AttributeError):
612
+ return False
613
+
614
+ @staticmethod
615
+ def is_interpolation(node: Any, key: Optional[Union[int, str]] = None) -> bool:
616
+ if key is not None:
617
+ assert isinstance(node, Container)
618
+ target = node._get_child(key)
619
+ else:
620
+ target = node
621
+ if target is not None:
622
+ assert isinstance(target, Node)
623
+ return target._is_interpolation()
624
+ return False
625
+
626
+ @staticmethod
627
+ def is_list(obj: Any) -> bool:
628
+ from . import ListConfig
629
+
630
+ return isinstance(obj, ListConfig)
631
+
632
+ @staticmethod
633
+ def is_dict(obj: Any) -> bool:
634
+ from . import DictConfig
635
+
636
+ return isinstance(obj, DictConfig)
637
+
638
+ @staticmethod
639
+ def is_config(obj: Any) -> bool:
640
+ from . import Container
641
+
642
+ return isinstance(obj, Container)
643
+
644
+ @staticmethod
645
+ def get_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]:
646
+ if key is not None:
647
+ c = obj._get_child(key)
648
+ else:
649
+ c = obj
650
+ return OmegaConf._get_obj_type(c)
651
+
652
+ @staticmethod
653
+ def select(
654
+ cfg: Container,
655
+ key: str,
656
+ *,
657
+ default: Any = _DEFAULT_MARKER_,
658
+ throw_on_resolution_failure: bool = True,
659
+ throw_on_missing: bool = False,
660
+ ) -> Any:
661
+ """
662
+ :param cfg: Config node to select from
663
+ :param key: Key to select
664
+ :param default: Default value to return if key is not found
665
+ :param throw_on_resolution_failure: Raise an exception if an interpolation
666
+ resolution error occurs, otherwise return None
667
+ :param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???')
668
+ is made, otherwise return None
669
+ :return: selected value or None if not found.
670
+ """
671
+ from ._impl import select_value
672
+
673
+ try:
674
+ return select_value(
675
+ cfg=cfg,
676
+ key=key,
677
+ default=default,
678
+ throw_on_resolution_failure=throw_on_resolution_failure,
679
+ throw_on_missing=throw_on_missing,
680
+ )
681
+ except Exception as e:
682
+ format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e))
683
+
684
+ @staticmethod
685
+ def update(
686
+ cfg: Container,
687
+ key: str,
688
+ value: Any = None,
689
+ *,
690
+ merge: bool = True,
691
+ force_add: bool = False,
692
+ ) -> None:
693
+ """
694
+ Updates a dot separated key sequence to a value
695
+
696
+ :param cfg: input config to update
697
+ :param key: key to update (can be a dot separated path)
698
+ :param value: value to set, if value if a list or a dict it will be merged or set
699
+ depending on merge_config_values
700
+ :param merge: If value is a dict or a list, True (default) to merge
701
+ into the destination, False to replace the destination.
702
+ :param force_add: insert the entire path regardless of Struct flag or Structured Config nodes.
703
+ """
704
+
705
+ split = split_key(key)
706
+ root = cfg
707
+ for i in range(len(split) - 1):
708
+ k = split[i]
709
+ # if next_root is a primitive (string, int etc) replace it with an empty map
710
+ next_root, key_ = _select_one(root, k, throw_on_missing=False)
711
+ if not isinstance(next_root, Container):
712
+ if force_add:
713
+ with flag_override(root, "struct", False):
714
+ root[key_] = {}
715
+ else:
716
+ root[key_] = {}
717
+ root = root[key_]
718
+
719
+ last = split[-1]
720
+
721
+ assert isinstance(
722
+ root, Container
723
+ ), f"Unexpected type for root: {type(root).__name__}"
724
+
725
+ last_key: Union[str, int] = last
726
+ if isinstance(root, ListConfig):
727
+ last_key = int(last)
728
+
729
+ ctx = flag_override(root, "struct", False) if force_add else nullcontext()
730
+ with ctx:
731
+ if merge and (OmegaConf.is_config(value) or is_primitive_container(value)):
732
+ assert isinstance(root, BaseContainer)
733
+ node = root._get_child(last_key)
734
+ if OmegaConf.is_config(node):
735
+ assert isinstance(node, BaseContainer)
736
+ node.merge_with(value)
737
+ return
738
+
739
+ if OmegaConf.is_dict(root):
740
+ assert isinstance(last_key, str)
741
+ root.__setattr__(last_key, value)
742
+ elif OmegaConf.is_list(root):
743
+ assert isinstance(last_key, int)
744
+ root.__setitem__(last_key, value)
745
+ else:
746
+ assert False
747
+
748
+ @staticmethod
749
+ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:
750
+ """
751
+ returns a yaml dump of this config object.
752
+
753
+ :param cfg: Config object, Structured Config type or instance
754
+ :param resolve: if True, will return a string with the interpolations resolved, otherwise
755
+ interpolations are preserved
756
+ :param sort_keys: If True, will print dict keys in sorted order. default False.
757
+ :return: A string containing the yaml representation.
758
+ """
759
+ cfg = _ensure_container(cfg)
760
+ container = OmegaConf.to_container(cfg, resolve=resolve, enum_to_str=True)
761
+ return yaml.dump( # type: ignore
762
+ container,
763
+ default_flow_style=False,
764
+ allow_unicode=True,
765
+ sort_keys=sort_keys,
766
+ Dumper=get_omega_conf_dumper(),
767
+ )
768
+
769
+ @staticmethod
770
+ def resolve(cfg: Container) -> None:
771
+ """
772
+ Resolves all interpolations in the given config object in-place.
773
+
774
+ :param cfg: An OmegaConf container (DictConfig, ListConfig)
775
+ Raises a ValueError if the input object is not an OmegaConf container.
776
+ """
777
+ import omegaconf._impl
778
+
779
+ if not OmegaConf.is_config(cfg):
780
+ # Since this function is mutating the input object in-place, it doesn't make sense to
781
+ # auto-convert the input object to an OmegaConf container
782
+ raise ValueError(
783
+ f"Invalid config type ({type(cfg).__name__}), expected an OmegaConf Container"
784
+ )
785
+ omegaconf._impl._resolve(cfg)
786
+
787
+ @staticmethod
788
+ def missing_keys(cfg: Any) -> Set[str]:
789
+ """
790
+ Returns a set of missing keys in a dotlist style.
791
+
792
+ :param cfg: An ``OmegaConf.Container``,
793
+ or a convertible object via ``OmegaConf.create`` (dict, list, ...).
794
+ :return: set of strings of the missing keys.
795
+ :raises ValueError: On input not representing a config.
796
+ """
797
+ cfg = _ensure_container(cfg)
798
+ missings: Set[str] = set()
799
+
800
+ def gather(_cfg: Container) -> None:
801
+ itr: Iterable[Any]
802
+ if isinstance(_cfg, ListConfig):
803
+ itr = range(len(_cfg))
804
+ else:
805
+ itr = _cfg
806
+
807
+ for key in itr:
808
+ if OmegaConf.is_missing(_cfg, key):
809
+ missings.add(_cfg._get_full_key(key))
810
+ elif OmegaConf.is_config(_cfg[key]):
811
+ gather(_cfg[key])
812
+
813
+ gather(cfg)
814
+ return missings
815
+
816
+ # === private === #
817
+
818
+ @staticmethod
819
+ def _create_impl( # noqa F811
820
+ obj: Any = _DEFAULT_MARKER_,
821
+ parent: Optional[BaseContainer] = None,
822
+ flags: Optional[Dict[str, bool]] = None,
823
+ ) -> Union[DictConfig, ListConfig]:
824
+ try:
825
+ from ._utils import get_yaml_loader
826
+ from .dictconfig import DictConfig
827
+ from .listconfig import ListConfig
828
+
829
+ if obj is _DEFAULT_MARKER_:
830
+ obj = {}
831
+ if isinstance(obj, str):
832
+ obj = yaml.load(obj, Loader=get_yaml_loader())
833
+ if obj is None:
834
+ return OmegaConf.create({}, parent=parent, flags=flags)
835
+ elif isinstance(obj, str):
836
+ return OmegaConf.create({obj: None}, parent=parent, flags=flags)
837
+ else:
838
+ assert isinstance(obj, (list, dict))
839
+ return OmegaConf.create(obj, parent=parent, flags=flags)
840
+
841
+ else:
842
+ if (
843
+ is_primitive_dict(obj)
844
+ or OmegaConf.is_dict(obj)
845
+ or is_structured_config(obj)
846
+ or obj is None
847
+ ):
848
+ if isinstance(obj, DictConfig):
849
+ return DictConfig(
850
+ content=obj,
851
+ parent=parent,
852
+ ref_type=obj._metadata.ref_type,
853
+ is_optional=obj._metadata.optional,
854
+ key_type=obj._metadata.key_type,
855
+ element_type=obj._metadata.element_type,
856
+ flags=flags,
857
+ )
858
+ else:
859
+ obj_type = OmegaConf.get_type(obj)
860
+ key_type, element_type = get_dict_key_value_types(obj_type)
861
+ return DictConfig(
862
+ content=obj,
863
+ parent=parent,
864
+ key_type=key_type,
865
+ element_type=element_type,
866
+ flags=flags,
867
+ )
868
+ elif is_primitive_list(obj) or OmegaConf.is_list(obj):
869
+ if isinstance(obj, ListConfig):
870
+ return ListConfig(
871
+ content=obj,
872
+ parent=parent,
873
+ element_type=obj._metadata.element_type,
874
+ ref_type=obj._metadata.ref_type,
875
+ is_optional=obj._metadata.optional,
876
+ flags=flags,
877
+ )
878
+ else:
879
+ obj_type = OmegaConf.get_type(obj)
880
+ element_type = get_list_element_type(obj_type)
881
+ return ListConfig(
882
+ content=obj,
883
+ parent=parent,
884
+ element_type=element_type,
885
+ ref_type=Any,
886
+ is_optional=True,
887
+ flags=flags,
888
+ )
889
+ else:
890
+ if isinstance(obj, type):
891
+ raise ValidationError(
892
+ f"Input class '{obj.__name__}' is not a structured config. "
893
+ "did you forget to decorate it as a dataclass?"
894
+ )
895
+ else:
896
+ raise ValidationError(
897
+ f"Object of unsupported type: '{type(obj).__name__}'"
898
+ )
899
+ except OmegaConfBaseException as e:
900
+ format_and_raise(node=None, key=None, value=None, msg=str(e), cause=e)
901
+ assert False
902
+
903
+ @staticmethod
904
+ def _get_obj_type(c: Any) -> Optional[Type[Any]]:
905
+ if is_structured_config(c):
906
+ return get_type_of(c)
907
+ elif c is None:
908
+ return None
909
+ elif isinstance(c, DictConfig):
910
+ if c._is_none():
911
+ return None
912
+ elif c._is_missing():
913
+ return None
914
+ else:
915
+ if is_structured_config(c._metadata.object_type):
916
+ return c._metadata.object_type
917
+ else:
918
+ return dict
919
+ elif isinstance(c, ListConfig):
920
+ return list
921
+ elif isinstance(c, ValueNode):
922
+ return type(c._value())
923
+ elif isinstance(c, UnionNode):
924
+ return type(_get_value(c))
925
+ elif isinstance(c, dict):
926
+ return dict
927
+ elif isinstance(c, (list, tuple)):
928
+ return list
929
+ else:
930
+ return get_type_of(c)
931
+
932
+ @staticmethod
933
+ def _get_resolver(
934
+ name: str,
935
+ ) -> Optional[
936
+ Callable[
937
+ [Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]],
938
+ Any,
939
+ ]
940
+ ]:
941
+ # noinspection PyProtectedMember
942
+ return (
943
+ BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
944
+ )
945
+
946
+
947
+ # register all default resolvers
948
+ register_default_resolvers()
949
+
950
+
951
+ @contextmanager
952
+ def flag_override(
953
+ config: Node,
954
+ names: Union[List[str], str],
955
+ values: Union[List[Optional[bool]], Optional[bool]],
956
+ ) -> Generator[Node, None, None]:
957
+ if isinstance(names, str):
958
+ names = [names]
959
+ if values is None or isinstance(values, bool):
960
+ values = [values]
961
+
962
+ prev_states = [config._get_node_flag(name) for name in names]
963
+
964
+ try:
965
+ config._set_flag(names, values)
966
+ yield config
967
+ finally:
968
+ config._set_flag(names, prev_states)
969
+
970
+
971
+ @contextmanager
972
+ def read_write(config: Node) -> Generator[Node, None, None]:
973
+ prev_state = config._get_node_flag("readonly")
974
+ try:
975
+ OmegaConf.set_readonly(config, False)
976
+ yield config
977
+ finally:
978
+ OmegaConf.set_readonly(config, prev_state)
979
+
980
+
981
+ @contextmanager
982
+ def open_dict(config: Container) -> Generator[Container, None, None]:
983
+ prev_state = config._get_node_flag("struct")
984
+ try:
985
+ OmegaConf.set_struct(config, False)
986
+ yield config
987
+ finally:
988
+ OmegaConf.set_struct(config, prev_state)
989
+
990
+
991
+ # === private === #
992
+
993
+
994
+ def _node_wrap(
995
+ parent: Optional[Box],
996
+ is_optional: bool,
997
+ value: Any,
998
+ key: Any,
999
+ ref_type: Any = Any,
1000
+ ) -> Node:
1001
+ node: Node
1002
+ if is_dict_annotation(ref_type) or (is_primitive_dict(value) and ref_type is Any):
1003
+ key_type, element_type = get_dict_key_value_types(ref_type)
1004
+ node = DictConfig(
1005
+ content=value,
1006
+ key=key,
1007
+ parent=parent,
1008
+ ref_type=ref_type,
1009
+ is_optional=is_optional,
1010
+ key_type=key_type,
1011
+ element_type=element_type,
1012
+ )
1013
+ elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or (
1014
+ type(value) in (list, tuple) and ref_type is Any
1015
+ ):
1016
+ element_type = get_list_element_type(ref_type)
1017
+ node = ListConfig(
1018
+ content=value,
1019
+ key=key,
1020
+ parent=parent,
1021
+ is_optional=is_optional,
1022
+ element_type=element_type,
1023
+ ref_type=ref_type,
1024
+ )
1025
+ elif is_structured_config(ref_type) or is_structured_config(value):
1026
+ key_type, element_type = get_dict_key_value_types(value)
1027
+ node = DictConfig(
1028
+ ref_type=ref_type,
1029
+ is_optional=is_optional,
1030
+ content=value,
1031
+ key=key,
1032
+ parent=parent,
1033
+ key_type=key_type,
1034
+ element_type=element_type,
1035
+ )
1036
+ elif is_union_annotation(ref_type):
1037
+ node = UnionNode(
1038
+ content=value,
1039
+ ref_type=ref_type,
1040
+ is_optional=is_optional,
1041
+ key=key,
1042
+ parent=parent,
1043
+ )
1044
+ elif ref_type == Any or ref_type is None:
1045
+ node = AnyNode(value=value, key=key, parent=parent)
1046
+ elif isinstance(ref_type, type) and issubclass(ref_type, Enum):
1047
+ node = EnumNode(
1048
+ enum_type=ref_type,
1049
+ value=value,
1050
+ key=key,
1051
+ parent=parent,
1052
+ is_optional=is_optional,
1053
+ )
1054
+ elif ref_type == int:
1055
+ node = IntegerNode(value=value, key=key, parent=parent, is_optional=is_optional)
1056
+ elif ref_type == float:
1057
+ node = FloatNode(value=value, key=key, parent=parent, is_optional=is_optional)
1058
+ elif ref_type == bool:
1059
+ node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
1060
+ elif ref_type == str:
1061
+ node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
1062
+ elif ref_type == bytes:
1063
+ node = BytesNode(value=value, key=key, parent=parent, is_optional=is_optional)
1064
+ elif ref_type == pathlib.Path:
1065
+ node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional)
1066
+ else:
1067
+ if parent is not None and parent._get_flag("allow_objects") is True:
1068
+ if type(value) in (list, tuple):
1069
+ node = ListConfig(
1070
+ content=value,
1071
+ key=key,
1072
+ parent=parent,
1073
+ ref_type=ref_type,
1074
+ is_optional=is_optional,
1075
+ )
1076
+ elif is_primitive_dict(value):
1077
+ node = DictConfig(
1078
+ content=value,
1079
+ key=key,
1080
+ parent=parent,
1081
+ ref_type=ref_type,
1082
+ is_optional=is_optional,
1083
+ )
1084
+ else:
1085
+ node = AnyNode(value=value, key=key, parent=parent)
1086
+ else:
1087
+ raise ValidationError(f"Unexpected type annotation: {type_str(ref_type)}")
1088
+ return node
1089
+
1090
+
1091
+ def _maybe_wrap(
1092
+ ref_type: Any,
1093
+ key: Any,
1094
+ value: Any,
1095
+ is_optional: bool,
1096
+ parent: Optional[BaseContainer],
1097
+ ) -> Node:
1098
+ # if already a node, update key and parent and return as is.
1099
+ # NOTE: that this mutate the input node!
1100
+ if isinstance(value, Node):
1101
+ value._set_key(key)
1102
+ value._set_parent(parent)
1103
+ return value
1104
+ else:
1105
+ return _node_wrap(
1106
+ ref_type=ref_type,
1107
+ parent=parent,
1108
+ is_optional=is_optional,
1109
+ value=value,
1110
+ key=key,
1111
+ )
1112
+
1113
+
1114
+ def _select_one(
1115
+ c: Container, key: str, throw_on_missing: bool, throw_on_type_error: bool = True
1116
+ ) -> Tuple[Optional[Node], Union[str, int]]:
1117
+ from .dictconfig import DictConfig
1118
+ from .listconfig import ListConfig
1119
+
1120
+ ret_key: Union[str, int] = key
1121
+ assert isinstance(c, Container), f"Unexpected type: {c}"
1122
+ if c._is_none():
1123
+ return None, ret_key
1124
+
1125
+ if isinstance(c, DictConfig):
1126
+ assert isinstance(ret_key, str)
1127
+ val = c._get_child(ret_key, validate_access=False)
1128
+ elif isinstance(c, ListConfig):
1129
+ assert isinstance(ret_key, str)
1130
+ if not is_int(ret_key):
1131
+ if throw_on_type_error:
1132
+ raise TypeError(
1133
+ f"Index '{ret_key}' ({type(ret_key).__name__}) is not an int"
1134
+ )
1135
+ else:
1136
+ val = None
1137
+ else:
1138
+ ret_key = int(ret_key)
1139
+ if ret_key < 0 or ret_key + 1 > len(c):
1140
+ val = None
1141
+ else:
1142
+ val = c._get_child(ret_key)
1143
+ else:
1144
+ assert False
1145
+
1146
+ if val is not None:
1147
+ assert isinstance(val, Node)
1148
+ if val._is_missing():
1149
+ if throw_on_missing:
1150
+ raise MissingMandatoryValue(
1151
+ f"Missing mandatory value: {c._get_full_key(ret_key)}"
1152
+ )
1153
+ else:
1154
+ return val, ret_key
1155
+
1156
+ assert val is None or isinstance(val, Node)
1157
+ return val, ret_key
omegaconf/py.typed ADDED
File without changes
omegaconf/resolvers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from omegaconf.resolvers import oc
2
+
3
+ __all__ = [
4
+ "oc",
5
+ ]
omegaconf/resolvers/oc/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import string
3
+ import warnings
4
+ from typing import Any, Optional
5
+
6
+ from omegaconf import Container, Node
7
+ from omegaconf._utils import _DEFAULT_MARKER_, _get_value
8
+ from omegaconf.basecontainer import BaseContainer
9
+ from omegaconf.errors import ConfigKeyError
10
+ from omegaconf.grammar_parser import parse
11
+ from omegaconf.resolvers.oc import dict
12
+
13
+
14
+ def create(obj: Any, _parent_: Container) -> Any:
15
+ """Create a config object from `obj`, similar to `OmegaConf.create`"""
16
+ from omegaconf import OmegaConf
17
+
18
+ assert isinstance(_parent_, BaseContainer)
19
+ return OmegaConf.create(obj, parent=_parent_)
20
+
21
+
22
+ def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
23
+ """
24
+ :param key: Environment variable key
25
+ :param default: Optional default value to use in case the key environment variable is not set.
26
+ If default is not a string, it is converted with str(default).
27
+ None default is returned as is.
28
+ :return: The environment variable 'key'. If the environment variable is not set and a default is
29
+ provided, the default is used. If used, the default is converted to a string with str(default).
30
+ If the default is None, None is returned (without a string conversion).
31
+ """
32
+ try:
33
+ return os.environ[key]
34
+ except KeyError:
35
+ if default is not _DEFAULT_MARKER_:
36
+ return str(default) if default is not None else None
37
+ else:
38
+ raise KeyError(f"Environment variable '{key}' not found")
39
+
40
+
41
+ def decode(expr: Optional[str], _parent_: Container, _node_: Node) -> Any:
42
+ """
43
+ Parse and evaluate `expr` according to the `singleElement` rule of the grammar.
44
+
45
+ If `expr` is `None`, then return `None`.
46
+ """
47
+ if expr is None:
48
+ return None
49
+
50
+ if not isinstance(expr, str):
51
+ raise TypeError(
52
+ f"`oc.decode` can only take strings or None as input, "
53
+ f"but `{expr}` is of type {type(expr).__name__}"
54
+ )
55
+
56
+ parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE")
57
+ val = _parent_.resolve_parse_tree(parse_tree, node=_node_)
58
+ return _get_value(val)
59
+
60
+
61
+ def deprecated(
62
+ key: str,
63
+ message: str = "'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'",
64
+ *,
65
+ _parent_: Container,
66
+ _node_: Node,
67
+ ) -> Any:
68
+ from omegaconf._impl import select_node
69
+
70
+ if not isinstance(key, str):
71
+ raise TypeError(
72
+ f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})"
73
+ )
74
+
75
+ if not isinstance(message, str):
76
+ raise TypeError(
77
+ f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})"
78
+ )
79
+
80
+ full_key = _node_._get_full_key(key=None)
81
+ target_node = select_node(_parent_, key, absolute_key=True)
82
+ if target_node is None:
83
+ raise ConfigKeyError(
84
+ f"In oc.deprecated resolver at '{full_key}': Key not found: '{key}'"
85
+ )
86
+ new_key = target_node._get_full_key(key=None)
87
+ msg = string.Template(message).safe_substitute(
88
+ OLD_KEY=full_key,
89
+ NEW_KEY=new_key,
90
+ )
91
+ warnings.warn(category=UserWarning, message=msg)
92
+ return target_node
93
+
94
+
95
+ def select(
96
+ key: str,
97
+ default: Any = _DEFAULT_MARKER_,
98
+ *,
99
+ _parent_: Container,
100
+ ) -> Any:
101
+ from omegaconf._impl import select_value
102
+
103
+ return select_value(cfg=_parent_, key=key, absolute_key=True, default=default)
104
+
105
+
106
+ __all__ = [
107
+ "create",
108
+ "decode",
109
+ "deprecated",
110
+ "dict",
111
+ "env",
112
+ "select",
113
+ ]
omegaconf/resolvers/oc/dict.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ from omegaconf import AnyNode, Container, DictConfig, ListConfig
4
+ from omegaconf._utils import Marker
5
+ from omegaconf.basecontainer import BaseContainer
6
+ from omegaconf.errors import ConfigKeyError
7
+
8
+ _DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")
9
+
10
+
11
+ def keys(
12
+ key: str,
13
+ _parent_: Container,
14
+ ) -> ListConfig:
15
+ from omegaconf import OmegaConf
16
+
17
+ assert isinstance(_parent_, BaseContainer)
18
+
19
+ in_dict = _get_and_validate_dict_input(
20
+ key, parent=_parent_, resolver_name="oc.dict.keys"
21
+ )
22
+
23
+ ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
24
+ assert isinstance(ret, ListConfig)
25
+ return ret
26
+
27
+
28
+ def values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
29
+ assert isinstance(_parent_, BaseContainer)
30
+ in_dict = _get_and_validate_dict_input(
31
+ key, parent=_parent_, resolver_name="oc.dict.values"
32
+ )
33
+
34
+ content = in_dict._content
35
+ assert isinstance(content, dict)
36
+
37
+ ret = ListConfig([])
38
+ if key.startswith("."):
39
+ key = f".{key}" # extra dot to compensate for extra level of nesting within ret ListConfig
40
+ for k in content:
41
+ ref_node = AnyNode(f"${{{key}.{k!s}}}")
42
+ ret.append(ref_node)
43
+
44
+ # Finalize result by setting proper type and parent.
45
+ element_type: Any = in_dict._metadata.element_type
46
+ ret._metadata.element_type = element_type
47
+ ret._metadata.ref_type = List[element_type]
48
+ ret._set_parent(_parent_)
49
+
50
+ return ret
51
+
52
+
53
+ def _get_and_validate_dict_input(
54
+ key: str,
55
+ parent: BaseContainer,
56
+ resolver_name: str,
57
+ ) -> DictConfig:
58
+ from omegaconf._impl import select_value
59
+
60
+ if not isinstance(key, str):
61
+ raise TypeError(
62
+ f"`{resolver_name}` requires a string as input, but obtained `{key}` "
63
+ f"of type: {type(key).__name__}"
64
+ )
65
+
66
+ in_dict = select_value(
67
+ parent,
68
+ key,
69
+ throw_on_missing=True,
70
+ absolute_key=True,
71
+ default=_DEFAULT_SELECT_MARKER_,
72
+ )
73
+
74
+ if in_dict is _DEFAULT_SELECT_MARKER_:
75
+ raise ConfigKeyError(f"Key not found: '{key}'")
76
+
77
+ if not isinstance(in_dict, DictConfig):
78
+ raise TypeError(
79
+ f"`{resolver_name}` cannot be applied to objects of type: "
80
+ f"{type(in_dict).__name__}"
81
+ )
82
+
83
+ return in_dict
omegaconf/version.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys # pragma: no cover
2
+
3
+ __version__ = "2.3.0"
4
+
5
+ msg = """OmegaConf 2.0 and above is compatible with Python 3.6 and newer.
6
+ You have the following options:
7
+ 1. Upgrade to Python 3.6 or newer.
8
+ This is highly recommended. new features will not be added to OmegaConf 1.4.
9
+ 2. Continue using OmegaConf 1.4:
10
+ You can pip install 'OmegaConf<1.5' to do that.
11
+ """
12
+ if sys.version_info < (3, 6):
13
+ raise ImportError(msg) # pragma: no cover