Tiger14n commited on
Commit
0c229ff
·
1 Parent(s): 914e267
omegaconf/__init__.py DELETED
@@ -1,65 +0,0 @@
1
- from .base import Container, DictKeyType, Node, SCMode, UnionNode
2
- from .dictconfig import DictConfig
3
- from .errors import (
4
- KeyValidationError,
5
- MissingMandatoryValue,
6
- ReadonlyConfigError,
7
- UnsupportedValueType,
8
- ValidationError,
9
- )
10
- from .listconfig import ListConfig
11
- from .nodes import (
12
- AnyNode,
13
- BooleanNode,
14
- BytesNode,
15
- EnumNode,
16
- FloatNode,
17
- IntegerNode,
18
- PathNode,
19
- StringNode,
20
- ValueNode,
21
- )
22
- from .omegaconf import (
23
- II,
24
- MISSING,
25
- SI,
26
- OmegaConf,
27
- Resolver,
28
- flag_override,
29
- open_dict,
30
- read_write,
31
- )
32
- from .version import __version__
33
-
34
- __all__ = [
35
- "__version__",
36
- "MissingMandatoryValue",
37
- "ValidationError",
38
- "ReadonlyConfigError",
39
- "UnsupportedValueType",
40
- "KeyValidationError",
41
- "Container",
42
- "UnionNode",
43
- "ListConfig",
44
- "DictConfig",
45
- "DictKeyType",
46
- "OmegaConf",
47
- "Resolver",
48
- "SCMode",
49
- "flag_override",
50
- "read_write",
51
- "open_dict",
52
- "Node",
53
- "ValueNode",
54
- "AnyNode",
55
- "IntegerNode",
56
- "StringNode",
57
- "BytesNode",
58
- "PathNode",
59
- "BooleanNode",
60
- "EnumNode",
61
- "FloatNode",
62
- "MISSING",
63
- "SI",
64
- "II",
65
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/_impl.py DELETED
@@ -1,101 +0,0 @@
1
- from typing import Any
2
-
3
- from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode
4
- from omegaconf.errors import ConfigTypeError, InterpolationToMissingValueError
5
-
6
- from ._utils import _DEFAULT_MARKER_, _get_value
7
-
8
-
9
- def _resolve_container_value(cfg: Container, key: Any) -> None:
10
- node = cfg._get_child(key)
11
- assert isinstance(node, Node)
12
- if node._is_interpolation():
13
- try:
14
- resolved = node._dereference_node()
15
- except InterpolationToMissingValueError:
16
- node._set_value(MISSING)
17
- else:
18
- if isinstance(resolved, Container):
19
- _resolve(resolved)
20
- if isinstance(resolved, Container) and isinstance(node, ValueNode):
21
- cfg[key] = resolved
22
- else:
23
- node._set_value(_get_value(resolved))
24
- else:
25
- _resolve(node)
26
-
27
-
28
- def _resolve(cfg: Node) -> Node:
29
- assert isinstance(cfg, Node)
30
- if cfg._is_interpolation():
31
- try:
32
- resolved = cfg._dereference_node()
33
- except InterpolationToMissingValueError:
34
- cfg._set_value(MISSING)
35
- else:
36
- cfg._set_value(resolved._value())
37
-
38
- if isinstance(cfg, DictConfig):
39
- for k in cfg.keys():
40
- _resolve_container_value(cfg, k)
41
-
42
- elif isinstance(cfg, ListConfig):
43
- for i in range(len(cfg)):
44
- _resolve_container_value(cfg, i)
45
-
46
- return cfg
47
-
48
-
49
- def select_value(
50
- cfg: Container,
51
- key: str,
52
- *,
53
- default: Any = _DEFAULT_MARKER_,
54
- throw_on_resolution_failure: bool = True,
55
- throw_on_missing: bool = False,
56
- absolute_key: bool = False,
57
- ) -> Any:
58
- node = select_node(
59
- cfg=cfg,
60
- key=key,
61
- throw_on_resolution_failure=throw_on_resolution_failure,
62
- throw_on_missing=throw_on_missing,
63
- absolute_key=absolute_key,
64
- )
65
-
66
- node_not_found = node is None
67
- if node_not_found or node._is_missing():
68
- if default is not _DEFAULT_MARKER_:
69
- return default
70
- else:
71
- return None
72
-
73
- return _get_value(node)
74
-
75
-
76
- def select_node(
77
- cfg: Container,
78
- key: str,
79
- *,
80
- throw_on_resolution_failure: bool = True,
81
- throw_on_missing: bool = False,
82
- absolute_key: bool = False,
83
- ) -> Any:
84
- try:
85
- # for non relative keys, the interpretation can be:
86
- # 1. relative to cfg
87
- # 2. relative to the config root
88
- # This is controlled by the absolute_key flag. By default, such keys are relative to cfg.
89
- if not absolute_key and not key.startswith("."):
90
- key = f".{key}"
91
-
92
- cfg, key = cfg._resolve_key_and_root(key)
93
- _root, _last_key, node = cfg._select_impl(
94
- key,
95
- throw_on_missing=throw_on_missing,
96
- throw_on_resolution_failure=throw_on_resolution_failure,
97
- )
98
- except ConfigTypeError:
99
- return None
100
-
101
- return node
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/_utils.py DELETED
@@ -1,1039 +0,0 @@
1
- import copy
2
- import os
3
- import pathlib
4
- import re
5
- import string
6
- import sys
7
- import types
8
- import warnings
9
- from contextlib import contextmanager
10
- from enum import Enum
11
- from textwrap import dedent
12
- from typing import (
13
- Any,
14
- Dict,
15
- Iterator,
16
- List,
17
- Optional,
18
- Tuple,
19
- Type,
20
- Union,
21
- get_type_hints,
22
- )
23
-
24
- import yaml
25
-
26
- from .errors import (
27
- ConfigIndexError,
28
- ConfigTypeError,
29
- ConfigValueError,
30
- GrammarParseError,
31
- OmegaConfBaseException,
32
- ValidationError,
33
- )
34
- from .grammar_parser import SIMPLE_INTERPOLATION_PATTERN, parse
35
-
36
- try:
37
- import dataclasses
38
-
39
- except ImportError: # pragma: no cover
40
- dataclasses = None # type: ignore # pragma: no cover
41
-
42
- try:
43
- import attr
44
-
45
- except ImportError: # pragma: no cover
46
- attr = None # type: ignore # pragma: no cover
47
-
48
- NoneType: Type[None] = type(None)
49
-
50
- BUILTIN_VALUE_TYPES: Tuple[Type[Any], ...] = (
51
- int,
52
- float,
53
- bool,
54
- str,
55
- bytes,
56
- NoneType,
57
- )
58
-
59
- # Regexprs to match key paths like: a.b, a[b], ..a[c].d, etc.
60
- # We begin by matching the head (in these examples: a, a, ..a).
61
- # This can be read as "dots followed by any character but `.` or `[`"
62
- # Note that a key starting with brackets, like [a], is purposedly *not*
63
- # matched here and will instead be handled in the next regex below (this
64
- # is to keep this regex simple).
65
- KEY_PATH_HEAD = re.compile(r"(\.)*[^.[]*")
66
- # Then we match other keys. The following expression matches one key and can
67
- # be read as a choice between two syntaxes:
68
- # - `.` followed by anything except `.` or `[` (ex: .b, .d)
69
- # - `[` followed by anything then `]` (ex: [b], [c])
70
- KEY_PATH_OTHER = re.compile(r"\.([^.[]*)|\[(.*?)\]")
71
-
72
-
73
- # source: https://yaml.org/type/bool.html
74
- YAML_BOOL_TYPES = [
75
- "y",
76
- "Y",
77
- "yes",
78
- "Yes",
79
- "YES",
80
- "n",
81
- "N",
82
- "no",
83
- "No",
84
- "NO",
85
- "true",
86
- "True",
87
- "TRUE",
88
- "false",
89
- "False",
90
- "FALSE",
91
- "on",
92
- "On",
93
- "ON",
94
- "off",
95
- "Off",
96
- "OFF",
97
- ]
98
-
99
-
100
- class Marker:
101
- def __init__(self, desc: str):
102
- self.desc = desc
103
-
104
- def __repr__(self) -> str:
105
- return self.desc
106
-
107
-
108
- # To be used as default value when `None` is not an option.
109
- _DEFAULT_MARKER_: Any = Marker("_DEFAULT_MARKER_")
110
-
111
-
112
- class OmegaConfDumper(yaml.Dumper): # type: ignore
113
- str_representer_added = False
114
-
115
- @staticmethod
116
- def str_representer(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
117
- with_quotes = yaml_is_bool(data) or is_int(data) or is_float(data)
118
- return dumper.represent_scalar(
119
- yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG,
120
- data,
121
- style=("'" if with_quotes else None),
122
- )
123
-
124
-
125
- def get_omega_conf_dumper() -> Type[OmegaConfDumper]:
126
- if not OmegaConfDumper.str_representer_added:
127
- OmegaConfDumper.add_representer(str, OmegaConfDumper.str_representer)
128
- OmegaConfDumper.str_representer_added = True
129
- return OmegaConfDumper
130
-
131
-
132
- def yaml_is_bool(b: str) -> bool:
133
- return b in YAML_BOOL_TYPES
134
-
135
-
136
- def get_yaml_loader() -> Any:
137
- class OmegaConfLoader(yaml.SafeLoader): # type: ignore
138
- def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Any:
139
- keys = set()
140
- for key_node, value_node in node.value:
141
- if key_node.tag != yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG:
142
- continue
143
- if key_node.value in keys:
144
- raise yaml.constructor.ConstructorError(
145
- "while constructing a mapping",
146
- node.start_mark,
147
- f"found duplicate key {key_node.value}",
148
- key_node.start_mark,
149
- )
150
- keys.add(key_node.value)
151
- return super().construct_mapping(node, deep=deep)
152
-
153
- loader = OmegaConfLoader
154
- loader.add_implicit_resolver(
155
- "tag:yaml.org,2002:float",
156
- re.compile(
157
- """^(?:
158
- [-+]?[0-9]+(?:_[0-9]+)*\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
159
- |[-+]?[0-9]+(?:_[0-9]+)*(?:[eE][-+]?[0-9]+)
160
- |\\.[0-9]+(?:_[0-9]+)*(?:[eE][-+][0-9]+)?
161
- |[-+]?[0-9]+(?:_[0-9]+)*(?::[0-5]?[0-9])+\\.[0-9_]*
162
- |[-+]?\\.(?:inf|Inf|INF)
163
- |\\.(?:nan|NaN|NAN))$""",
164
- re.X,
165
- ),
166
- list("-+0123456789."),
167
- )
168
- loader.yaml_implicit_resolvers = {
169
- key: [
170
- (tag, regexp)
171
- for tag, regexp in resolvers
172
- if tag != "tag:yaml.org,2002:timestamp"
173
- ]
174
- for key, resolvers in loader.yaml_implicit_resolvers.items()
175
- }
176
-
177
- loader.add_constructor(
178
- "tag:yaml.org,2002:python/object/apply:pathlib.Path",
179
- lambda loader, node: pathlib.Path(*loader.construct_sequence(node)),
180
- )
181
- loader.add_constructor(
182
- "tag:yaml.org,2002:python/object/apply:pathlib.PosixPath",
183
- lambda loader, node: pathlib.PosixPath(*loader.construct_sequence(node)),
184
- )
185
- loader.add_constructor(
186
- "tag:yaml.org,2002:python/object/apply:pathlib.WindowsPath",
187
- lambda loader, node: pathlib.WindowsPath(*loader.construct_sequence(node)),
188
- )
189
-
190
- return loader
191
-
192
-
193
- def _get_class(path: str) -> type:
194
- from importlib import import_module
195
-
196
- module_path, _, class_name = path.rpartition(".")
197
- mod = import_module(module_path)
198
- try:
199
- klass: type = getattr(mod, class_name)
200
- except AttributeError:
201
- raise ImportError(f"Class {class_name} is not in module {module_path}")
202
- return klass
203
-
204
-
205
- def is_union_annotation(type_: Any) -> bool:
206
- if sys.version_info >= (3, 10): # pragma: no cover
207
- if isinstance(type_, types.UnionType):
208
- return True
209
- return getattr(type_, "__origin__", None) is Union
210
-
211
-
212
- def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
213
- """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
214
- if is_union_annotation(type_):
215
- args = type_.__args__
216
- if NoneType in args:
217
- optional = True
218
- args = tuple(a for a in args if a is not NoneType)
219
- else:
220
- optional = False
221
- if len(args) == 1:
222
- return optional, args[0]
223
- elif len(args) >= 2:
224
- return optional, Union[args]
225
- else:
226
- assert False
227
-
228
- if type_ is Any:
229
- return True, Any
230
-
231
- if type_ in (None, NoneType):
232
- return True, NoneType
233
-
234
- return False, type_
235
-
236
-
237
- def _is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool:
238
- """Check `obj` metadata to see if the given node is optional."""
239
- from .base import Container, Node
240
-
241
- if key is not None:
242
- assert isinstance(obj, Container)
243
- obj = obj._get_node(key)
244
- assert isinstance(obj, Node)
245
- return obj._is_optional()
246
-
247
-
248
- def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
249
- import typing # lgtm [py/import-and-import-from]
250
-
251
- forward = typing.ForwardRef if hasattr(typing, "ForwardRef") else typing._ForwardRef # type: ignore
252
- if type(type_) is forward:
253
- return _get_class(f"{module}.{type_.__forward_arg__}")
254
- else:
255
- if is_dict_annotation(type_):
256
- kt, vt = get_dict_key_value_types(type_)
257
- if kt is not None:
258
- kt = _resolve_forward(kt, module=module)
259
- if vt is not None:
260
- vt = _resolve_forward(vt, module=module)
261
- return Dict[kt, vt] # type: ignore
262
- if is_list_annotation(type_):
263
- et = get_list_element_type(type_)
264
- if et is not None:
265
- et = _resolve_forward(et, module=module)
266
- return List[et] # type: ignore
267
- if is_tuple_annotation(type_):
268
- its = get_tuple_item_types(type_)
269
- its = tuple(_resolve_forward(it, module=module) for it in its)
270
- return Tuple[its] # type: ignore
271
-
272
- return type_
273
-
274
-
275
- def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]]:
276
- """Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
277
- from omegaconf.omegaconf import _maybe_wrap
278
-
279
- is_type = isinstance(obj, type)
280
- obj_type = obj if is_type else type(obj)
281
- subclasses_dict = is_dict_subclass(obj_type)
282
-
283
- if subclasses_dict:
284
- warnings.warn(
285
- f"Class `{obj_type.__name__}` subclasses `Dict`."
286
- + " Subclassing `Dict` in Structured Config classes is deprecated,"
287
- + " see github.com/omry/omegaconf/issues/663",
288
- UserWarning,
289
- stacklevel=9,
290
- )
291
-
292
- if is_type:
293
- return None
294
- elif subclasses_dict:
295
- dict_subclass_data = {}
296
- key_type, element_type = get_dict_key_value_types(obj_type)
297
- for name, value in obj.items():
298
- is_optional, type_ = _resolve_optional(element_type)
299
- type_ = _resolve_forward(type_, obj.__module__)
300
- try:
301
- dict_subclass_data[name] = _maybe_wrap(
302
- ref_type=type_,
303
- is_optional=is_optional,
304
- key=name,
305
- value=value,
306
- parent=parent,
307
- )
308
- except ValidationError as ex:
309
- format_and_raise(
310
- node=None, key=name, value=value, cause=ex, msg=str(ex)
311
- )
312
- return dict_subclass_data
313
- else:
314
- return None
315
-
316
-
317
- def get_attr_class_fields(obj: Any) -> List["attr.Attribute[Any]"]:
318
- is_type = isinstance(obj, type)
319
- obj_type = obj if is_type else type(obj)
320
- fields = attr.fields_dict(obj_type).values()
321
- return [f for f in fields if f.metadata.get("omegaconf_ignore") is not True]
322
-
323
-
324
- def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
325
- from omegaconf.omegaconf import OmegaConf, _maybe_wrap
326
-
327
- flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
328
-
329
- from omegaconf import MISSING
330
-
331
- d = {}
332
- is_type = isinstance(obj, type)
333
- obj_type = obj if is_type else type(obj)
334
- dummy_parent = OmegaConf.create({}, flags=flags)
335
- dummy_parent._metadata.object_type = obj_type
336
- resolved_hints = get_type_hints(obj_type)
337
-
338
- for attrib in get_attr_class_fields(obj):
339
- name = attrib.name
340
- is_optional, type_ = _resolve_optional(resolved_hints[name])
341
- type_ = _resolve_forward(type_, obj.__module__)
342
- if not is_type:
343
- value = getattr(obj, name)
344
- else:
345
- value = attrib.default
346
- if value == attr.NOTHING:
347
- value = MISSING
348
- if is_union_annotation(type_) and not is_supported_union_annotation(type_):
349
- e = ConfigValueError(
350
- f"Unions of containers are not supported:\n{name}: {type_str(type_)}"
351
- )
352
- format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
353
-
354
- try:
355
- d[name] = _maybe_wrap(
356
- ref_type=type_,
357
- is_optional=is_optional,
358
- key=name,
359
- value=value,
360
- parent=dummy_parent,
361
- )
362
- except (ValidationError, GrammarParseError) as ex:
363
- format_and_raise(
364
- node=dummy_parent, key=name, value=value, cause=ex, msg=str(ex)
365
- )
366
- d[name]._set_parent(None)
367
- dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
368
- if dict_subclass_data is not None:
369
- d.update(dict_subclass_data)
370
- return d
371
-
372
-
373
- def get_dataclass_fields(obj: Any) -> List["dataclasses.Field[Any]"]:
374
- fields = dataclasses.fields(obj)
375
- return [f for f in fields if f.metadata.get("omegaconf_ignore") is not True]
376
-
377
-
378
- def get_dataclass_data(
379
- obj: Any, allow_objects: Optional[bool] = None
380
- ) -> Dict[str, Any]:
381
- from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap
382
-
383
- flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
384
- d = {}
385
- is_type = isinstance(obj, type)
386
- obj_type = get_type_of(obj)
387
- dummy_parent = OmegaConf.create({}, flags=flags)
388
- dummy_parent._metadata.object_type = obj_type
389
- resolved_hints = get_type_hints(obj_type)
390
- for field in get_dataclass_fields(obj):
391
- name = field.name
392
- is_optional, type_ = _resolve_optional(resolved_hints[field.name])
393
- type_ = _resolve_forward(type_, obj.__module__)
394
- has_default = field.default != dataclasses.MISSING
395
- has_default_factory = field.default_factory != dataclasses.MISSING
396
-
397
- if not is_type:
398
- value = getattr(obj, name)
399
- else:
400
- if has_default:
401
- value = field.default
402
- elif has_default_factory:
403
- value = field.default_factory() # type: ignore
404
- else:
405
- value = MISSING
406
-
407
- if is_union_annotation(type_) and not is_supported_union_annotation(type_):
408
- e = ConfigValueError(
409
- f"Unions of containers are not supported:\n{name}: {type_str(type_)}"
410
- )
411
- format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
412
- try:
413
- d[name] = _maybe_wrap(
414
- ref_type=type_,
415
- is_optional=is_optional,
416
- key=name,
417
- value=value,
418
- parent=dummy_parent,
419
- )
420
- except (ValidationError, GrammarParseError) as ex:
421
- format_and_raise(
422
- node=dummy_parent, key=name, value=value, cause=ex, msg=str(ex)
423
- )
424
- d[name]._set_parent(None)
425
- dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
426
- if dict_subclass_data is not None:
427
- d.update(dict_subclass_data)
428
- return d
429
-
430
-
431
- def is_dataclass(obj: Any) -> bool:
432
- from omegaconf.base import Node
433
-
434
- if dataclasses is None or isinstance(obj, Node):
435
- return False
436
- return dataclasses.is_dataclass(obj)
437
-
438
-
439
- def is_attr_class(obj: Any) -> bool:
440
- from omegaconf.base import Node
441
-
442
- if attr is None or isinstance(obj, Node):
443
- return False
444
- return attr.has(obj)
445
-
446
-
447
- def is_structured_config(obj: Any) -> bool:
448
- return is_attr_class(obj) or is_dataclass(obj)
449
-
450
-
451
- def is_dataclass_frozen(type_: Any) -> bool:
452
- return type_.__dataclass_params__.frozen # type: ignore
453
-
454
-
455
- def is_attr_frozen(type_: type) -> bool:
456
- # This is very hacky and probably fragile as well.
457
- # Unfortunately currently there isn't an official API in attr that can detect that.
458
- # noinspection PyProtectedMember
459
- return type_.__setattr__ == attr._make._frozen_setattrs # type: ignore
460
-
461
-
462
- def get_type_of(class_or_object: Any) -> Type[Any]:
463
- type_ = class_or_object
464
- if not isinstance(type_, type):
465
- type_ = type(class_or_object)
466
- assert isinstance(type_, type)
467
- return type_
468
-
469
-
470
- def is_structured_config_frozen(obj: Any) -> bool:
471
- type_ = get_type_of(obj)
472
-
473
- if is_dataclass(type_):
474
- return is_dataclass_frozen(type_)
475
- if is_attr_class(type_):
476
- return is_attr_frozen(type_)
477
- return False
478
-
479
-
480
- def get_structured_config_init_field_names(obj: Any) -> List[str]:
481
- fields: Union[List["dataclasses.Field[Any]"], List["attr.Attribute[Any]"]]
482
- if is_dataclass(obj):
483
- fields = get_dataclass_fields(obj)
484
- elif is_attr_class(obj):
485
- fields = get_attr_class_fields(obj)
486
- else:
487
- raise ValueError(f"Unsupported type: {type(obj).__name__}")
488
- return [f.name for f in fields if f.init]
489
-
490
-
491
- def get_structured_config_data(
492
- obj: Any, allow_objects: Optional[bool] = None
493
- ) -> Dict[str, Any]:
494
- if is_dataclass(obj):
495
- return get_dataclass_data(obj, allow_objects=allow_objects)
496
- elif is_attr_class(obj):
497
- return get_attr_data(obj, allow_objects=allow_objects)
498
- else:
499
- raise ValueError(f"Unsupported type: {type(obj).__name__}")
500
-
501
-
502
- class ValueKind(Enum):
503
- VALUE = 0
504
- MANDATORY_MISSING = 1
505
- INTERPOLATION = 2
506
-
507
-
508
- def _is_missing_value(value: Any) -> bool:
509
- from omegaconf import Node
510
-
511
- if isinstance(value, Node):
512
- value = value._value()
513
- return _is_missing_literal(value)
514
-
515
-
516
- def _is_missing_literal(value: Any) -> bool:
517
- # Uses literal '???' instead of the MISSING const for performance reasons.
518
- return isinstance(value, str) and value == "???"
519
-
520
-
521
- def _is_none(
522
- value: Any, resolve: bool = False, throw_on_resolution_failure: bool = True
523
- ) -> bool:
524
- from omegaconf import Node
525
-
526
- if not isinstance(value, Node):
527
- return value is None
528
-
529
- if resolve:
530
- value = value._maybe_dereference_node(
531
- throw_on_resolution_failure=throw_on_resolution_failure
532
- )
533
- if not throw_on_resolution_failure and value is None:
534
- # Resolution failure: consider that it is *not* None.
535
- return False
536
- assert isinstance(value, Node)
537
-
538
- return value._is_none()
539
-
540
-
541
- def get_value_kind(
542
- value: Any, strict_interpolation_validation: bool = False
543
- ) -> ValueKind:
544
- """
545
- Determine the kind of a value
546
- Examples:
547
- VALUE: "10", "20", True
548
- MANDATORY_MISSING: "???"
549
- INTERPOLATION: "${foo.bar}", "${foo.${bar}}", "${foo:bar}", "[${foo}, ${bar}]",
550
- "ftp://${host}/path", "${foo:${bar}, [true], {'baz': ${baz}}}"
551
-
552
- :param value: Input to classify.
553
- :param strict_interpolation_validation: If `True`, then when `value` is a string
554
- containing "${", it is parsed to validate the interpolation syntax. If `False`,
555
- this parsing step is skipped: this is more efficient, but will not detect errors.
556
- """
557
-
558
- if _is_missing_value(value):
559
- return ValueKind.MANDATORY_MISSING
560
-
561
- if _is_interpolation(value, strict_interpolation_validation):
562
- return ValueKind.INTERPOLATION
563
-
564
- return ValueKind.VALUE
565
-
566
-
567
- def _is_interpolation(v: Any, strict_interpolation_validation: bool = False) -> bool:
568
- from omegaconf import Node
569
-
570
- if isinstance(v, Node):
571
- v = v._value()
572
-
573
- if isinstance(v, str) and _is_interpolation_string(
574
- v, strict_interpolation_validation
575
- ):
576
- return True
577
- return False
578
-
579
-
580
- def _is_interpolation_string(value: str, strict_interpolation_validation: bool) -> bool:
581
- # We identify potential interpolations by the presence of "${" in the string.
582
- # Note that escaped interpolations (ex: "esc: \${bar}") are identified as
583
- # interpolations: this is intended, since they must be processed as interpolations
584
- # for the string to be properly un-escaped.
585
- # Keep in mind that invalid interpolations will only be detected when
586
- # `strict_interpolation_validation` is True.
587
- if "${" in value:
588
- if strict_interpolation_validation:
589
- # First try the cheap regex matching that detects common interpolations.
590
- if SIMPLE_INTERPOLATION_PATTERN.match(value) is None:
591
- # If no match, do the more expensive grammar parsing to detect errors.
592
- parse(value)
593
- return True
594
- return False
595
-
596
-
597
- def _is_special(value: Any) -> bool:
598
- """Special values are None, MISSING, and interpolation."""
599
- return _is_none(value) or get_value_kind(value) in (
600
- ValueKind.MANDATORY_MISSING,
601
- ValueKind.INTERPOLATION,
602
- )
603
-
604
-
605
- def is_float(st: str) -> bool:
606
- try:
607
- float(st)
608
- return True
609
- except ValueError:
610
- return False
611
-
612
-
613
- def is_int(st: str) -> bool:
614
- try:
615
- int(st)
616
- return True
617
- except ValueError:
618
- return False
619
-
620
-
621
- def is_primitive_list(obj: Any) -> bool:
622
- return isinstance(obj, (list, tuple))
623
-
624
-
625
- def is_primitive_dict(obj: Any) -> bool:
626
- t = get_type_of(obj)
627
- return t is dict
628
-
629
-
630
- def is_dict_annotation(type_: Any) -> bool:
631
- if type_ in (dict, Dict):
632
- return True
633
- origin = getattr(type_, "__origin__", None)
634
- # type_dict is a bit hard to detect.
635
- # this support is tentative, if it eventually causes issues in other areas it may be dropped.
636
- if sys.version_info < (3, 7, 0): # pragma: no cover
637
- typed_dict = hasattr(type_, "__base__") and type_.__base__ == Dict
638
- return origin is Dict or type_ is Dict or typed_dict
639
- else: # pragma: no cover
640
- typed_dict = hasattr(type_, "__base__") and type_.__base__ == dict
641
- return origin is dict or typed_dict
642
-
643
-
644
- def is_list_annotation(type_: Any) -> bool:
645
- if type_ in (list, List):
646
- return True
647
- origin = getattr(type_, "__origin__", None)
648
- if sys.version_info < (3, 7, 0):
649
- return origin is List or type_ is List # pragma: no cover
650
- else:
651
- return origin is list # pragma: no cover
652
-
653
-
654
- def is_tuple_annotation(type_: Any) -> bool:
655
- if type_ in (tuple, Tuple):
656
- return True
657
- origin = getattr(type_, "__origin__", None)
658
- if sys.version_info < (3, 7, 0):
659
- return origin is Tuple or type_ is Tuple # pragma: no cover
660
- else:
661
- return origin is tuple # pragma: no cover
662
-
663
-
664
- def is_supported_union_annotation(obj: Any) -> bool:
665
- """Currently only primitive types are supported in Unions, e.g. Union[int, str]"""
666
- if not is_union_annotation(obj):
667
- return False
668
- args = obj.__args__
669
- return all(is_primitive_type_annotation(arg) for arg in args)
670
-
671
-
672
- def is_dict_subclass(type_: Any) -> bool:
673
- return type_ is not None and isinstance(type_, type) and issubclass(type_, Dict)
674
-
675
-
676
- def is_dict(obj: Any) -> bool:
677
- return is_primitive_dict(obj) or is_dict_annotation(obj) or is_dict_subclass(obj)
678
-
679
-
680
- def is_primitive_container(obj: Any) -> bool:
681
- return is_primitive_list(obj) or is_primitive_dict(obj)
682
-
683
-
684
- def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any:
685
- args = getattr(ref_type, "__args__", None)
686
- if ref_type is not List and args is not None and args[0]:
687
- element_type = args[0]
688
- else:
689
- element_type = Any
690
- return element_type
691
-
692
-
693
- def get_tuple_item_types(ref_type: Type[Any]) -> Tuple[Any, ...]:
694
- args = getattr(ref_type, "__args__", None)
695
- if args in (None, ()):
696
- args = (Any, ...)
697
- assert isinstance(args, tuple)
698
- return args
699
-
700
-
701
- def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:
702
- args = getattr(ref_type, "__args__", None)
703
- if args is None:
704
- bases = getattr(ref_type, "__orig_bases__", None)
705
- if bases is not None and len(bases) > 0:
706
- args = getattr(bases[0], "__args__", None)
707
-
708
- key_type: Any
709
- element_type: Any
710
- if ref_type is None or ref_type == Dict:
711
- key_type = Any
712
- element_type = Any
713
- else:
714
- if args is not None:
715
- key_type = args[0]
716
- element_type = args[1]
717
- else:
718
- key_type = Any
719
- element_type = Any
720
-
721
- return key_type, element_type
722
-
723
-
724
- def is_valid_value_annotation(type_: Any) -> bool:
725
- _, type_ = _resolve_optional(type_)
726
- return (
727
- type_ is Any
728
- or is_primitive_type_annotation(type_)
729
- or is_structured_config(type_)
730
- or is_container_annotation(type_)
731
- or is_supported_union_annotation(type_)
732
- )
733
-
734
-
735
- def _valid_dict_key_annotation_type(type_: Any) -> bool:
736
- from omegaconf import DictKeyType
737
-
738
- return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore
739
-
740
-
741
- def is_primitive_type_annotation(type_: Any) -> bool:
742
- type_ = get_type_of(type_)
743
- return issubclass(type_, (Enum, pathlib.Path)) or type_ in BUILTIN_VALUE_TYPES
744
-
745
-
746
- def _get_value(value: Any) -> Any:
747
- from .base import Container, UnionNode
748
- from .nodes import ValueNode
749
-
750
- if isinstance(value, ValueNode):
751
- return value._value()
752
- elif isinstance(value, Container):
753
- boxed = value._value()
754
- if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
755
- return boxed
756
- elif isinstance(value, UnionNode):
757
- boxed = value._value()
758
- if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
759
- return boxed
760
- else:
761
- return _get_value(boxed) # pass through value of boxed node
762
-
763
- # return primitives and regular OmegaConf Containers as is
764
- return value
765
-
766
-
767
- def get_type_hint(obj: Any, key: Any = None) -> Optional[Type[Any]]:
768
- from omegaconf import Container, Node
769
-
770
- if isinstance(obj, Container):
771
- if key is not None:
772
- obj = obj._get_node(key)
773
- else:
774
- if key is not None:
775
- raise ValueError("Key must only be provided when obj is a container")
776
-
777
- if isinstance(obj, Node):
778
- ref_type = obj._metadata.ref_type
779
- if obj._is_optional() and ref_type is not Any:
780
- return Optional[ref_type] # type: ignore
781
- else:
782
- return ref_type
783
- else:
784
- return Any # type: ignore
785
-
786
-
787
- def _raise(ex: Exception, cause: Exception) -> None:
788
- # Set the environment variable OC_CAUSE=1 to get a stacktrace that includes the
789
- # causing exception.
790
- env_var = os.environ["OC_CAUSE"] if "OC_CAUSE" in os.environ else None
791
- debugging = sys.gettrace() is not None
792
- full_backtrace = (debugging and not env_var == "0") or (env_var == "1")
793
- if full_backtrace:
794
- ex.__cause__ = cause
795
- else:
796
- ex.__cause__ = None
797
- raise ex.with_traceback(sys.exc_info()[2]) # set env var OC_CAUSE=1 for full trace
798
-
799
-
800
- def format_and_raise(
801
- node: Any,
802
- key: Any,
803
- value: Any,
804
- msg: str,
805
- cause: Exception,
806
- type_override: Any = None,
807
- ) -> None:
808
- from omegaconf import OmegaConf
809
- from omegaconf.base import Node
810
-
811
- if isinstance(cause, AssertionError):
812
- raise
813
-
814
- if isinstance(cause, OmegaConfBaseException) and cause._initialized:
815
- ex = cause
816
- if type_override is not None:
817
- ex = type_override(str(cause))
818
- ex.__dict__ = copy.deepcopy(cause.__dict__)
819
- _raise(ex, cause)
820
-
821
- object_type: Optional[Type[Any]]
822
- object_type_str: Optional[str] = None
823
- ref_type: Optional[Type[Any]]
824
- ref_type_str: Optional[str]
825
-
826
- child_node: Optional[Node] = None
827
- if node is None:
828
- full_key = key if key is not None else ""
829
- object_type = None
830
- ref_type = None
831
- ref_type_str = None
832
- else:
833
- if key is not None and not node._is_none():
834
- child_node = node._get_node(key, validate_access=False)
835
-
836
- try:
837
- full_key = node._get_full_key(key=key)
838
- except Exception as exc:
839
- # Since we are handling an exception, raising a different one here would
840
- # be misleading. Instead, we display it in the key.
841
- full_key = f"<unresolvable due to {type(exc).__name__}: {exc}>"
842
-
843
- object_type = OmegaConf.get_type(node)
844
- object_type_str = type_str(object_type)
845
-
846
- ref_type = get_type_hint(node)
847
- ref_type_str = type_str(ref_type)
848
-
849
- msg = string.Template(msg).safe_substitute(
850
- REF_TYPE=ref_type_str,
851
- OBJECT_TYPE=object_type_str,
852
- KEY=key,
853
- FULL_KEY=full_key,
854
- VALUE=value,
855
- VALUE_TYPE=type_str(type(value), include_module_name=True),
856
- KEY_TYPE=f"{type(key).__name__}",
857
- )
858
-
859
- if ref_type not in (None, Any):
860
- template = dedent(
861
- """\
862
- $MSG
863
- full_key: $FULL_KEY
864
- reference_type=$REF_TYPE
865
- object_type=$OBJECT_TYPE"""
866
- )
867
- else:
868
- template = dedent(
869
- """\
870
- $MSG
871
- full_key: $FULL_KEY
872
- object_type=$OBJECT_TYPE"""
873
- )
874
- s = string.Template(template=template)
875
-
876
- message = s.substitute(
877
- REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, MSG=msg, FULL_KEY=full_key
878
- )
879
- exception_type = type(cause) if type_override is None else type_override
880
- if exception_type == TypeError:
881
- exception_type = ConfigTypeError
882
- elif exception_type == IndexError:
883
- exception_type = ConfigIndexError
884
-
885
- ex = exception_type(f"{message}")
886
- if issubclass(exception_type, OmegaConfBaseException):
887
- ex._initialized = True
888
- ex.msg = message
889
- ex.parent_node = node
890
- ex.child_node = child_node
891
- ex.key = key
892
- ex.full_key = full_key
893
- ex.value = value
894
- ex.object_type = object_type
895
- ex.object_type_str = object_type_str
896
- ex.ref_type = ref_type
897
- ex.ref_type_str = ref_type_str
898
-
899
- _raise(ex, cause)
900
-
901
-
902
- def type_str(t: Any, include_module_name: bool = False) -> str:
903
- is_optional, t = _resolve_optional(t)
904
- if t is NoneType:
905
- return str(t.__name__)
906
- if t is Any:
907
- return "Any"
908
- if t is ...:
909
- return "..."
910
-
911
- if hasattr(t, "__name__"):
912
- name = str(t.__name__)
913
- elif getattr(t, "_name", None) is not None: # pragma: no cover
914
- name = str(t._name)
915
- elif getattr(t, "__origin__", None) is not None: # pragma: no cover
916
- name = type_str(t.__origin__)
917
- else:
918
- name = str(t)
919
- if name.startswith("typing."): # pragma: no cover
920
- name = name[len("typing.") :]
921
-
922
- args = getattr(t, "__args__", None)
923
- if args is not None:
924
- args = ", ".join(
925
- [type_str(t, include_module_name=include_module_name) for t in t.__args__]
926
- )
927
- ret = f"{name}[{args}]"
928
- else:
929
- ret = name
930
- if include_module_name:
931
- if (
932
- hasattr(t, "__module__")
933
- and t.__module__ != "builtins"
934
- and t.__module__ != "typing"
935
- and not t.__module__.startswith("omegaconf.")
936
- ):
937
- module_prefix = str(t.__module__) + "."
938
- else:
939
- module_prefix = ""
940
- ret = module_prefix + ret
941
- if is_optional:
942
- return f"Optional[{ret}]"
943
- else:
944
- return ret
945
-
946
-
947
- def _ensure_container(target: Any, flags: Optional[Dict[str, bool]] = None) -> Any:
948
- from omegaconf import OmegaConf
949
-
950
- if is_primitive_container(target):
951
- assert isinstance(target, (list, dict))
952
- target = OmegaConf.create(target, flags=flags)
953
- elif is_structured_config(target):
954
- target = OmegaConf.structured(target, flags=flags)
955
- elif not OmegaConf.is_config(target):
956
- raise ValueError(
957
- "Invalid input. Supports one of "
958
- + "[dict,list,DictConfig,ListConfig,dataclass,dataclass instance,attr class,attr class instance]"
959
- )
960
-
961
- return target
962
-
963
-
964
- def is_generic_list(type_: Any) -> bool:
965
- """
966
- Checks if a type is a generic list, for example:
967
- list returns False
968
- typing.List returns False
969
- typing.List[T] returns True
970
-
971
- :param type_: variable type
972
- :return: bool
973
- """
974
- return is_list_annotation(type_) and get_list_element_type(type_) is not None
975
-
976
-
977
- def is_generic_dict(type_: Any) -> bool:
978
- """
979
- Checks if a type is a generic dict, for example:
980
- list returns False
981
- typing.List returns False
982
- typing.List[T] returns True
983
-
984
- :param type_: variable type
985
- :return: bool
986
- """
987
- return is_dict_annotation(type_) and len(get_dict_key_value_types(type_)) > 0
988
-
989
-
990
- def is_container_annotation(type_: Any) -> bool:
991
- return is_list_annotation(type_) or is_dict_annotation(type_)
992
-
993
-
994
- def split_key(key: str) -> List[str]:
995
- """
996
- Split a full key path into its individual components.
997
-
998
- This is similar to `key.split(".")` but also works with the getitem syntax:
999
- "a.b" -> ["a", "b"]
1000
- "a[b]" -> ["a", "b"]
1001
- ".a.b[c].d" -> ["", "a", "b", "c", "d"]
1002
- "[a].b" -> ["a", "b"]
1003
- """
1004
- # Obtain the first part of the key (in docstring examples: a, a, .a, '')
1005
- first = KEY_PATH_HEAD.match(key)
1006
- assert first is not None
1007
- first_stop = first.span()[1]
1008
-
1009
- # `tokens` will contain all elements composing the key.
1010
- tokens = key[0:first_stop].split(".")
1011
-
1012
- # Optimization in case `key` has no other component: we are done.
1013
- if first_stop == len(key):
1014
- return tokens
1015
-
1016
- if key[first_stop] == "[" and not tokens[-1]:
1017
- # This is a special case where the first key starts with brackets, e.g.
1018
- # [a] or ..[a]. In that case there is an extra "" in `tokens` that we
1019
- # need to get rid of:
1020
- # [a] -> tokens = [""] but we would like []
1021
- # ..[a] -> tokens = ["", "", ""] but we would like ["", ""]
1022
- tokens.pop()
1023
-
1024
- # Identify other key elements (in docstring examples: b, b, b/c/d, b)
1025
- others = KEY_PATH_OTHER.findall(key[first_stop:])
1026
-
1027
- # There are two groups in the `KEY_PATH_OTHER` regex: one for keys starting
1028
- # with a dot (.b, .d) and one for keys starting with a bracket ([b], [c]).
1029
- # Only one group can be non-empty.
1030
- tokens += [dot_key if dot_key else bracket_key for dot_key, bracket_key in others]
1031
-
1032
- return tokens
1033
-
1034
-
1035
- # Similar to Python 3.7+'s `contextlib.nullcontext` (which should be used instead,
1036
- # once support for Python 3.6 is dropped).
1037
- @contextmanager
1038
- def nullcontext(enter_result: Any = None) -> Iterator[Any]:
1039
- yield enter_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/base.py DELETED
@@ -1,962 +0,0 @@
1
- import copy
2
- import sys
3
- from abc import ABC, abstractmethod
4
- from collections import defaultdict
5
- from dataclasses import dataclass, field
6
- from enum import Enum
7
- from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
8
-
9
- from antlr4 import ParserRuleContext
10
-
11
- from ._utils import (
12
- _DEFAULT_MARKER_,
13
- NoneType,
14
- ValueKind,
15
- _get_value,
16
- _is_interpolation,
17
- _is_missing_value,
18
- _is_special,
19
- format_and_raise,
20
- get_value_kind,
21
- is_union_annotation,
22
- is_valid_value_annotation,
23
- split_key,
24
- type_str,
25
- )
26
- from .errors import (
27
- ConfigKeyError,
28
- ConfigTypeError,
29
- InterpolationKeyError,
30
- InterpolationResolutionError,
31
- InterpolationToMissingValueError,
32
- InterpolationValidationError,
33
- MissingMandatoryValue,
34
- UnsupportedInterpolationType,
35
- ValidationError,
36
- )
37
- from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
38
- from .grammar_parser import parse
39
- from .grammar_visitor import GrammarVisitor
40
-
41
- DictKeyType = Union[str, bytes, int, Enum, float, bool]
42
-
43
-
44
- @dataclass
45
- class Metadata:
46
-
47
- ref_type: Union[Type[Any], Any]
48
-
49
- object_type: Union[Type[Any], Any]
50
-
51
- optional: bool
52
-
53
- key: Any
54
-
55
- # Flags have 3 modes:
56
- # unset : inherit from parent (None if no parent specifies)
57
- # set to true: flag is true
58
- # set to false: flag is false
59
- flags: Optional[Dict[str, bool]] = None
60
-
61
- # If True, when checking the value of a flag, if the flag is not set None is returned
62
- # otherwise, the parent node is queried.
63
- flags_root: bool = False
64
-
65
- resolver_cache: Dict[str, Any] = field(default_factory=lambda: defaultdict(dict))
66
-
67
- def __post_init__(self) -> None:
68
- if self.flags is None:
69
- self.flags = {}
70
-
71
- @property
72
- def type_hint(self) -> Union[Type[Any], Any]:
73
- """Compute `type_hint` from `self.optional` and `self.ref_type`"""
74
- # For compatibility with pickled OmegaConf objects created using older
75
- # versions of OmegaConf, we store `ref_type` and `object_type`
76
- # separately (rather than storing `type_hint` directly).
77
- if self.optional:
78
- return Optional[self.ref_type]
79
- else:
80
- return self.ref_type
81
-
82
-
83
- @dataclass
84
- class ContainerMetadata(Metadata):
85
- key_type: Any = None
86
- element_type: Any = None
87
-
88
- def __post_init__(self) -> None:
89
- if self.ref_type is None:
90
- self.ref_type = Any
91
- assert self.key_type is Any or isinstance(self.key_type, type)
92
- if self.element_type is not None:
93
- if not is_valid_value_annotation(self.element_type):
94
- raise ValidationError(
95
- f"Unsupported value type: '{type_str(self.element_type, include_module_name=True)}'"
96
- )
97
-
98
- if self.flags is None:
99
- self.flags = {}
100
-
101
-
102
- class Node(ABC):
103
- _metadata: Metadata
104
-
105
- _parent: Optional["Box"]
106
- _flags_cache: Optional[Dict[str, Optional[bool]]]
107
-
108
- def __init__(self, parent: Optional["Box"], metadata: Metadata):
109
- self.__dict__["_metadata"] = metadata
110
- self.__dict__["_parent"] = parent
111
- self.__dict__["_flags_cache"] = None
112
-
113
- def __getstate__(self) -> Dict[str, Any]:
114
- # Overridden to ensure that the flags cache is cleared on serialization.
115
- state_dict = copy.copy(self.__dict__)
116
- del state_dict["_flags_cache"]
117
- return state_dict
118
-
119
- def __setstate__(self, state_dict: Dict[str, Any]) -> None:
120
- self.__dict__.update(state_dict)
121
- self.__dict__["_flags_cache"] = None
122
-
123
- def _set_parent(self, parent: Optional["Box"]) -> None:
124
- assert parent is None or isinstance(parent, Box)
125
- self.__dict__["_parent"] = parent
126
- self._invalidate_flags_cache()
127
-
128
- def _invalidate_flags_cache(self) -> None:
129
- self.__dict__["_flags_cache"] = None
130
-
131
- def _get_parent(self) -> Optional["Box"]:
132
- parent = self.__dict__["_parent"]
133
- assert parent is None or isinstance(parent, Box)
134
- return parent
135
-
136
- def _get_parent_container(self) -> Optional["Container"]:
137
- """
138
- Like _get_parent, but returns the grandparent
139
- in the case where `self` is wrapped by a UnionNode.
140
- """
141
- parent = self.__dict__["_parent"]
142
- assert parent is None or isinstance(parent, Box)
143
-
144
- if isinstance(parent, UnionNode):
145
- grandparent = parent.__dict__["_parent"]
146
- assert grandparent is None or isinstance(grandparent, Container)
147
- return grandparent
148
- else:
149
- assert parent is None or isinstance(parent, Container)
150
- return parent
151
-
152
- def _set_flag(
153
- self,
154
- flags: Union[List[str], str],
155
- values: Union[List[Optional[bool]], Optional[bool]],
156
- ) -> "Node":
157
- if isinstance(flags, str):
158
- flags = [flags]
159
-
160
- if values is None or isinstance(values, bool):
161
- values = [values]
162
-
163
- if len(values) == 1:
164
- values = len(flags) * values
165
-
166
- if len(flags) != len(values):
167
- raise ValueError("Inconsistent lengths of input flag names and values")
168
-
169
- for idx, flag in enumerate(flags):
170
- value = values[idx]
171
- if value is None:
172
- assert self._metadata.flags is not None
173
- if flag in self._metadata.flags:
174
- del self._metadata.flags[flag]
175
- else:
176
- assert self._metadata.flags is not None
177
- self._metadata.flags[flag] = value
178
- self._invalidate_flags_cache()
179
- return self
180
-
181
- def _get_node_flag(self, flag: str) -> Optional[bool]:
182
- """
183
- :param flag: flag to inspect
184
- :return: the state of the flag on this node.
185
- """
186
- assert self._metadata.flags is not None
187
- return self._metadata.flags.get(flag)
188
-
189
- def _get_flag(self, flag: str) -> Optional[bool]:
190
- cache = self.__dict__["_flags_cache"]
191
- if cache is None:
192
- cache = self.__dict__["_flags_cache"] = {}
193
-
194
- ret = cache.get(flag, _DEFAULT_MARKER_)
195
- if ret is _DEFAULT_MARKER_:
196
- ret = self._get_flag_no_cache(flag)
197
- cache[flag] = ret
198
- assert ret is None or isinstance(ret, bool)
199
- return ret
200
-
201
- def _get_flag_no_cache(self, flag: str) -> Optional[bool]:
202
- """
203
- Returns True if this config node flag is set
204
- A flag is set if node.set_flag(True) was called
205
- or one if it's parents is flag is set
206
- :return:
207
- """
208
- flags = self._metadata.flags
209
- assert flags is not None
210
- if flag in flags and flags[flag] is not None:
211
- return flags[flag]
212
-
213
- if self._is_flags_root():
214
- return None
215
-
216
- parent = self._get_parent()
217
- if parent is None:
218
- return None
219
- else:
220
- # noinspection PyProtectedMember
221
- return parent._get_flag(flag)
222
-
223
- def _format_and_raise(
224
- self,
225
- key: Any,
226
- value: Any,
227
- cause: Exception,
228
- msg: Optional[str] = None,
229
- type_override: Any = None,
230
- ) -> None:
231
- format_and_raise(
232
- node=self,
233
- key=key,
234
- value=value,
235
- msg=str(cause) if msg is None else msg,
236
- cause=cause,
237
- type_override=type_override,
238
- )
239
- assert False
240
-
241
- @abstractmethod
242
- def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
243
- ...
244
-
245
- def _dereference_node(self) -> "Node":
246
- node = self._dereference_node_impl(throw_on_resolution_failure=True)
247
- assert node is not None
248
- return node
249
-
250
- def _maybe_dereference_node(
251
- self,
252
- throw_on_resolution_failure: bool = False,
253
- memo: Optional[Set[int]] = None,
254
- ) -> Optional["Node"]:
255
- return self._dereference_node_impl(
256
- throw_on_resolution_failure=throw_on_resolution_failure,
257
- memo=memo,
258
- )
259
-
260
- def _dereference_node_impl(
261
- self,
262
- throw_on_resolution_failure: bool,
263
- memo: Optional[Set[int]] = None,
264
- ) -> Optional["Node"]:
265
- if not self._is_interpolation():
266
- return self
267
-
268
- parent = self._get_parent_container()
269
- if parent is None:
270
- if throw_on_resolution_failure:
271
- raise InterpolationResolutionError(
272
- "Cannot resolve interpolation for a node without a parent"
273
- )
274
- return None
275
- assert parent is not None
276
- key = self._key()
277
- return parent._resolve_interpolation_from_parse_tree(
278
- parent=parent,
279
- key=key,
280
- value=self,
281
- parse_tree=parse(_get_value(self)),
282
- throw_on_resolution_failure=throw_on_resolution_failure,
283
- memo=memo,
284
- )
285
-
286
- def _get_root(self) -> "Container":
287
- root: Optional[Box] = self._get_parent()
288
- if root is None:
289
- assert isinstance(self, Container)
290
- return self
291
- assert root is not None and isinstance(root, Box)
292
- while root._get_parent() is not None:
293
- root = root._get_parent()
294
- assert root is not None and isinstance(root, Box)
295
- assert root is not None and isinstance(root, Container)
296
- return root
297
-
298
- def _is_missing(self) -> bool:
299
- """
300
- Check if the node's value is `???` (does *not* resolve interpolations).
301
- """
302
- return _is_missing_value(self)
303
-
304
- def _is_none(self) -> bool:
305
- """
306
- Check if the node's value is `None` (does *not* resolve interpolations).
307
- """
308
- return self._value() is None
309
-
310
- @abstractmethod
311
- def __eq__(self, other: Any) -> bool:
312
- ...
313
-
314
- @abstractmethod
315
- def __ne__(self, other: Any) -> bool:
316
- ...
317
-
318
- @abstractmethod
319
- def __hash__(self) -> int:
320
- ...
321
-
322
- @abstractmethod
323
- def _value(self) -> Any:
324
- ...
325
-
326
- @abstractmethod
327
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
328
- ...
329
-
330
- @abstractmethod
331
- def _is_optional(self) -> bool:
332
- ...
333
-
334
- @abstractmethod
335
- def _is_interpolation(self) -> bool:
336
- ...
337
-
338
- def _key(self) -> Any:
339
- return self._metadata.key
340
-
341
- def _set_key(self, key: Any) -> None:
342
- self._metadata.key = key
343
-
344
- def _is_flags_root(self) -> bool:
345
- return self._metadata.flags_root
346
-
347
- def _set_flags_root(self, flags_root: bool) -> None:
348
- if self._metadata.flags_root != flags_root:
349
- self._metadata.flags_root = flags_root
350
- self._invalidate_flags_cache()
351
-
352
- def _has_ref_type(self) -> bool:
353
- return self._metadata.ref_type is not Any
354
-
355
-
356
- class Box(Node):
357
- """
358
- Base class for nodes that can contain other nodes.
359
- Concrete subclasses include DictConfig, ListConfig, and UnionNode.
360
- """
361
-
362
- _content: Any
363
-
364
- def __init__(self, parent: Optional["Box"], metadata: Metadata):
365
- super().__init__(parent=parent, metadata=metadata)
366
- self.__dict__["_content"] = None
367
-
368
- def __copy__(self) -> Any:
369
- # real shallow copy is impossible because of the reference to the parent.
370
- return copy.deepcopy(self)
371
-
372
- def _re_parent(self) -> None:
373
- from .dictconfig import DictConfig
374
- from .listconfig import ListConfig
375
-
376
- # update parents of first level Config nodes to self
377
-
378
- if isinstance(self, DictConfig):
379
- content = self.__dict__["_content"]
380
- if isinstance(content, dict):
381
- for _key, value in self.__dict__["_content"].items():
382
- if value is not None:
383
- value._set_parent(self)
384
- if isinstance(value, Box):
385
- value._re_parent()
386
- elif isinstance(self, ListConfig):
387
- content = self.__dict__["_content"]
388
- if isinstance(content, list):
389
- for item in self.__dict__["_content"]:
390
- if item is not None:
391
- item._set_parent(self)
392
- if isinstance(item, Box):
393
- item._re_parent()
394
- elif isinstance(self, UnionNode):
395
- content = self.__dict__["_content"]
396
- if isinstance(content, Node):
397
- content._set_parent(self)
398
- if isinstance(content, Box): # pragma: no cover
399
- # No coverage here as support for containers inside
400
- # UnionNode is not yet implemented
401
- content._re_parent()
402
-
403
-
404
- class Container(Box):
405
- """
406
- Container tagging interface
407
- """
408
-
409
- _metadata: ContainerMetadata
410
-
411
- @abstractmethod
412
- def _get_child(
413
- self,
414
- key: Any,
415
- validate_access: bool = True,
416
- validate_key: bool = True,
417
- throw_on_missing_value: bool = False,
418
- throw_on_missing_key: bool = False,
419
- ) -> Union[Optional[Node], List[Optional[Node]]]:
420
- ...
421
-
422
- @abstractmethod
423
- def _get_node(
424
- self,
425
- key: Any,
426
- validate_access: bool = True,
427
- validate_key: bool = True,
428
- throw_on_missing_value: bool = False,
429
- throw_on_missing_key: bool = False,
430
- ) -> Union[Optional[Node], List[Optional[Node]]]:
431
- ...
432
-
433
- @abstractmethod
434
- def __delitem__(self, key: Any) -> None:
435
- ...
436
-
437
- @abstractmethod
438
- def __setitem__(self, key: Any, value: Any) -> None:
439
- ...
440
-
441
- @abstractmethod
442
- def __iter__(self) -> Iterator[Any]:
443
- ...
444
-
445
- @abstractmethod
446
- def __getitem__(self, key_or_index: Any) -> Any:
447
- ...
448
-
449
- def _resolve_key_and_root(self, key: str) -> Tuple["Container", str]:
450
- orig = key
451
- if not key.startswith("."):
452
- return self._get_root(), key
453
- else:
454
- root: Optional[Container] = self
455
- assert key.startswith(".")
456
- while True:
457
- assert root is not None
458
- key = key[1:]
459
- if not key.startswith("."):
460
- break
461
- root = root._get_parent_container()
462
- if root is None:
463
- raise ConfigKeyError(f"Error resolving key '{orig}'")
464
-
465
- return root, key
466
-
467
- def _select_impl(
468
- self,
469
- key: str,
470
- throw_on_missing: bool,
471
- throw_on_resolution_failure: bool,
472
- memo: Optional[Set[int]] = None,
473
- ) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]:
474
- """
475
- Select a value using dot separated key sequence
476
- """
477
- from .omegaconf import _select_one
478
-
479
- if key == "":
480
- return self, "", self
481
-
482
- split = split_key(key)
483
- root: Optional[Container] = self
484
- for i in range(len(split) - 1):
485
- if root is None:
486
- break
487
-
488
- k = split[i]
489
- ret, _ = _select_one(
490
- c=root,
491
- key=k,
492
- throw_on_missing=throw_on_missing,
493
- throw_on_type_error=throw_on_resolution_failure,
494
- )
495
- if isinstance(ret, Node):
496
- ret = ret._maybe_dereference_node(
497
- throw_on_resolution_failure=throw_on_resolution_failure,
498
- memo=memo,
499
- )
500
-
501
- if ret is not None and not isinstance(ret, Container):
502
- parent_key = ".".join(split[0 : i + 1])
503
- child_key = split[i + 1]
504
- raise ConfigTypeError(
505
- f"Error trying to access {key}: node `{parent_key}` "
506
- f"is not a container and thus cannot contain `{child_key}`"
507
- )
508
- root = ret
509
-
510
- if root is None:
511
- return None, None, None
512
-
513
- last_key = split[-1]
514
- value, _ = _select_one(
515
- c=root,
516
- key=last_key,
517
- throw_on_missing=throw_on_missing,
518
- throw_on_type_error=throw_on_resolution_failure,
519
- )
520
- if value is None:
521
- return root, last_key, None
522
-
523
- if memo is not None:
524
- vid = id(value)
525
- if vid in memo:
526
- raise InterpolationResolutionError("Recursive interpolation detected")
527
- # push to memo "stack"
528
- memo.add(vid)
529
-
530
- try:
531
- value = root._maybe_resolve_interpolation(
532
- parent=root,
533
- key=last_key,
534
- value=value,
535
- throw_on_resolution_failure=throw_on_resolution_failure,
536
- memo=memo,
537
- )
538
- finally:
539
- if memo is not None:
540
- # pop from memo "stack"
541
- memo.remove(vid)
542
-
543
- return root, last_key, value
544
-
545
- def _resolve_interpolation_from_parse_tree(
546
- self,
547
- parent: Optional["Container"],
548
- value: "Node",
549
- key: Any,
550
- parse_tree: OmegaConfGrammarParser.ConfigValueContext,
551
- throw_on_resolution_failure: bool,
552
- memo: Optional[Set[int]],
553
- ) -> Optional["Node"]:
554
- """
555
- Resolve an interpolation.
556
-
557
- This happens in two steps:
558
- 1. The parse tree is visited, which outputs either a `Node` (e.g.,
559
- for node interpolations "${foo}"), a string (e.g., for string
560
- interpolations "hello ${name}", or any other arbitrary value
561
- (e.g., or custom interpolations "${foo:bar}").
562
- 2. This output is potentially validated and converted when the node
563
- being resolved (`value`) is typed.
564
-
565
- If an error occurs in one of the above steps, an `InterpolationResolutionError`
566
- (or a subclass of it) is raised, *unless* `throw_on_resolution_failure` is set
567
- to `False` (in which case the return value is `None`).
568
-
569
- :param parent: Parent of the node being resolved.
570
- :param value: Node being resolved.
571
- :param key: The associated key in the parent.
572
- :param parse_tree: The parse tree as obtained from `grammar_parser.parse()`.
573
- :param throw_on_resolution_failure: If `False`, then exceptions raised during
574
- the resolution of the interpolation are silenced, and instead `None` is
575
- returned.
576
-
577
- :return: A `Node` that contains the interpolation result. This may be an existing
578
- node in the config (in the case of a node interpolation "${foo}"), or a new
579
- node that is created to wrap the interpolated value. It is `None` if and only if
580
- `throw_on_resolution_failure` is `False` and an error occurs during resolution.
581
- """
582
-
583
- try:
584
- resolved = self.resolve_parse_tree(
585
- parse_tree=parse_tree, node=value, key=key, memo=memo
586
- )
587
- except InterpolationResolutionError:
588
- if throw_on_resolution_failure:
589
- raise
590
- return None
591
-
592
- return self._validate_and_convert_interpolation_result(
593
- parent=parent,
594
- value=value,
595
- key=key,
596
- resolved=resolved,
597
- throw_on_resolution_failure=throw_on_resolution_failure,
598
- )
599
-
600
- def _validate_and_convert_interpolation_result(
601
- self,
602
- parent: Optional["Container"],
603
- value: "Node",
604
- key: Any,
605
- resolved: Any,
606
- throw_on_resolution_failure: bool,
607
- ) -> Optional["Node"]:
608
- from .nodes import AnyNode, InterpolationResultNode, ValueNode
609
-
610
- # If the output is not a Node already (e.g., because it is the output of a
611
- # custom resolver), then we will need to wrap it within a Node.
612
- must_wrap = not isinstance(resolved, Node)
613
-
614
- # If the node is typed, validate (and possibly convert) the result.
615
- if isinstance(value, ValueNode) and not isinstance(value, AnyNode):
616
- res_value = _get_value(resolved)
617
- try:
618
- conv_value = value.validate_and_convert(res_value)
619
- except ValidationError as e:
620
- if throw_on_resolution_failure:
621
- self._format_and_raise(
622
- key=key,
623
- value=res_value,
624
- cause=e,
625
- msg=f"While dereferencing interpolation '{value}': {e}",
626
- type_override=InterpolationValidationError,
627
- )
628
- return None
629
-
630
- # If the converted value is of the same type, it means that no conversion
631
- # was actually needed. As a result, we can keep the original `resolved`
632
- # (and otherwise, the converted value must be wrapped into a new node).
633
- if type(conv_value) != type(res_value):
634
- must_wrap = True
635
- resolved = conv_value
636
-
637
- if must_wrap:
638
- return InterpolationResultNode(value=resolved, key=key, parent=parent)
639
- else:
640
- assert isinstance(resolved, Node)
641
- return resolved
642
-
643
- def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None:
644
- parent: Optional[Node] = node
645
- while parent is not None:
646
- if parent is target:
647
- raise InterpolationResolutionError(
648
- "Interpolation to parent node detected"
649
- )
650
- parent = parent._get_parent()
651
-
652
- def _resolve_node_interpolation(
653
- self, inter_key: str, memo: Optional[Set[int]]
654
- ) -> "Node":
655
- """A node interpolation is of the form `${foo.bar}`"""
656
- try:
657
- root_node, inter_key = self._resolve_key_and_root(inter_key)
658
- except ConfigKeyError as exc:
659
- raise InterpolationKeyError(
660
- f"ConfigKeyError while resolving interpolation: {exc}"
661
- ).with_traceback(sys.exc_info()[2])
662
-
663
- try:
664
- parent, last_key, value = root_node._select_impl(
665
- inter_key,
666
- throw_on_missing=True,
667
- throw_on_resolution_failure=True,
668
- memo=memo,
669
- )
670
- except MissingMandatoryValue as exc:
671
- raise InterpolationToMissingValueError(
672
- f"MissingMandatoryValue while resolving interpolation: {exc}"
673
- ).with_traceback(sys.exc_info()[2])
674
-
675
- if parent is None or value is None:
676
- raise InterpolationKeyError(f"Interpolation key '{inter_key}' not found")
677
- else:
678
- self._validate_not_dereferencing_to_parent(node=self, target=value)
679
- return value
680
-
681
- def _evaluate_custom_resolver(
682
- self,
683
- key: Any,
684
- node: Node,
685
- inter_type: str,
686
- inter_args: Tuple[Any, ...],
687
- inter_args_str: Tuple[str, ...],
688
- ) -> Any:
689
- from omegaconf import OmegaConf
690
-
691
- resolver = OmegaConf._get_resolver(inter_type)
692
- if resolver is not None:
693
- root_node = self._get_root()
694
- return resolver(
695
- root_node,
696
- self,
697
- node,
698
- inter_args,
699
- inter_args_str,
700
- )
701
- else:
702
- raise UnsupportedInterpolationType(
703
- f"Unsupported interpolation type {inter_type}"
704
- )
705
-
706
- def _maybe_resolve_interpolation(
707
- self,
708
- parent: Optional["Container"],
709
- key: Any,
710
- value: Node,
711
- throw_on_resolution_failure: bool,
712
- memo: Optional[Set[int]] = None,
713
- ) -> Optional[Node]:
714
- value_kind = get_value_kind(value)
715
- if value_kind != ValueKind.INTERPOLATION:
716
- return value
717
-
718
- parse_tree = parse(_get_value(value))
719
- return self._resolve_interpolation_from_parse_tree(
720
- parent=parent,
721
- value=value,
722
- key=key,
723
- parse_tree=parse_tree,
724
- throw_on_resolution_failure=throw_on_resolution_failure,
725
- memo=memo if memo is not None else set(),
726
- )
727
-
728
- def resolve_parse_tree(
729
- self,
730
- parse_tree: ParserRuleContext,
731
- node: Node,
732
- memo: Optional[Set[int]] = None,
733
- key: Optional[Any] = None,
734
- ) -> Any:
735
- """
736
- Resolve a given parse tree into its value.
737
-
738
- We make no assumption here on the type of the tree's root, so that the
739
- return value may be of any type.
740
- """
741
-
742
- def node_interpolation_callback(
743
- inter_key: str, memo: Optional[Set[int]]
744
- ) -> Optional["Node"]:
745
- return self._resolve_node_interpolation(inter_key=inter_key, memo=memo)
746
-
747
- def resolver_interpolation_callback(
748
- name: str, args: Tuple[Any, ...], args_str: Tuple[str, ...]
749
- ) -> Any:
750
- return self._evaluate_custom_resolver(
751
- key=key,
752
- node=node,
753
- inter_type=name,
754
- inter_args=args,
755
- inter_args_str=args_str,
756
- )
757
-
758
- visitor = GrammarVisitor(
759
- node_interpolation_callback=node_interpolation_callback,
760
- resolver_interpolation_callback=resolver_interpolation_callback,
761
- memo=memo,
762
- )
763
- try:
764
- return visitor.visit(parse_tree)
765
- except InterpolationResolutionError:
766
- raise
767
- except Exception as exc:
768
- # Other kinds of exceptions are wrapped in an `InterpolationResolutionError`.
769
- raise InterpolationResolutionError(
770
- f"{type(exc).__name__} raised while resolving interpolation: {exc}"
771
- ).with_traceback(sys.exc_info()[2])
772
-
773
- def _invalidate_flags_cache(self) -> None:
774
- from .dictconfig import DictConfig
775
- from .listconfig import ListConfig
776
-
777
- # invalidate subtree cache only if the cache is initialized in this node.
778
-
779
- if self.__dict__["_flags_cache"] is not None:
780
- self.__dict__["_flags_cache"] = None
781
- if isinstance(self, DictConfig):
782
- content = self.__dict__["_content"]
783
- if isinstance(content, dict):
784
- for value in self.__dict__["_content"].values():
785
- value._invalidate_flags_cache()
786
- elif isinstance(self, ListConfig):
787
- content = self.__dict__["_content"]
788
- if isinstance(content, list):
789
- for item in self.__dict__["_content"]:
790
- item._invalidate_flags_cache()
791
-
792
-
793
- class SCMode(Enum):
794
- DICT = 1 # Convert to plain dict
795
- DICT_CONFIG = 2 # Keep as OmegaConf DictConfig
796
- INSTANTIATE = 3 # Create a dataclass or attrs class instance
797
-
798
-
799
- class UnionNode(Box):
800
- """
801
- This class handles Union type hints. The `_content` attribute is either a
802
- child node that is compatible with the given Union ref_type, or it is a
803
- special value (None or MISSING or interpolation).
804
-
805
- Much of the logic for e.g. value assignment and type validation is
806
- delegated to the child node. As such, UnionNode functions as a
807
- "pass-through" node. User apps and downstream libraries should not need to
808
- know about UnionNode (assuming they only use OmegaConf's public API).
809
- """
810
-
811
- _parent: Optional[Container]
812
- _content: Union[Node, None, str]
813
-
814
- def __init__(
815
- self,
816
- content: Any,
817
- ref_type: Any,
818
- is_optional: bool = True,
819
- key: Any = None,
820
- parent: Optional[Box] = None,
821
- ) -> None:
822
- try:
823
- if not is_union_annotation(ref_type): # pragma: no cover
824
- msg = (
825
- f"UnionNode got unexpected ref_type {ref_type}. Please file a bug"
826
- + " report at https://github.com/omry/omegaconf/issues"
827
- )
828
- raise AssertionError(msg)
829
- if not isinstance(parent, (Container, NoneType)):
830
- raise ConfigTypeError("Parent type is not omegaconf.Container")
831
- super().__init__(
832
- parent=parent,
833
- metadata=Metadata(
834
- ref_type=ref_type,
835
- object_type=None,
836
- optional=is_optional,
837
- key=key,
838
- flags={"convert": False},
839
- ),
840
- )
841
- self._set_value(content)
842
- except Exception as ex:
843
- format_and_raise(node=None, key=key, value=content, msg=str(ex), cause=ex)
844
-
845
- def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
846
- parent = self._get_parent()
847
- if parent is None:
848
- if self._metadata.key is None:
849
- return ""
850
- else:
851
- return str(self._metadata.key)
852
- else:
853
- return parent._get_full_key(self._metadata.key)
854
-
855
- def __eq__(self, other: Any) -> bool:
856
- content = self.__dict__["_content"]
857
- if isinstance(content, Node):
858
- ret = content.__eq__(other)
859
- elif isinstance(other, Node):
860
- ret = other.__eq__(content)
861
- else:
862
- ret = content.__eq__(other)
863
- assert isinstance(ret, (bool, type(NotImplemented)))
864
- return ret
865
-
866
- def __ne__(self, other: Any) -> bool:
867
- x = self.__eq__(other)
868
- if x is NotImplemented:
869
- return NotImplemented
870
- return not x
871
-
872
- def __hash__(self) -> int:
873
- return hash(self.__dict__["_content"])
874
-
875
- def _value(self) -> Union[Node, None, str]:
876
- content = self.__dict__["_content"]
877
- assert isinstance(content, (Node, NoneType, str))
878
- return content
879
-
880
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
881
- previous_content = self.__dict__["_content"]
882
- previous_metadata = self.__dict__["_metadata"]
883
- try:
884
- self._set_value_impl(value, flags)
885
- except Exception as e:
886
- self.__dict__["_content"] = previous_content
887
- self.__dict__["_metadata"] = previous_metadata
888
- raise e
889
-
890
- def _set_value_impl(
891
- self, value: Any, flags: Optional[Dict[str, bool]] = None
892
- ) -> None:
893
- from omegaconf.omegaconf import _node_wrap
894
-
895
- ref_type = self._metadata.ref_type
896
- type_hint = self._metadata.type_hint
897
-
898
- value = _get_value(value)
899
- if _is_special(value):
900
- assert isinstance(value, (str, NoneType))
901
- if value is None:
902
- if not self._is_optional():
903
- raise ValidationError(
904
- f"Value '$VALUE' is incompatible with type hint '{type_str(type_hint)}'"
905
- )
906
- self.__dict__["_content"] = value
907
- elif isinstance(value, Container):
908
- raise ValidationError(
909
- f"Cannot assign container '$VALUE' of type '$VALUE_TYPE' to {type_str(type_hint)}"
910
- )
911
- else:
912
- for candidate_ref_type in ref_type.__args__:
913
- try:
914
- self.__dict__["_content"] = _node_wrap(
915
- value=value,
916
- ref_type=candidate_ref_type,
917
- is_optional=False,
918
- key=None,
919
- parent=self,
920
- )
921
- break
922
- except ValidationError:
923
- continue
924
- else:
925
- raise ValidationError(
926
- f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_str(type_hint)}'"
927
- )
928
-
929
- def _is_optional(self) -> bool:
930
- return self.__dict__["_metadata"].optional is True
931
-
932
- def _is_interpolation(self) -> bool:
933
- return _is_interpolation(self.__dict__["_content"])
934
-
935
- def __str__(self) -> str:
936
- return str(self.__dict__["_content"])
937
-
938
- def __repr__(self) -> str:
939
- return repr(self.__dict__["_content"])
940
-
941
- def __deepcopy__(self, memo: Dict[int, Any]) -> "UnionNode":
942
- res = object.__new__(type(self))
943
- for key, value in self.__dict__.items():
944
- if key not in ("_content", "_parent"):
945
- res.__dict__[key] = copy.deepcopy(value, memo=memo)
946
-
947
- src_content = self.__dict__["_content"]
948
- if isinstance(src_content, Node):
949
- old_parent = src_content.__dict__["_parent"]
950
- try:
951
- src_content.__dict__["_parent"] = None
952
- content_copy = copy.deepcopy(src_content, memo=memo)
953
- content_copy.__dict__["_parent"] = res
954
- finally:
955
- src_content.__dict__["_parent"] = old_parent
956
- else:
957
- # None and strings can be assigned as is
958
- content_copy = src_content
959
-
960
- res.__dict__["_content"] = content_copy
961
- res.__dict__["_parent"] = self.__dict__["_parent"]
962
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/basecontainer.py DELETED
@@ -1,916 +0,0 @@
1
- import copy
2
- import sys
3
- from abc import ABC, abstractmethod
4
- from enum import Enum
5
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union
6
-
7
- import yaml
8
-
9
- from ._utils import (
10
- _DEFAULT_MARKER_,
11
- ValueKind,
12
- _ensure_container,
13
- _get_value,
14
- _is_interpolation,
15
- _is_missing_value,
16
- _is_none,
17
- _is_special,
18
- _resolve_optional,
19
- get_structured_config_data,
20
- get_type_hint,
21
- get_value_kind,
22
- get_yaml_loader,
23
- is_container_annotation,
24
- is_dict_annotation,
25
- is_list_annotation,
26
- is_primitive_dict,
27
- is_primitive_type_annotation,
28
- is_structured_config,
29
- is_tuple_annotation,
30
- is_union_annotation,
31
- )
32
- from .base import (
33
- Box,
34
- Container,
35
- ContainerMetadata,
36
- DictKeyType,
37
- Node,
38
- SCMode,
39
- UnionNode,
40
- )
41
- from .errors import (
42
- ConfigCycleDetectedException,
43
- ConfigTypeError,
44
- InterpolationResolutionError,
45
- KeyValidationError,
46
- MissingMandatoryValue,
47
- OmegaConfBaseException,
48
- ReadonlyConfigError,
49
- ValidationError,
50
- )
51
-
52
- if TYPE_CHECKING:
53
- from .dictconfig import DictConfig # pragma: no cover
54
-
55
-
56
- class BaseContainer(Container, ABC):
57
- _resolvers: ClassVar[Dict[str, Any]] = {}
58
-
59
- def __init__(self, parent: Optional[Box], metadata: ContainerMetadata):
60
- if not (parent is None or isinstance(parent, Box)):
61
- raise ConfigTypeError("Parent type is not omegaconf.Box")
62
- super().__init__(parent=parent, metadata=metadata)
63
-
64
- def _get_child(
65
- self,
66
- key: Any,
67
- validate_access: bool = True,
68
- validate_key: bool = True,
69
- throw_on_missing_value: bool = False,
70
- throw_on_missing_key: bool = False,
71
- ) -> Union[Optional[Node], List[Optional[Node]]]:
72
- """Like _get_node, passing through to the nearest concrete Node."""
73
- child = self._get_node(
74
- key=key,
75
- validate_access=validate_access,
76
- validate_key=validate_key,
77
- throw_on_missing_value=throw_on_missing_value,
78
- throw_on_missing_key=throw_on_missing_key,
79
- )
80
- if isinstance(child, UnionNode) and not _is_special(child):
81
- value = child._value()
82
- assert isinstance(value, Node) and not isinstance(value, UnionNode)
83
- child = value
84
- return child
85
-
86
- def _resolve_with_default(
87
- self,
88
- key: Union[DictKeyType, int],
89
- value: Node,
90
- default_value: Any = _DEFAULT_MARKER_,
91
- ) -> Any:
92
- """returns the value with the specified key, like obj.key and obj['key']"""
93
- if _is_missing_value(value):
94
- if default_value is not _DEFAULT_MARKER_:
95
- return default_value
96
- raise MissingMandatoryValue("Missing mandatory value: $FULL_KEY")
97
-
98
- resolved_node = self._maybe_resolve_interpolation(
99
- parent=self,
100
- key=key,
101
- value=value,
102
- throw_on_resolution_failure=True,
103
- )
104
-
105
- return _get_value(resolved_node)
106
-
107
- def __str__(self) -> str:
108
- return self.__repr__()
109
-
110
- def __repr__(self) -> str:
111
- if self.__dict__["_content"] is None:
112
- return "None"
113
- elif self._is_interpolation() or self._is_missing():
114
- v = self.__dict__["_content"]
115
- return f"'{v}'"
116
- else:
117
- return self.__dict__["_content"].__repr__() # type: ignore
118
-
119
- # Support pickle
120
- def __getstate__(self) -> Dict[str, Any]:
121
- dict_copy = copy.copy(self.__dict__)
122
-
123
- # no need to serialize the flags cache, it can be re-constructed later
124
- dict_copy.pop("_flags_cache", None)
125
-
126
- dict_copy["_metadata"] = copy.copy(dict_copy["_metadata"])
127
- ref_type = self._metadata.ref_type
128
- if is_container_annotation(ref_type):
129
- if is_dict_annotation(ref_type):
130
- dict_copy["_metadata"].ref_type = Dict
131
- elif is_list_annotation(ref_type):
132
- dict_copy["_metadata"].ref_type = List
133
- else:
134
- assert False
135
- if sys.version_info < (3, 7): # pragma: no cover
136
- element_type = self._metadata.element_type
137
- if is_union_annotation(element_type):
138
- raise OmegaConfBaseException(
139
- "Serializing structured configs with `Union` element type requires python >= 3.7"
140
- )
141
- return dict_copy
142
-
143
- # Support pickle
144
- def __setstate__(self, d: Dict[str, Any]) -> None:
145
- from omegaconf import DictConfig
146
- from omegaconf._utils import is_generic_dict, is_generic_list
147
-
148
- if isinstance(self, DictConfig):
149
- key_type = d["_metadata"].key_type
150
-
151
- # backward compatibility to load OmegaConf 2.0 configs
152
- if key_type is None:
153
- key_type = Any
154
- d["_metadata"].key_type = key_type
155
-
156
- element_type = d["_metadata"].element_type
157
-
158
- # backward compatibility to load OmegaConf 2.0 configs
159
- if element_type is None:
160
- element_type = Any
161
- d["_metadata"].element_type = element_type
162
-
163
- ref_type = d["_metadata"].ref_type
164
- if is_container_annotation(ref_type):
165
- if is_generic_dict(ref_type):
166
- d["_metadata"].ref_type = Dict[key_type, element_type] # type: ignore
167
- elif is_generic_list(ref_type):
168
- d["_metadata"].ref_type = List[element_type] # type: ignore
169
- else:
170
- assert False
171
-
172
- d["_flags_cache"] = None
173
- self.__dict__.update(d)
174
-
175
- @abstractmethod
176
- def __delitem__(self, key: Any) -> None:
177
- ...
178
-
179
- def __len__(self) -> int:
180
- if self._is_none() or self._is_missing() or self._is_interpolation():
181
- return 0
182
- content = self.__dict__["_content"]
183
- return len(content)
184
-
185
- def merge_with_cli(self) -> None:
186
- args_list = sys.argv[1:]
187
- self.merge_with_dotlist(args_list)
188
-
189
- def merge_with_dotlist(self, dotlist: List[str]) -> None:
190
- from omegaconf import OmegaConf
191
-
192
- def fail() -> None:
193
- raise ValueError("Input list must be a list or a tuple of strings")
194
-
195
- if not isinstance(dotlist, (list, tuple)):
196
- fail()
197
-
198
- for arg in dotlist:
199
- if not isinstance(arg, str):
200
- fail()
201
-
202
- idx = arg.find("=")
203
- if idx == -1:
204
- key = arg
205
- value = None
206
- else:
207
- key = arg[0:idx]
208
- value = arg[idx + 1 :]
209
- value = yaml.load(value, Loader=get_yaml_loader())
210
-
211
- OmegaConf.update(self, key, value)
212
-
213
- def is_empty(self) -> bool:
214
- """return true if config is empty"""
215
- return len(self.__dict__["_content"]) == 0
216
-
217
- @staticmethod
218
- def _to_content(
219
- conf: Container,
220
- resolve: bool,
221
- throw_on_missing: bool,
222
- enum_to_str: bool = False,
223
- structured_config_mode: SCMode = SCMode.DICT,
224
- ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]:
225
- from omegaconf import MISSING, DictConfig, ListConfig
226
-
227
- def convert(val: Node) -> Any:
228
- value = val._value()
229
- if enum_to_str and isinstance(value, Enum):
230
- value = f"{value.name}"
231
-
232
- return value
233
-
234
- def get_node_value(key: Union[DictKeyType, int]) -> Any:
235
- try:
236
- node = conf._get_child(key, throw_on_missing_value=throw_on_missing)
237
- except MissingMandatoryValue as e:
238
- conf._format_and_raise(key=key, value=None, cause=e)
239
- assert isinstance(node, Node)
240
- if resolve:
241
- try:
242
- node = node._dereference_node()
243
- except InterpolationResolutionError as e:
244
- conf._format_and_raise(key=key, value=None, cause=e)
245
-
246
- if isinstance(node, Container):
247
- value = BaseContainer._to_content(
248
- node,
249
- resolve=resolve,
250
- throw_on_missing=throw_on_missing,
251
- enum_to_str=enum_to_str,
252
- structured_config_mode=structured_config_mode,
253
- )
254
- else:
255
- value = convert(node)
256
- return value
257
-
258
- if conf._is_none():
259
- return None
260
- elif conf._is_missing():
261
- if throw_on_missing:
262
- conf._format_and_raise(
263
- key=None,
264
- value=None,
265
- cause=MissingMandatoryValue("Missing mandatory value"),
266
- )
267
- else:
268
- return MISSING
269
- elif not resolve and conf._is_interpolation():
270
- inter = conf._value()
271
- assert isinstance(inter, str)
272
- return inter
273
-
274
- if resolve:
275
- _conf = conf._dereference_node()
276
- assert isinstance(_conf, Container)
277
- conf = _conf
278
-
279
- if isinstance(conf, DictConfig):
280
- if (
281
- conf._metadata.object_type not in (dict, None)
282
- and structured_config_mode == SCMode.DICT_CONFIG
283
- ):
284
- return conf
285
- if structured_config_mode == SCMode.INSTANTIATE and is_structured_config(
286
- conf._metadata.object_type
287
- ):
288
- return conf._to_object()
289
-
290
- retdict: Dict[DictKeyType, Any] = {}
291
- for key in conf.keys():
292
- value = get_node_value(key)
293
- if enum_to_str and isinstance(key, Enum):
294
- key = f"{key.name}"
295
- retdict[key] = value
296
- return retdict
297
- elif isinstance(conf, ListConfig):
298
- retlist: List[Any] = []
299
- for index in range(len(conf)):
300
- item = get_node_value(index)
301
- retlist.append(item)
302
-
303
- return retlist
304
- assert False
305
-
306
- @staticmethod
307
- def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
308
- """merge src into dest and return a new copy, does not modified input"""
309
- from omegaconf import AnyNode, DictConfig, ValueNode
310
-
311
- assert isinstance(dest, DictConfig)
312
- assert isinstance(src, DictConfig)
313
- src_type = src._metadata.object_type
314
- src_ref_type = get_type_hint(src)
315
- assert src_ref_type is not None
316
-
317
- # If source DictConfig is:
318
- # - None => set the destination DictConfig to None
319
- # - an interpolation => set the destination DictConfig to be the same interpolation
320
- if src._is_none() or src._is_interpolation():
321
- dest._set_value(src._value())
322
- _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
323
- return
324
-
325
- dest._validate_merge(value=src)
326
-
327
- def expand(node: Container) -> None:
328
- rt = node._metadata.ref_type
329
- val: Any
330
- if rt is not Any:
331
- if is_dict_annotation(rt):
332
- val = {}
333
- elif is_list_annotation(rt) or is_tuple_annotation(rt):
334
- val = []
335
- else:
336
- val = rt
337
- elif isinstance(node, DictConfig):
338
- val = {}
339
- else:
340
- assert False
341
-
342
- node._set_value(val)
343
-
344
- if (
345
- src._is_missing()
346
- and not dest._is_missing()
347
- and is_structured_config(src_ref_type)
348
- ):
349
- # Replace `src` with a prototype of its corresponding structured config
350
- # whose fields are all missing (to avoid overwriting fields in `dest`).
351
- assert src_type is None # src missing, so src's object_type should be None
352
- src_type = src_ref_type
353
- src = _create_structured_with_missing_fields(
354
- ref_type=src_ref_type, object_type=src_type
355
- )
356
-
357
- if (dest._is_interpolation() or dest._is_missing()) and not src._is_missing():
358
- expand(dest)
359
-
360
- src_items = list(src) if not src._is_missing() else []
361
- for key in src_items:
362
- src_node = src._get_node(key, validate_access=False)
363
- dest_node = dest._get_node(key, validate_access=False)
364
- assert isinstance(src_node, Node)
365
- assert dest_node is None or isinstance(dest_node, Node)
366
- src_value = _get_value(src_node)
367
-
368
- src_vk = get_value_kind(src_node)
369
- src_node_missing = src_vk is ValueKind.MANDATORY_MISSING
370
-
371
- if isinstance(dest_node, DictConfig):
372
- dest_node._validate_merge(value=src_node)
373
-
374
- if (
375
- isinstance(dest_node, Container)
376
- and dest_node._is_none()
377
- and not src_node_missing
378
- and not _is_none(src_node, resolve=True)
379
- ):
380
- expand(dest_node)
381
-
382
- if dest_node is not None and dest_node._is_interpolation():
383
- target_node = dest_node._maybe_dereference_node()
384
- if isinstance(target_node, Container):
385
- dest[key] = target_node
386
- dest_node = dest._get_node(key)
387
-
388
- is_optional, et = _resolve_optional(dest._metadata.element_type)
389
- if dest_node is None and is_structured_config(et) and not src_node_missing:
390
- # merging into a new node. Use element_type as a base
391
- dest[key] = DictConfig(
392
- et, parent=dest, ref_type=et, is_optional=is_optional
393
- )
394
- dest_node = dest._get_node(key)
395
-
396
- if dest_node is not None:
397
- if isinstance(dest_node, BaseContainer):
398
- if isinstance(src_node, BaseContainer):
399
- dest_node._merge_with(src_node)
400
- elif not src_node_missing:
401
- dest.__setitem__(key, src_node)
402
- else:
403
- if isinstance(src_node, BaseContainer):
404
- dest.__setitem__(key, src_node)
405
- else:
406
- assert isinstance(dest_node, (ValueNode, UnionNode))
407
- assert isinstance(src_node, (ValueNode, UnionNode))
408
- try:
409
- if isinstance(dest_node, AnyNode):
410
- if src_node_missing:
411
- node = copy.copy(src_node)
412
- # if src node is missing, use the value from the dest_node,
413
- # but validate it against the type of the src node before assigment
414
- node._set_value(dest_node._value())
415
- else:
416
- node = src_node
417
- dest.__setitem__(key, node)
418
- else:
419
- if not src_node_missing:
420
- dest_node._set_value(src_value)
421
-
422
- except (ValidationError, ReadonlyConfigError) as e:
423
- dest._format_and_raise(key=key, value=src_value, cause=e)
424
- else:
425
- from omegaconf import open_dict
426
-
427
- if is_structured_config(src_type):
428
- # verified to be compatible above in _validate_merge
429
- with open_dict(dest):
430
- dest[key] = src._get_node(key)
431
- else:
432
- dest[key] = src._get_node(key)
433
-
434
- _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
435
-
436
- # explicit flags on the source config are replacing the flag values in the destination
437
- flags = src._metadata.flags
438
- assert flags is not None
439
- for flag, value in flags.items():
440
- if value is not None:
441
- dest._set_flag(flag, value)
442
-
443
- @staticmethod
444
- def _list_merge(dest: Any, src: Any) -> None:
445
- from omegaconf import DictConfig, ListConfig, OmegaConf
446
-
447
- assert isinstance(dest, ListConfig)
448
- assert isinstance(src, ListConfig)
449
-
450
- if src._is_none():
451
- dest._set_value(None)
452
- elif src._is_missing():
453
- # do not change dest if src is MISSING.
454
- if dest._metadata.element_type is Any:
455
- dest._metadata.element_type = src._metadata.element_type
456
- elif src._is_interpolation():
457
- dest._set_value(src._value())
458
- else:
459
- temp_target = ListConfig(content=[], parent=dest._get_parent())
460
- temp_target.__dict__["_metadata"] = copy.deepcopy(
461
- dest.__dict__["_metadata"]
462
- )
463
- is_optional, et = _resolve_optional(dest._metadata.element_type)
464
- if is_structured_config(et):
465
- prototype = DictConfig(et, ref_type=et, is_optional=is_optional)
466
- for item in src._iter_ex(resolve=False):
467
- if isinstance(item, DictConfig):
468
- item = OmegaConf.merge(prototype, item)
469
- temp_target.append(item)
470
- else:
471
- for item in src._iter_ex(resolve=False):
472
- temp_target.append(item)
473
-
474
- dest.__dict__["_content"] = temp_target.__dict__["_content"]
475
-
476
- # explicit flags on the source config are replacing the flag values in the destination
477
- flags = src._metadata.flags
478
- assert flags is not None
479
- for flag, value in flags.items():
480
- if value is not None:
481
- dest._set_flag(flag, value)
482
-
483
- def merge_with(
484
- self,
485
- *others: Union[
486
- "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
487
- ],
488
- ) -> None:
489
- try:
490
- self._merge_with(*others)
491
- except Exception as e:
492
- self._format_and_raise(key=None, value=None, cause=e)
493
-
494
- def _merge_with(
495
- self,
496
- *others: Union[
497
- "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
498
- ],
499
- ) -> None:
500
- from .dictconfig import DictConfig
501
- from .listconfig import ListConfig
502
-
503
- """merge a list of other Config objects into this one, overriding as needed"""
504
- for other in others:
505
- if other is None:
506
- raise ValueError("Cannot merge with a None config")
507
-
508
- my_flags = {}
509
- if self._get_flag("allow_objects") is True:
510
- my_flags = {"allow_objects": True}
511
- other = _ensure_container(other, flags=my_flags)
512
-
513
- if isinstance(self, DictConfig) and isinstance(other, DictConfig):
514
- BaseContainer._map_merge(self, other)
515
- elif isinstance(self, ListConfig) and isinstance(other, ListConfig):
516
- BaseContainer._list_merge(self, other)
517
- else:
518
- raise TypeError("Cannot merge DictConfig with ListConfig")
519
-
520
- # recursively correct the parent hierarchy after the merge
521
- self._re_parent()
522
-
523
- # noinspection PyProtectedMember
524
- def _set_item_impl(self, key: Any, value: Any) -> None:
525
- """
526
- Changes the value of the node key with the desired value. If the node key doesn't
527
- exist it creates a new one.
528
- """
529
- from .nodes import AnyNode, ValueNode
530
-
531
- if isinstance(value, Node):
532
- do_deepcopy = not self._get_flag("no_deepcopy_set_nodes")
533
- if not do_deepcopy and isinstance(value, Box):
534
- # if value is from the same config, perform a deepcopy no matter what.
535
- if self._get_root() is value._get_root():
536
- do_deepcopy = True
537
-
538
- if do_deepcopy:
539
- value = copy.deepcopy(value)
540
- value._set_parent(None)
541
-
542
- try:
543
- old = value._key()
544
- value._set_key(key)
545
- self._validate_set(key, value)
546
- finally:
547
- value._set_key(old)
548
- else:
549
- self._validate_set(key, value)
550
-
551
- if self._get_flag("readonly"):
552
- raise ReadonlyConfigError("Cannot change read-only config container")
553
-
554
- input_is_node = isinstance(value, Node)
555
- target_node_ref = self._get_node(key)
556
- assert target_node_ref is None or isinstance(target_node_ref, Node)
557
-
558
- input_is_typed_vnode = isinstance(value, ValueNode) and not isinstance(
559
- value, AnyNode
560
- )
561
-
562
- def get_target_type_hint(val: Any) -> Any:
563
- if not is_structured_config(val):
564
- type_hint = self._metadata.element_type
565
- else:
566
- target = self._get_node(key)
567
- if target is None:
568
- type_hint = self._metadata.element_type
569
- else:
570
- assert isinstance(target, Node)
571
- type_hint = target._metadata.type_hint
572
- return type_hint
573
-
574
- target_type_hint = get_target_type_hint(value)
575
- _, target_ref_type = _resolve_optional(target_type_hint)
576
-
577
- def assign(value_key: Any, val: Node) -> None:
578
- assert val._get_parent() is None
579
- v = val
580
- v._set_parent(self)
581
- v._set_key(value_key)
582
- _deep_update_type_hint(node=v, type_hint=self._metadata.element_type)
583
- self.__dict__["_content"][value_key] = v
584
-
585
- if input_is_typed_vnode and not is_union_annotation(target_ref_type):
586
- assign(key, value)
587
- else:
588
- # input is not a ValueNode, can be primitive or box
589
-
590
- special_value = _is_special(value)
591
- # We use the `Node._set_value` method if the target node exists and:
592
- # 1. the target has an explicit ref_type, or
593
- # 2. the target is an AnyNode and the input is a primitive type.
594
- should_set_value = target_node_ref is not None and (
595
- target_node_ref._has_ref_type()
596
- or (
597
- isinstance(target_node_ref, AnyNode)
598
- and is_primitive_type_annotation(value)
599
- )
600
- )
601
- if should_set_value:
602
- if special_value and isinstance(value, Node):
603
- value = value._value()
604
- self.__dict__["_content"][key]._set_value(value)
605
- elif input_is_node:
606
- if (
607
- special_value
608
- and (
609
- is_container_annotation(target_ref_type)
610
- or is_structured_config(target_ref_type)
611
- )
612
- or is_primitive_type_annotation(target_ref_type)
613
- or is_union_annotation(target_ref_type)
614
- ):
615
- value = _get_value(value)
616
- self._wrap_value_and_set(key, value, target_type_hint)
617
- else:
618
- assign(key, value)
619
- else:
620
- self._wrap_value_and_set(key, value, target_type_hint)
621
-
622
- def _wrap_value_and_set(self, key: Any, val: Any, type_hint: Any) -> None:
623
- from omegaconf.omegaconf import _maybe_wrap
624
-
625
- is_optional, ref_type = _resolve_optional(type_hint)
626
-
627
- try:
628
- wrapped = _maybe_wrap(
629
- ref_type=ref_type,
630
- key=key,
631
- value=val,
632
- is_optional=is_optional,
633
- parent=self,
634
- )
635
- except ValidationError as e:
636
- self._format_and_raise(key=key, value=val, cause=e)
637
- self.__dict__["_content"][key] = wrapped
638
-
639
- @staticmethod
640
- def _item_eq(
641
- c1: Container,
642
- k1: Union[DictKeyType, int],
643
- c2: Container,
644
- k2: Union[DictKeyType, int],
645
- ) -> bool:
646
- v1 = c1._get_child(k1)
647
- v2 = c2._get_child(k2)
648
- assert v1 is not None and v2 is not None
649
-
650
- assert isinstance(v1, Node)
651
- assert isinstance(v2, Node)
652
-
653
- if v1._is_none() and v2._is_none():
654
- return True
655
-
656
- if v1._is_missing() and v2._is_missing():
657
- return True
658
-
659
- v1_inter = v1._is_interpolation()
660
- v2_inter = v2._is_interpolation()
661
- dv1: Optional[Node] = v1
662
- dv2: Optional[Node] = v2
663
-
664
- if v1_inter:
665
- dv1 = v1._maybe_dereference_node()
666
- if v2_inter:
667
- dv2 = v2._maybe_dereference_node()
668
-
669
- if v1_inter and v2_inter:
670
- if dv1 is None or dv2 is None:
671
- return v1 == v2
672
- else:
673
- # both are not none, if both are containers compare as container
674
- if isinstance(dv1, Container) and isinstance(dv2, Container):
675
- if dv1 != dv2:
676
- return False
677
- dv1 = _get_value(dv1)
678
- dv2 = _get_value(dv2)
679
- return dv1 == dv2
680
- elif not v1_inter and not v2_inter:
681
- v1 = _get_value(v1)
682
- v2 = _get_value(v2)
683
- ret = v1 == v2
684
- assert isinstance(ret, bool)
685
- return ret
686
- else:
687
- dv1 = _get_value(dv1)
688
- dv2 = _get_value(dv2)
689
- ret = dv1 == dv2
690
- assert isinstance(ret, bool)
691
- return ret
692
-
693
- def _is_optional(self) -> bool:
694
- return self.__dict__["_metadata"].optional is True
695
-
696
- def _is_interpolation(self) -> bool:
697
- return _is_interpolation(self.__dict__["_content"])
698
-
699
- @abstractmethod
700
- def _validate_get(self, key: Any, value: Any = None) -> None:
701
- ...
702
-
703
- @abstractmethod
704
- def _validate_set(self, key: Any, value: Any) -> None:
705
- ...
706
-
707
- def _value(self) -> Any:
708
- return self.__dict__["_content"]
709
-
710
- def _get_full_key(self, key: Union[DictKeyType, int, slice, None]) -> str:
711
- from .listconfig import ListConfig
712
- from .omegaconf import _select_one
713
-
714
- if not isinstance(key, (int, str, Enum, float, bool, slice, bytes, type(None))):
715
- return ""
716
-
717
- def _slice_to_str(x: slice) -> str:
718
- if x.step is not None:
719
- return f"{x.start}:{x.stop}:{x.step}"
720
- else:
721
- return f"{x.start}:{x.stop}"
722
-
723
- def prepand(
724
- full_key: str,
725
- parent_type: Any,
726
- cur_type: Any,
727
- key: Optional[Union[DictKeyType, int, slice]],
728
- ) -> str:
729
- if key is None:
730
- return full_key
731
-
732
- if isinstance(key, slice):
733
- key = _slice_to_str(key)
734
- elif isinstance(key, Enum):
735
- key = key.name
736
- else:
737
- key = str(key)
738
-
739
- assert isinstance(key, str)
740
-
741
- if issubclass(parent_type, ListConfig):
742
- if full_key != "":
743
- if issubclass(cur_type, ListConfig):
744
- full_key = f"[{key}]{full_key}"
745
- else:
746
- full_key = f"[{key}].{full_key}"
747
- else:
748
- full_key = f"[{key}]"
749
- else:
750
- if full_key == "":
751
- full_key = key
752
- else:
753
- if issubclass(cur_type, ListConfig):
754
- full_key = f"{key}{full_key}"
755
- else:
756
- full_key = f"{key}.{full_key}"
757
- return full_key
758
-
759
- if key is not None and key != "":
760
- assert isinstance(self, Container)
761
- cur, _ = _select_one(
762
- c=self, key=str(key), throw_on_missing=False, throw_on_type_error=False
763
- )
764
- if cur is None:
765
- cur = self
766
- full_key = prepand("", type(cur), None, key)
767
- if cur._key() is not None:
768
- full_key = prepand(
769
- full_key, type(cur._get_parent()), type(cur), cur._key()
770
- )
771
- else:
772
- full_key = prepand("", type(cur._get_parent()), type(cur), cur._key())
773
- else:
774
- cur = self
775
- if cur._key() is None:
776
- return ""
777
- full_key = self._key()
778
-
779
- assert cur is not None
780
- memo = {id(cur)} # remember already visited nodes so as to detect cycles
781
- while cur._get_parent() is not None:
782
- cur = cur._get_parent()
783
- if id(cur) in memo:
784
- raise ConfigCycleDetectedException(
785
- f"Cycle when iterating over parents of key `{key!s}`"
786
- )
787
- memo.add(id(cur))
788
- assert cur is not None
789
- if cur._key() is not None:
790
- full_key = prepand(
791
- full_key, type(cur._get_parent()), type(cur), cur._key()
792
- )
793
-
794
- return full_key
795
-
796
-
797
- def _create_structured_with_missing_fields(
798
- ref_type: type, object_type: Optional[type] = None
799
- ) -> "DictConfig":
800
- from . import MISSING, DictConfig
801
-
802
- cfg_data = get_structured_config_data(ref_type)
803
- for v in cfg_data.values():
804
- v._set_value(MISSING)
805
-
806
- cfg = DictConfig(cfg_data)
807
- cfg._metadata.optional, cfg._metadata.ref_type = _resolve_optional(ref_type)
808
- cfg._metadata.object_type = object_type
809
-
810
- return cfg
811
-
812
-
813
- def _update_types(node: Node, ref_type: Any, object_type: Optional[type]) -> None:
814
- if object_type is not None and not is_primitive_dict(object_type):
815
- node._metadata.object_type = object_type
816
-
817
- if node._metadata.ref_type is Any:
818
- _deep_update_type_hint(node, ref_type)
819
-
820
-
821
- def _deep_update_type_hint(node: Node, type_hint: Any) -> None:
822
- """Ensure node is compatible with type_hint, mutating if necessary."""
823
- from omegaconf import DictConfig, ListConfig
824
-
825
- from ._utils import get_dict_key_value_types, get_list_element_type
826
-
827
- if type_hint is Any:
828
- return
829
-
830
- _shallow_validate_type_hint(node, type_hint)
831
-
832
- new_is_optional, new_ref_type = _resolve_optional(type_hint)
833
- node._metadata.ref_type = new_ref_type
834
- node._metadata.optional = new_is_optional
835
-
836
- if is_list_annotation(new_ref_type) and isinstance(node, ListConfig):
837
- new_element_type = get_list_element_type(new_ref_type)
838
- node._metadata.element_type = new_element_type
839
- if not _is_special(node):
840
- for i in range(len(node)):
841
- _deep_update_subnode(node, i, new_element_type)
842
-
843
- if is_dict_annotation(new_ref_type) and isinstance(node, DictConfig):
844
- new_key_type, new_element_type = get_dict_key_value_types(new_ref_type)
845
- node._metadata.key_type = new_key_type
846
- node._metadata.element_type = new_element_type
847
- if not _is_special(node):
848
- for key in node:
849
- if new_key_type is not Any and not isinstance(key, new_key_type):
850
- raise KeyValidationError(
851
- f"Key {key!r} ({type(key).__name__}) is incompatible"
852
- + f" with key type hint '{new_key_type.__name__}'"
853
- )
854
- _deep_update_subnode(node, key, new_element_type)
855
-
856
-
857
- def _deep_update_subnode(node: BaseContainer, key: Any, value_type_hint: Any) -> None:
858
- """Get node[key] and ensure it is compatible with value_type_hint, mutating if necessary."""
859
- subnode = node._get_node(key)
860
- assert isinstance(subnode, Node)
861
- if _is_special(subnode):
862
- # Ensure special values are wrapped in a Node subclass that
863
- # is compatible with the type hint.
864
- node._wrap_value_and_set(key, subnode._value(), value_type_hint)
865
- subnode = node._get_node(key)
866
- assert isinstance(subnode, Node)
867
- _deep_update_type_hint(subnode, value_type_hint)
868
-
869
-
870
- def _shallow_validate_type_hint(node: Node, type_hint: Any) -> None:
871
- """Error if node's type, content and metadata are not compatible with type_hint."""
872
- from omegaconf import DictConfig, ListConfig, ValueNode
873
-
874
- is_optional, ref_type = _resolve_optional(type_hint)
875
-
876
- vk = get_value_kind(node)
877
-
878
- if node._is_none():
879
- if not is_optional:
880
- value = _get_value(node)
881
- raise ValidationError(
882
- f"Value {value!r} ({type(value).__name__})"
883
- + f" is incompatible with type hint '{ref_type.__name__}'"
884
- )
885
- return
886
- elif vk in (ValueKind.MANDATORY_MISSING, ValueKind.INTERPOLATION):
887
- return
888
- elif vk == ValueKind.VALUE:
889
- if is_primitive_type_annotation(ref_type) and isinstance(node, ValueNode):
890
- value = node._value()
891
- if not isinstance(value, ref_type):
892
- raise ValidationError(
893
- f"Value {value!r} ({type(value).__name__})"
894
- + f" is incompatible with type hint '{ref_type.__name__}'"
895
- )
896
- elif is_structured_config(ref_type) and isinstance(node, DictConfig):
897
- return
898
- elif is_dict_annotation(ref_type) and isinstance(node, DictConfig):
899
- return
900
- elif is_list_annotation(ref_type) and isinstance(node, ListConfig):
901
- return
902
- else:
903
- if isinstance(node, ValueNode):
904
- value = node._value()
905
- raise ValidationError(
906
- f"Value {value!r} ({type(value).__name__})"
907
- + f" is incompatible with type hint '{ref_type}'"
908
- )
909
- else:
910
- raise ValidationError(
911
- f"'{type(node).__name__}' is incompatible"
912
- + f" with type hint '{ref_type}'"
913
- )
914
-
915
- else:
916
- assert False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/dictconfig.py DELETED
@@ -1,776 +0,0 @@
1
- import copy
2
- from enum import Enum
3
- from typing import (
4
- Any,
5
- Dict,
6
- ItemsView,
7
- Iterable,
8
- Iterator,
9
- KeysView,
10
- List,
11
- MutableMapping,
12
- Optional,
13
- Sequence,
14
- Tuple,
15
- Type,
16
- Union,
17
- )
18
-
19
- from ._utils import (
20
- _DEFAULT_MARKER_,
21
- ValueKind,
22
- _get_value,
23
- _is_interpolation,
24
- _is_missing_literal,
25
- _is_missing_value,
26
- _is_none,
27
- _resolve_optional,
28
- _valid_dict_key_annotation_type,
29
- format_and_raise,
30
- get_structured_config_data,
31
- get_structured_config_init_field_names,
32
- get_type_of,
33
- get_value_kind,
34
- is_container_annotation,
35
- is_dict,
36
- is_primitive_dict,
37
- is_structured_config,
38
- is_structured_config_frozen,
39
- type_str,
40
- )
41
- from .base import Box, Container, ContainerMetadata, DictKeyType, Node
42
- from .basecontainer import BaseContainer
43
- from .errors import (
44
- ConfigAttributeError,
45
- ConfigKeyError,
46
- ConfigTypeError,
47
- InterpolationResolutionError,
48
- KeyValidationError,
49
- MissingMandatoryValue,
50
- OmegaConfBaseException,
51
- ReadonlyConfigError,
52
- ValidationError,
53
- )
54
- from .nodes import EnumNode, ValueNode
55
-
56
-
57
- class DictConfig(BaseContainer, MutableMapping[Any, Any]):
58
-
59
- _metadata: ContainerMetadata
60
- _content: Union[Dict[DictKeyType, Node], None, str]
61
-
62
- def __init__(
63
- self,
64
- content: Union[Dict[DictKeyType, Any], "DictConfig", Any],
65
- key: Any = None,
66
- parent: Optional[Box] = None,
67
- ref_type: Union[Any, Type[Any]] = Any,
68
- key_type: Union[Any, Type[Any]] = Any,
69
- element_type: Union[Any, Type[Any]] = Any,
70
- is_optional: bool = True,
71
- flags: Optional[Dict[str, bool]] = None,
72
- ) -> None:
73
- try:
74
- if isinstance(content, DictConfig):
75
- if flags is None:
76
- flags = content._metadata.flags
77
- super().__init__(
78
- parent=parent,
79
- metadata=ContainerMetadata(
80
- key=key,
81
- optional=is_optional,
82
- ref_type=ref_type,
83
- object_type=dict,
84
- key_type=key_type,
85
- element_type=element_type,
86
- flags=flags,
87
- ),
88
- )
89
-
90
- if not _valid_dict_key_annotation_type(key_type):
91
- raise KeyValidationError(f"Unsupported key type {key_type}")
92
-
93
- if is_structured_config(content) or is_structured_config(ref_type):
94
- self._set_value(content, flags=flags)
95
- if is_structured_config_frozen(content) or is_structured_config_frozen(
96
- ref_type
97
- ):
98
- self._set_flag("readonly", True)
99
-
100
- else:
101
- if isinstance(content, DictConfig):
102
- metadata = copy.deepcopy(content._metadata)
103
- metadata.key = key
104
- metadata.ref_type = ref_type
105
- metadata.optional = is_optional
106
- metadata.element_type = element_type
107
- metadata.key_type = key_type
108
- self.__dict__["_metadata"] = metadata
109
- self._set_value(content, flags=flags)
110
- except Exception as ex:
111
- format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
112
-
113
- def __deepcopy__(self, memo: Dict[int, Any]) -> "DictConfig":
114
- res = DictConfig(None)
115
- res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
116
- res.__dict__["_flags_cache"] = copy.deepcopy(
117
- self.__dict__["_flags_cache"], memo=memo
118
- )
119
-
120
- src_content = self.__dict__["_content"]
121
- if isinstance(src_content, dict):
122
- content_copy = {}
123
- for k, v in src_content.items():
124
- old_parent = v.__dict__["_parent"]
125
- try:
126
- v.__dict__["_parent"] = None
127
- vc = copy.deepcopy(v, memo=memo)
128
- vc.__dict__["_parent"] = res
129
- content_copy[k] = vc
130
- finally:
131
- v.__dict__["_parent"] = old_parent
132
- else:
133
- # None and strings can be assigned as is
134
- content_copy = src_content
135
-
136
- res.__dict__["_content"] = content_copy
137
- # parent is retained, but not copied
138
- res.__dict__["_parent"] = self.__dict__["_parent"]
139
- return res
140
-
141
- def copy(self) -> "DictConfig":
142
- return copy.copy(self)
143
-
144
- def _is_typed(self) -> bool:
145
- return self._metadata.object_type not in (Any, None) and not is_dict(
146
- self._metadata.object_type
147
- )
148
-
149
- def _validate_get(self, key: Any, value: Any = None) -> None:
150
- is_typed = self._is_typed()
151
-
152
- is_struct = self._get_flag("struct") is True
153
- if key not in self.__dict__["_content"]:
154
- if is_typed:
155
- # do not raise an exception if struct is explicitly set to False
156
- if self._get_node_flag("struct") is False:
157
- return
158
- if is_typed or is_struct:
159
- if is_typed:
160
- assert self._metadata.object_type not in (dict, None)
161
- msg = f"Key '{key}' not in '{self._metadata.object_type.__name__}'"
162
- else:
163
- msg = f"Key '{key}' is not in struct"
164
- self._format_and_raise(
165
- key=key, value=value, cause=ConfigAttributeError(msg)
166
- )
167
-
168
- def _validate_set(self, key: Any, value: Any) -> None:
169
- from omegaconf import OmegaConf
170
-
171
- vk = get_value_kind(value)
172
- if vk == ValueKind.INTERPOLATION:
173
- return
174
- if _is_none(value):
175
- self._validate_non_optional(key, value)
176
- return
177
- if vk == ValueKind.MANDATORY_MISSING or value is None:
178
- return
179
-
180
- target = self._get_node(key) if key is not None else self
181
-
182
- target_has_ref_type = isinstance(
183
- target, DictConfig
184
- ) and target._metadata.ref_type not in (Any, dict)
185
- is_valid_target = target is None or not target_has_ref_type
186
-
187
- if is_valid_target:
188
- return
189
-
190
- assert isinstance(target, Node)
191
-
192
- target_type = target._metadata.ref_type
193
- value_type = OmegaConf.get_type(value)
194
-
195
- if is_dict(value_type) and is_dict(target_type):
196
- return
197
- if is_container_annotation(target_type) and not is_container_annotation(
198
- value_type
199
- ):
200
- raise ValidationError(
201
- f"Cannot assign {type_str(value_type)} to {type_str(target_type)}"
202
- )
203
-
204
- if target_type is not None and value_type is not None:
205
- origin = getattr(target_type, "__origin__", target_type)
206
- if not issubclass(value_type, origin):
207
- self._raise_invalid_value(value, value_type, target_type)
208
-
209
- def _validate_merge(self, value: Any) -> None:
210
- from omegaconf import OmegaConf
211
-
212
- dest = self
213
- src = value
214
-
215
- self._validate_non_optional(None, src)
216
-
217
- dest_obj_type = OmegaConf.get_type(dest)
218
- src_obj_type = OmegaConf.get_type(src)
219
-
220
- if dest._is_missing() and src._metadata.object_type not in (dict, None):
221
- self._validate_set(key=None, value=_get_value(src))
222
-
223
- if src._is_missing():
224
- return
225
-
226
- validation_error = (
227
- dest_obj_type is not None
228
- and src_obj_type is not None
229
- and is_structured_config(dest_obj_type)
230
- and not src._is_none()
231
- and not is_dict(src_obj_type)
232
- and not issubclass(src_obj_type, dest_obj_type)
233
- )
234
- if validation_error:
235
- msg = (
236
- f"Merge error: {type_str(src_obj_type)} is not a "
237
- f"subclass of {type_str(dest_obj_type)}. value: {src}"
238
- )
239
- raise ValidationError(msg)
240
-
241
- def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None:
242
- if _is_none(value, resolve=True, throw_on_resolution_failure=False):
243
-
244
- if key is not None:
245
- child = self._get_node(key)
246
- if child is not None:
247
- assert isinstance(child, Node)
248
- field_is_optional = child._is_optional()
249
- else:
250
- field_is_optional, _ = _resolve_optional(
251
- self._metadata.element_type
252
- )
253
- else:
254
- field_is_optional = self._is_optional()
255
-
256
- if not field_is_optional:
257
- self._format_and_raise(
258
- key=key,
259
- value=value,
260
- cause=ValidationError("field '$FULL_KEY' is not Optional"),
261
- )
262
-
263
- def _raise_invalid_value(
264
- self, value: Any, value_type: Any, target_type: Any
265
- ) -> None:
266
- assert value_type is not None
267
- assert target_type is not None
268
- msg = (
269
- f"Invalid type assigned: {type_str(value_type)} is not a "
270
- f"subclass of {type_str(target_type)}. value: {value}"
271
- )
272
- raise ValidationError(msg)
273
-
274
- def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
275
- return self._s_validate_and_normalize_key(self._metadata.key_type, key)
276
-
277
- def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
278
- if key_type is Any:
279
- for t in DictKeyType.__args__: # type: ignore
280
- if isinstance(key, t):
281
- return key # type: ignore
282
- raise KeyValidationError("Incompatible key type '$KEY_TYPE'")
283
- elif key_type is bool and key in [0, 1]:
284
- # Python treats True as 1 and False as 0 when used as dict keys
285
- # assert hash(0) == hash(False)
286
- # assert hash(1) == hash(True)
287
- return bool(key)
288
- elif key_type in (str, bytes, int, float, bool): # primitive type
289
- if not isinstance(key, key_type):
290
- raise KeyValidationError(
291
- f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
292
- )
293
-
294
- return key # type: ignore
295
- elif issubclass(key_type, Enum):
296
- try:
297
- return EnumNode.validate_and_convert_to_enum(key_type, key)
298
- except ValidationError:
299
- valid = ", ".join([x for x in key_type.__members__.keys()])
300
- raise KeyValidationError(
301
- f"Key '$KEY' is incompatible with the enum type '{key_type.__name__}', valid: [{valid}]"
302
- )
303
- else:
304
- assert False, f"Unsupported key type {key_type}"
305
-
306
- def __setitem__(self, key: DictKeyType, value: Any) -> None:
307
- try:
308
- self.__set_impl(key=key, value=value)
309
- except AttributeError as e:
310
- self._format_and_raise(
311
- key=key, value=value, type_override=ConfigKeyError, cause=e
312
- )
313
- except Exception as e:
314
- self._format_and_raise(key=key, value=value, cause=e)
315
-
316
- def __set_impl(self, key: DictKeyType, value: Any) -> None:
317
- key = self._validate_and_normalize_key(key)
318
- self._set_item_impl(key, value)
319
-
320
- # hide content while inspecting in debugger
321
- def __dir__(self) -> Iterable[str]:
322
- if self._is_missing() or self._is_none():
323
- return []
324
- return self.__dict__["_content"].keys() # type: ignore
325
-
326
- def __setattr__(self, key: str, value: Any) -> None:
327
- """
328
- Allow assigning attributes to DictConfig
329
- :param key:
330
- :param value:
331
- :return:
332
- """
333
- try:
334
- self.__set_impl(key, value)
335
- except Exception as e:
336
- if isinstance(e, OmegaConfBaseException) and e._initialized:
337
- raise e
338
- self._format_and_raise(key=key, value=value, cause=e)
339
- assert False
340
-
341
- def __getattr__(self, key: str) -> Any:
342
- """
343
- Allow accessing dictionary values as attributes
344
- :param key:
345
- :return:
346
- """
347
- if key == "__name__":
348
- raise AttributeError()
349
-
350
- try:
351
- return self._get_impl(
352
- key=key, default_value=_DEFAULT_MARKER_, validate_key=False
353
- )
354
- except ConfigKeyError as e:
355
- self._format_and_raise(
356
- key=key, value=None, cause=e, type_override=ConfigAttributeError
357
- )
358
- except Exception as e:
359
- self._format_and_raise(key=key, value=None, cause=e)
360
-
361
- def __getitem__(self, key: DictKeyType) -> Any:
362
- """
363
- Allow map style access
364
- :param key:
365
- :return:
366
- """
367
-
368
- try:
369
- return self._get_impl(key=key, default_value=_DEFAULT_MARKER_)
370
- except AttributeError as e:
371
- self._format_and_raise(
372
- key=key, value=None, cause=e, type_override=ConfigKeyError
373
- )
374
- except Exception as e:
375
- self._format_and_raise(key=key, value=None, cause=e)
376
-
377
- def __delattr__(self, key: str) -> None:
378
- """
379
- Allow deleting dictionary values as attributes
380
- :param key:
381
- :return:
382
- """
383
- if self._get_flag("readonly"):
384
- self._format_and_raise(
385
- key=key,
386
- value=None,
387
- cause=ReadonlyConfigError(
388
- "DictConfig in read-only mode does not support deletion"
389
- ),
390
- )
391
- try:
392
- del self.__dict__["_content"][key]
393
- except KeyError:
394
- msg = "Attribute not found: '$KEY'"
395
- self._format_and_raise(key=key, value=None, cause=ConfigAttributeError(msg))
396
-
397
- def __delitem__(self, key: DictKeyType) -> None:
398
- key = self._validate_and_normalize_key(key)
399
- if self._get_flag("readonly"):
400
- self._format_and_raise(
401
- key=key,
402
- value=None,
403
- cause=ReadonlyConfigError(
404
- "DictConfig in read-only mode does not support deletion"
405
- ),
406
- )
407
- if self._get_flag("struct"):
408
- self._format_and_raise(
409
- key=key,
410
- value=None,
411
- cause=ConfigTypeError(
412
- "DictConfig in struct mode does not support deletion"
413
- ),
414
- )
415
- if self._is_typed() and self._get_node_flag("struct") is not False:
416
- self._format_and_raise(
417
- key=key,
418
- value=None,
419
- cause=ConfigTypeError(
420
- f"{type_str(self._metadata.object_type)} (DictConfig) does not support deletion"
421
- ),
422
- )
423
-
424
- try:
425
- del self.__dict__["_content"][key]
426
- except KeyError:
427
- msg = "Key not found: '$KEY'"
428
- self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg))
429
-
430
- def get(self, key: DictKeyType, default_value: Any = None) -> Any:
431
- """Return the value for `key` if `key` is in the dictionary, else
432
- `default_value` (defaulting to `None`)."""
433
- try:
434
- return self._get_impl(key=key, default_value=default_value)
435
- except KeyValidationError as e:
436
- self._format_and_raise(key=key, value=None, cause=e)
437
-
438
- def _get_impl(
439
- self, key: DictKeyType, default_value: Any, validate_key: bool = True
440
- ) -> Any:
441
- try:
442
- node = self._get_child(
443
- key=key, throw_on_missing_key=True, validate_key=validate_key
444
- )
445
- except (ConfigAttributeError, ConfigKeyError):
446
- if default_value is not _DEFAULT_MARKER_:
447
- return default_value
448
- else:
449
- raise
450
- assert isinstance(node, Node)
451
- return self._resolve_with_default(
452
- key=key, value=node, default_value=default_value
453
- )
454
-
455
- def _get_node(
456
- self,
457
- key: DictKeyType,
458
- validate_access: bool = True,
459
- validate_key: bool = True,
460
- throw_on_missing_value: bool = False,
461
- throw_on_missing_key: bool = False,
462
- ) -> Optional[Node]:
463
- try:
464
- key = self._validate_and_normalize_key(key)
465
- except KeyValidationError:
466
- if validate_access and validate_key:
467
- raise
468
- else:
469
- if throw_on_missing_key:
470
- raise ConfigAttributeError
471
- else:
472
- return None
473
-
474
- if validate_access:
475
- self._validate_get(key)
476
-
477
- value: Optional[Node] = self.__dict__["_content"].get(key)
478
- if value is None:
479
- if throw_on_missing_key:
480
- raise ConfigKeyError(f"Missing key {key!s}")
481
- elif throw_on_missing_value and value._is_missing():
482
- raise MissingMandatoryValue("Missing mandatory value: $KEY")
483
- return value
484
-
485
- def pop(self, key: DictKeyType, default: Any = _DEFAULT_MARKER_) -> Any:
486
- try:
487
- if self._get_flag("readonly"):
488
- raise ReadonlyConfigError("Cannot pop from read-only node")
489
- if self._get_flag("struct"):
490
- raise ConfigTypeError("DictConfig in struct mode does not support pop")
491
- if self._is_typed() and self._get_node_flag("struct") is not False:
492
- raise ConfigTypeError(
493
- f"{type_str(self._metadata.object_type)} (DictConfig) does not support pop"
494
- )
495
- key = self._validate_and_normalize_key(key)
496
- node = self._get_child(key=key, validate_access=False)
497
- if node is not None:
498
- assert isinstance(node, Node)
499
- value = self._resolve_with_default(
500
- key=key, value=node, default_value=default
501
- )
502
-
503
- del self[key]
504
- return value
505
- else:
506
- if default is not _DEFAULT_MARKER_:
507
- return default
508
- else:
509
- full = self._get_full_key(key=key)
510
- if full != key:
511
- raise ConfigKeyError(
512
- f"Key not found: '{key!s}' (path: '{full}')"
513
- )
514
- else:
515
- raise ConfigKeyError(f"Key not found: '{key!s}'")
516
- except Exception as e:
517
- self._format_and_raise(key=key, value=None, cause=e)
518
-
519
- def keys(self) -> KeysView[DictKeyType]:
520
- if self._is_missing() or self._is_interpolation() or self._is_none():
521
- return {}.keys()
522
- ret = self.__dict__["_content"].keys()
523
- assert isinstance(ret, KeysView)
524
- return ret
525
-
526
- def __contains__(self, key: object) -> bool:
527
- """
528
- A key is contained in a DictConfig if there is an associated value and
529
- it is not a mandatory missing value ('???').
530
- :param key:
531
- :return:
532
- """
533
-
534
- try:
535
- key = self._validate_and_normalize_key(key)
536
- except KeyValidationError:
537
- return False
538
-
539
- try:
540
- node = self._get_child(key)
541
- assert node is None or isinstance(node, Node)
542
- except (KeyError, AttributeError):
543
- node = None
544
-
545
- if node is None:
546
- return False
547
- else:
548
- try:
549
- self._resolve_with_default(key=key, value=node)
550
- return True
551
- except InterpolationResolutionError:
552
- # Interpolations that fail count as existing.
553
- return True
554
- except MissingMandatoryValue:
555
- # Missing values count as *not* existing.
556
- return False
557
-
558
- def __iter__(self) -> Iterator[DictKeyType]:
559
- return iter(self.keys())
560
-
561
- def items(self) -> ItemsView[DictKeyType, Any]:
562
- return dict(self.items_ex(resolve=True, keys=None)).items()
563
-
564
- def setdefault(self, key: DictKeyType, default: Any = None) -> Any:
565
- if key in self:
566
- ret = self.__getitem__(key)
567
- else:
568
- ret = default
569
- self.__setitem__(key, default)
570
- return ret
571
-
572
- def items_ex(
573
- self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None
574
- ) -> List[Tuple[DictKeyType, Any]]:
575
- items: List[Tuple[DictKeyType, Any]] = []
576
-
577
- if self._is_none():
578
- self._format_and_raise(
579
- key=None,
580
- value=None,
581
- cause=TypeError("Cannot iterate a DictConfig object representing None"),
582
- )
583
- if self._is_missing():
584
- raise MissingMandatoryValue("Cannot iterate a missing DictConfig")
585
-
586
- for key in self.keys():
587
- if resolve:
588
- value = self[key]
589
- else:
590
- value = self.__dict__["_content"][key]
591
- if isinstance(value, ValueNode):
592
- value = value._value()
593
- if keys is None or key in keys:
594
- items.append((key, value))
595
-
596
- return items
597
-
598
- def __eq__(self, other: Any) -> bool:
599
- if other is None:
600
- return self.__dict__["_content"] is None
601
- if is_primitive_dict(other) or is_structured_config(other):
602
- other = DictConfig(other, flags={"allow_objects": True})
603
- return DictConfig._dict_conf_eq(self, other)
604
- if isinstance(other, DictConfig):
605
- return DictConfig._dict_conf_eq(self, other)
606
- if self._is_missing():
607
- return _is_missing_literal(other)
608
- return NotImplemented
609
-
610
- def __ne__(self, other: Any) -> bool:
611
- x = self.__eq__(other)
612
- if x is not NotImplemented:
613
- return not x
614
- return NotImplemented
615
-
616
- def __hash__(self) -> int:
617
- return hash(str(self))
618
-
619
- def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None:
620
- """
621
- Retypes a node.
622
- This should only be used in rare circumstances, where you want to dynamically change
623
- the runtime structured-type of a DictConfig.
624
- It will change the type and add the additional fields based on the input class or object
625
- """
626
- if type_or_prototype is None:
627
- return
628
- if not is_structured_config(type_or_prototype):
629
- raise ValueError(f"Expected structured config class: {type_or_prototype}")
630
-
631
- from omegaconf import OmegaConf
632
-
633
- proto: DictConfig = OmegaConf.structured(type_or_prototype)
634
- object_type = proto._metadata.object_type
635
- # remove the type to prevent assignment validation from rejecting the promotion.
636
- proto._metadata.object_type = None
637
- self.merge_with(proto)
638
- # restore the type.
639
- self._metadata.object_type = object_type
640
-
641
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
642
- try:
643
- previous_content = self.__dict__["_content"]
644
- self._set_value_impl(value, flags)
645
- except Exception as e:
646
- self.__dict__["_content"] = previous_content
647
- raise e
648
-
649
- def _set_value_impl(
650
- self, value: Any, flags: Optional[Dict[str, bool]] = None
651
- ) -> None:
652
- from omegaconf import MISSING, flag_override
653
-
654
- if flags is None:
655
- flags = {}
656
-
657
- assert not isinstance(value, ValueNode)
658
- self._validate_set(key=None, value=value)
659
-
660
- if _is_none(value, resolve=True):
661
- self.__dict__["_content"] = None
662
- self._metadata.object_type = None
663
- elif _is_interpolation(value, strict_interpolation_validation=True):
664
- self.__dict__["_content"] = value
665
- self._metadata.object_type = None
666
- elif _is_missing_value(value):
667
- self.__dict__["_content"] = MISSING
668
- self._metadata.object_type = None
669
- else:
670
- self.__dict__["_content"] = {}
671
- if is_structured_config(value):
672
- self._metadata.object_type = None
673
- ao = self._get_flag("allow_objects")
674
- data = get_structured_config_data(value, allow_objects=ao)
675
- with flag_override(self, ["struct", "readonly"], False):
676
- for k, v in data.items():
677
- self.__setitem__(k, v)
678
- self._metadata.object_type = get_type_of(value)
679
-
680
- elif isinstance(value, DictConfig):
681
- self._metadata.flags = copy.deepcopy(flags)
682
- with flag_override(self, ["struct", "readonly"], False):
683
- for k, v in value.__dict__["_content"].items():
684
- self.__setitem__(k, v)
685
- self._metadata.object_type = value._metadata.object_type
686
-
687
- elif isinstance(value, dict):
688
- with flag_override(self, ["struct", "readonly"], False):
689
- for k, v in value.items():
690
- self.__setitem__(k, v)
691
- self._metadata.object_type = dict
692
-
693
- else: # pragma: no cover
694
- msg = f"Unsupported value type: {value}"
695
- raise ValidationError(msg)
696
-
697
- @staticmethod
698
- def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool:
699
-
700
- d1_none = d1.__dict__["_content"] is None
701
- d2_none = d2.__dict__["_content"] is None
702
- if d1_none and d2_none:
703
- return True
704
- if d1_none != d2_none:
705
- return False
706
-
707
- assert isinstance(d1, DictConfig)
708
- assert isinstance(d2, DictConfig)
709
- if len(d1) != len(d2):
710
- return False
711
- if d1._is_missing() or d2._is_missing():
712
- return d1._is_missing() is d2._is_missing()
713
-
714
- for k, v in d1.items_ex(resolve=False):
715
- if k not in d2.__dict__["_content"]:
716
- return False
717
- if not BaseContainer._item_eq(d1, k, d2, k):
718
- return False
719
-
720
- return True
721
-
722
- def _to_object(self) -> Any:
723
- """
724
- Instantiate an instance of `self._metadata.object_type`.
725
- This requires `self` to be a structured config.
726
- Nested subconfigs are converted by calling `OmegaConf.to_object`.
727
- """
728
- from omegaconf import OmegaConf
729
-
730
- object_type = self._metadata.object_type
731
- assert is_structured_config(object_type)
732
- init_field_names = set(get_structured_config_init_field_names(object_type))
733
-
734
- init_field_items: Dict[str, Any] = {}
735
- non_init_field_items: Dict[str, Any] = {}
736
- for k in self.keys():
737
- assert isinstance(k, str)
738
- node = self._get_child(k)
739
- assert isinstance(node, Node)
740
- try:
741
- node = node._dereference_node()
742
- except InterpolationResolutionError as e:
743
- self._format_and_raise(key=k, value=None, cause=e)
744
- if node._is_missing():
745
- if k not in init_field_names:
746
- continue # MISSING is ignored for init=False fields
747
- self._format_and_raise(
748
- key=k,
749
- value=None,
750
- cause=MissingMandatoryValue(
751
- "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY"
752
- ),
753
- )
754
- if isinstance(node, Container):
755
- v = OmegaConf.to_object(node)
756
- else:
757
- v = node._value()
758
-
759
- if k in init_field_names:
760
- init_field_items[k] = v
761
- else:
762
- non_init_field_items[k] = v
763
-
764
- try:
765
- result = object_type(**init_field_items)
766
- except TypeError as exc:
767
- self._format_and_raise(
768
- key=None,
769
- value=None,
770
- cause=exc,
771
- msg="Could not create instance of `$OBJECT_TYPE`: " + str(exc),
772
- )
773
-
774
- for k, v in non_init_field_items.items():
775
- setattr(result, k, v)
776
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/errors.py DELETED
@@ -1,141 +0,0 @@
1
- from typing import Any, Optional, Type
2
-
3
-
4
- class OmegaConfBaseException(Exception):
5
- # would ideally be typed Optional[Node]
6
- parent_node: Any
7
- child_node: Any
8
- key: Any
9
- full_key: Optional[str]
10
- value: Any
11
- msg: Optional[str]
12
- cause: Optional[Exception]
13
- object_type: Optional[Type[Any]]
14
- object_type_str: Optional[str]
15
- ref_type: Optional[Type[Any]]
16
- ref_type_str: Optional[str]
17
-
18
- _initialized: bool = False
19
-
20
- def __init__(self, *_args: Any, **_kwargs: Any) -> None:
21
- self.parent_node = None
22
- self.child_node = None
23
- self.key = None
24
- self.full_key = None
25
- self.value = None
26
- self.msg = None
27
- self.object_type = None
28
- self.ref_type = None
29
-
30
-
31
- class MissingMandatoryValue(OmegaConfBaseException):
32
- """Thrown when a variable flagged with '???' value is accessed to
33
- indicate that the value was not set"""
34
-
35
-
36
- class KeyValidationError(OmegaConfBaseException, ValueError):
37
- """
38
- Thrown when an a key of invalid type is used
39
- """
40
-
41
-
42
- class ValidationError(OmegaConfBaseException, ValueError):
43
- """
44
- Thrown when a value fails validation
45
- """
46
-
47
-
48
- class UnsupportedValueType(ValidationError, ValueError):
49
- """
50
- Thrown when an input value is not of supported type
51
- """
52
-
53
-
54
- class ReadonlyConfigError(OmegaConfBaseException):
55
- """
56
- Thrown when someone tries to modify a frozen config
57
- """
58
-
59
-
60
- class InterpolationResolutionError(OmegaConfBaseException, ValueError):
61
- """
62
- Base class for exceptions raised when resolving an interpolation.
63
- """
64
-
65
-
66
- class UnsupportedInterpolationType(InterpolationResolutionError):
67
- """
68
- Thrown when an attempt to use an unregistered interpolation is made
69
- """
70
-
71
-
72
- class InterpolationKeyError(InterpolationResolutionError):
73
- """
74
- Thrown when a node does not exist when resolving an interpolation.
75
- """
76
-
77
-
78
- class InterpolationToMissingValueError(InterpolationResolutionError):
79
- """
80
- Thrown when a node interpolation points to a node that is set to ???.
81
- """
82
-
83
-
84
- class InterpolationValidationError(InterpolationResolutionError, ValidationError):
85
- """
86
- Thrown when the result of an interpolation fails the validation step.
87
- """
88
-
89
-
90
- class ConfigKeyError(OmegaConfBaseException, KeyError):
91
- """
92
- Thrown from DictConfig when a regular dict access would have caused a KeyError.
93
- """
94
-
95
- msg: str
96
-
97
- def __init__(self, msg: str) -> None:
98
- super().__init__(msg)
99
- self.msg = msg
100
-
101
- def __str__(self) -> str:
102
- """
103
- Workaround to nasty KeyError quirk: https://bugs.python.org/issue2651
104
- """
105
- return self.msg
106
-
107
-
108
- class ConfigAttributeError(OmegaConfBaseException, AttributeError):
109
- """
110
- Thrown from a config object when a regular access would have caused an AttributeError.
111
- """
112
-
113
-
114
- class ConfigTypeError(OmegaConfBaseException, TypeError):
115
- """
116
- Thrown from a config object when a regular access would have caused a TypeError.
117
- """
118
-
119
-
120
- class ConfigIndexError(OmegaConfBaseException, IndexError):
121
- """
122
- Thrown from a config object when a regular access would have caused an IndexError.
123
- """
124
-
125
-
126
- class ConfigValueError(OmegaConfBaseException, ValueError):
127
- """
128
- Thrown from a config object when a regular access would have caused a ValueError.
129
- """
130
-
131
-
132
- class ConfigCycleDetectedException(OmegaConfBaseException):
133
- """
134
- Thrown when a cycle is detected in the graph made by config nodes.
135
- """
136
-
137
-
138
- class GrammarParseError(OmegaConfBaseException):
139
- """
140
- Thrown when failing to parse an expression according to the ANTLR grammar.
141
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/grammar/OmegaConfGrammarLexer.g4 DELETED
@@ -1,137 +0,0 @@
1
- // Regenerate lexer and parser by running 'python setup.py antlr' at project root.
2
- // See `OmegaConfGrammarParser.g4` for some important information regarding how to
3
- // properly maintain this grammar.
4
-
5
- lexer grammar OmegaConfGrammarLexer;
6
-
7
- // Re-usable fragments.
8
- fragment CHAR: [a-zA-Z];
9
- fragment DIGIT: [0-9];
10
- fragment INT_UNSIGNED: '0' | [1-9] (('_')? DIGIT)*;
11
- fragment ESC_BACKSLASH: '\\\\'; // escaped backslash
12
-
13
- /////////////////////////////
14
- // DEFAULT_MODE (TOPLEVEL) //
15
- /////////////////////////////
16
-
17
- TOP_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
18
-
19
- // Regular string: anything that does not contain any $ and does not end with \
20
- // (this ensures this rule will not consume characters required to recognize other tokens).
21
- ANY_STR: ~[$]* ~[\\$];
22
-
23
- // Escaped interpolation: '\${', optionally preceded by an even number of \
24
- ESC_INTER: ESC_BACKSLASH* '\\${';
25
-
26
- // Backslashes that *may* be escaped (even number).
27
- TOP_ESC: ESC_BACKSLASH+;
28
-
29
- // Other backslashes that will not need escaping (odd number due to not matching the previous rule).
30
- BACKSLASHES: '\\'+ -> type(ANY_STR);
31
-
32
- // The dollar sign must be singled out so that we can recognize interpolations.
33
- DOLLAR: '$' -> type(ANY_STR);
34
-
35
-
36
- ////////////////
37
- // VALUE_MODE //
38
- ////////////////
39
-
40
- mode VALUE_MODE;
41
-
42
- INTER_OPEN: '${' WS? -> pushMode(INTERPOLATION_MODE);
43
- BRACE_OPEN: '{' WS? -> pushMode(VALUE_MODE); // must keep track of braces to detect end of interpolation
44
- BRACE_CLOSE: WS? '}' -> popMode;
45
- QUOTE_OPEN_SINGLE: '\'' -> pushMode(QUOTED_SINGLE_MODE);
46
- QUOTE_OPEN_DOUBLE: '"' -> pushMode(QUOTED_DOUBLE_MODE);
47
-
48
- COMMA: WS? ',' WS?;
49
- BRACKET_OPEN: '[' WS?;
50
- BRACKET_CLOSE: WS? ']';
51
- COLON: WS? ':' WS?;
52
-
53
- // Numbers.
54
-
55
- fragment POINT_FLOAT: INT_UNSIGNED '.' | INT_UNSIGNED? '.' DIGIT (('_')? DIGIT)*;
56
- fragment EXPONENT_FLOAT: (INT_UNSIGNED | POINT_FLOAT) [eE] [+-]? DIGIT (('_')? DIGIT)*;
57
- FLOAT: [+-]? (POINT_FLOAT | EXPONENT_FLOAT | [Ii][Nn][Ff] | [Nn][Aa][Nn]);
58
- INT: [+-]? INT_UNSIGNED;
59
-
60
- // Other reserved keywords.
61
-
62
- BOOL:
63
- [Tt][Rr][Uu][Ee] // TRUE
64
- | [Ff][Aa][Ll][Ss][Ee]; // FALSE
65
-
66
- NULL: [Nn][Uu][Ll][Ll];
67
-
68
- UNQUOTED_CHAR: [/\-\\+.$%*@?|]; // other characters allowed in unquoted strings
69
- ID: (CHAR|'_') (CHAR|DIGIT|'_'|'-')*;
70
- ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
71
- '\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
72
- WS: [ \t]+;
73
-
74
-
75
- ////////////////////////
76
- // INTERPOLATION_MODE //
77
- ////////////////////////
78
-
79
- mode INTERPOLATION_MODE;
80
-
81
- NESTED_INTER_OPEN: INTER_OPEN WS? -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
82
- INTER_COLON: WS? ':' WS? -> type(COLON), mode(VALUE_MODE);
83
- INTER_CLOSE: WS? '}' -> popMode;
84
-
85
- DOT: '.';
86
- INTER_BRACKET_OPEN: '[' -> type(BRACKET_OPEN);
87
- INTER_BRACKET_CLOSE: ']' -> type(BRACKET_CLOSE);
88
- INTER_ID: ID -> type(ID);
89
-
90
- // Interpolation key, may contain any non special character.
91
- // Note that we can allow '$' because the parser does not support interpolations that
92
- // are only part of a key name, i.e., "${foo${bar}}" is not allowed. As a result, it
93
- // is ok to "consume" all '$' characters within the `INTER_KEY` token.
94
- INTER_KEY: ~[\\{}()[\]:. \t'"]+;
95
-
96
-
97
- ////////////////////////
98
- // QUOTED_SINGLE_MODE //
99
- ////////////////////////
100
-
101
- mode QUOTED_SINGLE_MODE;
102
-
103
- // This mode is very similar to `DEFAULT_MODE` except for the handling of quotes.
104
-
105
- QSINGLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
106
- MATCHING_QUOTE_CLOSE: '\'' -> popMode;
107
-
108
- // Regular string: anything that does not contain any $ *or quote* and does not end with \
109
- QSINGLE_STR: ~['$]* ~['\\$] -> type(ANY_STR);
110
-
111
- QSINGLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);
112
-
113
- // Escaped quote (optionally preceded by an even number of backslashes).
114
- QSINGLE_ESC_QUOTE: ESC_BACKSLASH* '\\\'' -> type(ESC);
115
-
116
- QUOTED_ESC: ESC_BACKSLASH+;
117
- QSINGLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
118
- QSINGLE_DOLLAR: '$' -> type(ANY_STR);
119
-
120
-
121
- ////////////////////////
122
- // QUOTED_DOUBLE_MODE //
123
- ////////////////////////
124
-
125
- mode QUOTED_DOUBLE_MODE;
126
-
127
- // Same as `QUOTED_SINGLE_MODE` but for double quotes.
128
-
129
- QDOUBLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
130
- QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode;
131
-
132
- QDOUBLE_STR: ~["$]* ~["\\$] -> type(ANY_STR);
133
- QDOUBLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);
134
- QDOUBLE_ESC_QUOTE: ESC_BACKSLASH* '\\"' -> type(ESC);
135
- QDOUBLE_ESC: ESC_BACKSLASH+ -> type(QUOTED_ESC);
136
- QDOUBLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
137
- QDOUBLE_DOLLAR: '$' -> type(ANY_STR);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/grammar/OmegaConfGrammarParser.g4 DELETED
@@ -1,91 +0,0 @@
1
- // Regenerate parser by running 'python setup.py antlr' at project root.
2
-
3
- // Maintenance guidelines when modifying this grammar:
4
- //
5
- // - Consider whether the regex pattern `SIMPLE_INTERPOLATION_PATTERN` found in
6
- // `grammar_parser.py` should be updated as well.
7
- //
8
- // - Update Hydra's grammar accordingly.
9
- //
10
- // - Keep up-to-date the comments in the visitor (in `grammar_visitor.py`)
11
- // that contain grammar excerpts (within each `visit...()` method).
12
- //
13
- // - Remember to update the documentation (including the tutorial notebook as
14
- // well as grammar.rst)
15
-
16
- parser grammar OmegaConfGrammarParser;
17
- options {tokenVocab = OmegaConfGrammarLexer;}
18
-
19
- // Main rules used to parse OmegaConf strings.
20
-
21
- configValue: text EOF;
22
- singleElement: element EOF;
23
-
24
-
25
- // Composite text expression (may contain interpolations).
26
-
27
- text: (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+;
28
-
29
-
30
- // Elements.
31
-
32
- element:
33
- primitive
34
- | quotedValue
35
- | listContainer
36
- | dictContainer
37
- ;
38
-
39
-
40
- // Data structures.
41
-
42
- listContainer: BRACKET_OPEN sequence? BRACKET_CLOSE; // [], [1,2,3], [a,b,[1,2]]
43
- dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
44
- dictKeyValuePair: dictKey COLON element;
45
- sequence: (element (COMMA element?)*) | (COMMA element?)+;
46
-
47
-
48
- // Interpolations.
49
-
50
- interpolation: interpolationNode | interpolationResolver;
51
-
52
- interpolationNode:
53
- INTER_OPEN
54
- DOT* // relative interpolation?
55
- (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
56
- (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
57
- INTER_CLOSE;
58
- interpolationResolver: INTER_OPEN resolverName COLON sequence? BRACE_CLOSE;
59
- configKey: interpolation | ID | INTER_KEY;
60
- resolverName: (interpolation | ID) (DOT (interpolation | ID))* ; // oc.env, myfunc, ns.${x}, ns1.ns2.f
61
-
62
-
63
- // Primitive types.
64
-
65
- // Ex: "hello world", 'hello ${world}'
66
- quotedValue: (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE;
67
-
68
- primitive:
69
- ( ID // foo_10
70
- | NULL // null, NULL
71
- | INT // 0, 10, -20, 1_000_000
72
- | FLOAT // 3.14, -20.0, 1e-1, -10e3
73
- | BOOL // true, TrUe, false, False
74
- | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, |
75
- | COLON // :
76
- | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
77
- | WS // whitespaces
78
- | interpolation
79
- )+;
80
-
81
- // Same as `primitive` except that `COLON` and interpolations are not allowed.
82
- dictKey:
83
- ( ID // foo_10
84
- | NULL // null, NULL
85
- | INT // 0, 10, -20, 1_000_000
86
- | FLOAT // 3.14, -20.0, 1e-1, -10e3
87
- | BOOL // true, TrUe, false, False
88
- | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, |
89
- | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
90
- | WS // whitespaces
91
- )+;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/grammar/__init__.py DELETED
File without changes
omegaconf/grammar/gen/__init__.py DELETED
File without changes
omegaconf/grammar_parser.py DELETED
@@ -1,144 +0,0 @@
1
- import re
2
- import threading
3
- from typing import Any
4
-
5
- from antlr4 import CommonTokenStream, InputStream, ParserRuleContext
6
- from antlr4.error.ErrorListener import ErrorListener
7
-
8
- from .errors import GrammarParseError
9
-
10
- # Import from visitor in order to check the presence of generated grammar files
11
- # files in a single place.
12
- from .grammar_visitor import ( # type: ignore
13
- OmegaConfGrammarLexer,
14
- OmegaConfGrammarParser,
15
- )
16
-
17
- # Used to cache grammar objects to avoid re-creating them on each call to `parse()`.
18
- # We use a per-thread cache to make it thread-safe.
19
- _grammar_cache = threading.local()
20
-
21
- # Build regex pattern to efficiently identify typical interpolations.
22
- # See test `test_match_simple_interpolation_pattern` for examples.
23
- _config_key = r"[$\w]+" # foo, $0, $bar, $foo_$bar123$
24
- _key_maybe_brackets = f"{_config_key}|\\[{_config_key}\\]" # foo, [foo], [$bar]
25
- _node_access = f"\\.{_key_maybe_brackets}" # .foo, [foo], [$bar]
26
- _node_path = f"(\\.)*({_key_maybe_brackets})({_node_access})*" # [foo].bar, .foo[bar]
27
- _node_inter = f"\\${{\\s*{_node_path}\\s*}}" # node interpolation ${foo.bar}
28
- _id = "[a-zA-Z_][\\w\\-]*" # foo, foo_bar, foo-bar, abc123
29
- _resolver_name = f"({_id}(\\.{_id})*)?" # foo, ns.bar3, ns_1.ns_2.b0z
30
- _arg = r"[a-zA-Z_0-9/\-\+.$%*@?|]+" # string representing a resolver argument
31
- _args = f"{_arg}(\\s*,\\s*{_arg})*" # list of resolver arguments
32
- _resolver_inter = f"\\${{\\s*{_resolver_name}\\s*:\\s*{_args}?\\s*}}" # ${foo:bar}
33
- _inter = f"({_node_inter}|{_resolver_inter})" # any kind of interpolation
34
- _outer = "([^$]|\\$(?!{))+" # any character except $ (unless not followed by {)
35
- SIMPLE_INTERPOLATION_PATTERN = re.compile(
36
- f"({_outer})?({_inter}({_outer})?)+$", flags=re.ASCII
37
- )
38
- # NOTE: SIMPLE_INTERPOLATION_PATTERN must not generate false positive matches:
39
- # it must not accept anything that isn't a valid interpolation (per the
40
- # interpolation grammar defined in `omegaconf/grammar/*.g4`).
41
-
42
-
43
- class OmegaConfErrorListener(ErrorListener): # type: ignore
44
- def syntaxError(
45
- self,
46
- recognizer: Any,
47
- offending_symbol: Any,
48
- line: Any,
49
- column: Any,
50
- msg: Any,
51
- e: Any,
52
- ) -> None:
53
- raise GrammarParseError(str(e) if msg is None else msg) from e
54
-
55
- def reportAmbiguity(
56
- self,
57
- recognizer: Any,
58
- dfa: Any,
59
- startIndex: Any,
60
- stopIndex: Any,
61
- exact: Any,
62
- ambigAlts: Any,
63
- configs: Any,
64
- ) -> None:
65
- raise GrammarParseError("ANTLR error: Ambiguity") # pragma: no cover
66
-
67
- def reportAttemptingFullContext(
68
- self,
69
- recognizer: Any,
70
- dfa: Any,
71
- startIndex: Any,
72
- stopIndex: Any,
73
- conflictingAlts: Any,
74
- configs: Any,
75
- ) -> None:
76
- # Note: for now we raise an error to be safe. However this is mostly a
77
- # performance warning, so in the future this may be relaxed if we need
78
- # to change the grammar in such a way that this warning cannot be
79
- # avoided (another option would be to switch to SLL parsing mode).
80
- raise GrammarParseError(
81
- "ANTLR error: Attempting Full Context"
82
- ) # pragma: no cover
83
-
84
- def reportContextSensitivity(
85
- self,
86
- recognizer: Any,
87
- dfa: Any,
88
- startIndex: Any,
89
- stopIndex: Any,
90
- prediction: Any,
91
- configs: Any,
92
- ) -> None:
93
- raise GrammarParseError("ANTLR error: ContextSensitivity") # pragma: no cover
94
-
95
-
96
- def parse(
97
- value: str, parser_rule: str = "configValue", lexer_mode: str = "DEFAULT_MODE"
98
- ) -> ParserRuleContext:
99
- """
100
- Parse interpolated string `value` (and return the parse tree).
101
- """
102
- l_mode = getattr(OmegaConfGrammarLexer, lexer_mode)
103
- istream = InputStream(value)
104
-
105
- cached = getattr(_grammar_cache, "data", None)
106
- if cached is None:
107
- error_listener = OmegaConfErrorListener()
108
- lexer = OmegaConfGrammarLexer(istream)
109
- lexer.removeErrorListeners()
110
- lexer.addErrorListener(error_listener)
111
- lexer.mode(l_mode)
112
- token_stream = CommonTokenStream(lexer)
113
- parser = OmegaConfGrammarParser(token_stream)
114
- parser.removeErrorListeners()
115
- parser.addErrorListener(error_listener)
116
-
117
- # The two lines below could be enabled in the future if we decide to switch
118
- # to SLL prediction mode. Warning though, it has not been fully tested yet!
119
- # from antlr4 import PredictionMode
120
- # parser._interp.predictionMode = PredictionMode.SLL
121
-
122
- # Note that although the input stream `istream` is implicitly cached within
123
- # the lexer, it will be replaced by a new input next time the lexer is re-used.
124
- _grammar_cache.data = lexer, token_stream, parser
125
-
126
- else:
127
- lexer, token_stream, parser = cached
128
- # Replace the old input stream with the new one.
129
- lexer.inputStream = istream
130
- # Initialize the lexer / token stream / parser to process the new input.
131
- lexer.mode(l_mode)
132
- token_stream.setTokenSource(lexer)
133
- parser.reset()
134
-
135
- try:
136
- return getattr(parser, parser_rule)()
137
- except Exception as exc:
138
- if type(exc) is Exception and str(exc) == "Empty Stack":
139
- # This exception is raised by antlr when trying to pop a mode while
140
- # no mode has been pushed. We convert it into an `GrammarParseError`
141
- # to facilitate exception handling from the caller.
142
- raise GrammarParseError("Empty Stack")
143
- else:
144
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/grammar_visitor.py DELETED
@@ -1,392 +0,0 @@
1
- import sys
2
- import warnings
3
- from itertools import zip_longest
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- Generator,
10
- List,
11
- Optional,
12
- Set,
13
- Tuple,
14
- Union,
15
- )
16
-
17
- from antlr4 import TerminalNode
18
-
19
- from .errors import InterpolationResolutionError
20
-
21
- if TYPE_CHECKING:
22
- from .base import Node # noqa F401
23
-
24
- try:
25
- from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer
26
- from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
27
- from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import (
28
- OmegaConfGrammarParserVisitor,
29
- )
30
-
31
- except ModuleNotFoundError: # pragma: no cover
32
- print(
33
- "Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.",
34
- file=sys.stderr,
35
- )
36
- sys.exit(1)
37
-
38
-
39
- class GrammarVisitor(OmegaConfGrammarParserVisitor):
40
- def __init__(
41
- self,
42
- node_interpolation_callback: Callable[
43
- [str, Optional[Set[int]]],
44
- Optional["Node"],
45
- ],
46
- resolver_interpolation_callback: Callable[..., Any],
47
- memo: Optional[Set[int]],
48
- **kw: Dict[Any, Any],
49
- ):
50
- """
51
- Constructor.
52
-
53
- :param node_interpolation_callback: Callback function that is called when
54
- needing to resolve a node interpolation. This function should take a single
55
- string input which is the key's dot path (ex: `"foo.bar"`).
56
-
57
- :param resolver_interpolation_callback: Callback function that is called when
58
- needing to resolve a resolver interpolation. This function should accept
59
- three keyword arguments: `name` (str, the name of the resolver),
60
- `args` (tuple, the inputs to the resolver), and `args_str` (tuple,
61
- the string representation of the inputs to the resolver).
62
-
63
- :param kw: Additional keyword arguments to be forwarded to parent class.
64
- """
65
- super().__init__(**kw)
66
- self.node_interpolation_callback = node_interpolation_callback
67
- self.resolver_interpolation_callback = resolver_interpolation_callback
68
- self.memo = memo
69
-
70
- def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]:
71
- raise NotImplementedError
72
-
73
- def defaultResult(self) -> List[Any]:
74
- # Raising an exception because not currently used (like `aggregateResult()`).
75
- raise NotImplementedError
76
-
77
- def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
78
- from ._utils import _get_value
79
-
80
- # interpolation | ID | INTER_KEY
81
- assert ctx.getChildCount() == 1
82
- child = ctx.getChild(0)
83
- if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
84
- res = _get_value(self.visitInterpolation(child))
85
- if not isinstance(res, str):
86
- raise InterpolationResolutionError(
87
- f"The following interpolation is used to denote a config key and "
88
- f"thus should return a string, but instead returned `{res}` of "
89
- f"type `{type(res)}`: {ctx.getChild(0).getText()}"
90
- )
91
- return res
92
- else:
93
- assert isinstance(child, TerminalNode) and isinstance(
94
- child.symbol.text, str
95
- )
96
- return child.symbol.text
97
-
98
- def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
99
- # text EOF
100
- assert ctx.getChildCount() == 2
101
- return self.visit(ctx.getChild(0))
102
-
103
- def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any:
104
- return self._createPrimitive(ctx)
105
-
106
- def visitDictContainer(
107
- self, ctx: OmegaConfGrammarParser.DictContainerContext
108
- ) -> Dict[Any, Any]:
109
- # BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE
110
- assert ctx.getChildCount() >= 2
111
- return dict(
112
- self.visitDictKeyValuePair(ctx.getChild(i))
113
- for i in range(1, ctx.getChildCount() - 1, 2)
114
- )
115
-
116
- def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:
117
- # primitive | quotedValue | listContainer | dictContainer
118
- assert ctx.getChildCount() == 1
119
- return self.visit(ctx.getChild(0))
120
-
121
- def visitInterpolation(
122
- self, ctx: OmegaConfGrammarParser.InterpolationContext
123
- ) -> Any:
124
- assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
125
- return self.visit(ctx.getChild(0))
126
-
127
- def visitInterpolationNode(
128
- self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
129
- ) -> Optional["Node"]:
130
- # INTER_OPEN
131
- # DOT* // relative interpolation?
132
- # (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
133
- # (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
134
- # INTER_CLOSE;
135
-
136
- assert ctx.getChildCount() >= 3
137
-
138
- inter_key_tokens = [] # parsed elements of the dot path
139
- for child in ctx.getChildren():
140
- if isinstance(child, TerminalNode):
141
- s = child.symbol
142
- if s.type in [
143
- OmegaConfGrammarLexer.DOT,
144
- OmegaConfGrammarLexer.BRACKET_OPEN,
145
- OmegaConfGrammarLexer.BRACKET_CLOSE,
146
- ]:
147
- inter_key_tokens.append(s.text)
148
- else:
149
- assert s.type in (
150
- OmegaConfGrammarLexer.INTER_OPEN,
151
- OmegaConfGrammarLexer.INTER_CLOSE,
152
- )
153
- else:
154
- assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext)
155
- inter_key_tokens.append(self.visitConfigKey(child))
156
-
157
- inter_key = "".join(inter_key_tokens)
158
- return self.node_interpolation_callback(inter_key, self.memo)
159
-
160
- def visitInterpolationResolver(
161
- self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
162
- ) -> Any:
163
-
164
- # INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
165
- assert 4 <= ctx.getChildCount() <= 5
166
-
167
- resolver_name = self.visit(ctx.getChild(1))
168
- maybe_seq = ctx.getChild(3)
169
- args = []
170
- args_str = []
171
- if isinstance(maybe_seq, TerminalNode): # means there are no args
172
- assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE
173
- else:
174
- assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext)
175
- for val, txt in self.visitSequence(maybe_seq):
176
- args.append(val)
177
- args_str.append(txt)
178
-
179
- return self.resolver_interpolation_callback(
180
- name=resolver_name,
181
- args=tuple(args),
182
- args_str=tuple(args_str),
183
- )
184
-
185
- def visitDictKeyValuePair(
186
- self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext
187
- ) -> Tuple[Any, Any]:
188
- from ._utils import _get_value
189
-
190
- assert ctx.getChildCount() == 3 # dictKey COLON element
191
- key = self.visit(ctx.getChild(0))
192
- colon = ctx.getChild(1)
193
- assert (
194
- isinstance(colon, TerminalNode)
195
- and colon.symbol.type == OmegaConfGrammarLexer.COLON
196
- )
197
- value = _get_value(self.visitElement(ctx.getChild(2)))
198
- return key, value
199
-
200
- def visitListContainer(
201
- self, ctx: OmegaConfGrammarParser.ListContainerContext
202
- ) -> List[Any]:
203
- # BRACKET_OPEN sequence? BRACKET_CLOSE;
204
- assert ctx.getChildCount() in (2, 3)
205
- if ctx.getChildCount() == 2:
206
- return []
207
- sequence = ctx.getChild(1)
208
- assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext)
209
- return list(val for val, _ in self.visitSequence(sequence)) # ignore raw text
210
-
211
- def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any:
212
- return self._createPrimitive(ctx)
213
-
214
- def visitQuotedValue(self, ctx: OmegaConfGrammarParser.QuotedValueContext) -> str:
215
- # (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE
216
- n = ctx.getChildCount()
217
- assert n in [2, 3]
218
- return str(self.visit(ctx.getChild(1))) if n == 3 else ""
219
-
220
- def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> str:
221
- from ._utils import _get_value
222
-
223
- # (interpolation | ID) (DOT (interpolation | ID))*
224
- assert ctx.getChildCount() >= 1
225
- items = []
226
- for child in list(ctx.getChildren())[::2]:
227
- if isinstance(child, TerminalNode):
228
- assert child.symbol.type == OmegaConfGrammarLexer.ID
229
- items.append(child.symbol.text)
230
- else:
231
- assert isinstance(child, OmegaConfGrammarParser.InterpolationContext)
232
- item = _get_value(self.visitInterpolation(child))
233
- if not isinstance(item, str):
234
- raise InterpolationResolutionError(
235
- f"The name of a resolver must be a string, but the interpolation "
236
- f"{child.getText()} resolved to `{item}` which is of type "
237
- f"{type(item)}"
238
- )
239
- items.append(item)
240
- return ".".join(items)
241
-
242
- def visitSequence(
243
- self, ctx: OmegaConfGrammarParser.SequenceContext
244
- ) -> Generator[Any, None, None]:
245
- from ._utils import _get_value
246
-
247
- # (element (COMMA element?)*) | (COMMA element?)+
248
- assert ctx.getChildCount() >= 1
249
-
250
- # DEPRECATED: remove in 2.2 (revert #571)
251
- def empty_str_warning() -> None:
252
- txt = ctx.getText()
253
- warnings.warn(
254
- f"In the sequence `{txt}` some elements are missing: please replace "
255
- f"them with empty quoted strings. "
256
- f"See https://github.com/omry/omegaconf/issues/572 for details.",
257
- category=UserWarning,
258
- )
259
-
260
- is_previous_comma = True # whether previous child was a comma (init to True)
261
- for child in ctx.getChildren():
262
- if isinstance(child, OmegaConfGrammarParser.ElementContext):
263
- # Also preserve the original text representation of `child` so
264
- # as to allow backward compatibility with old resolvers (registered
265
- # with `legacy_register_resolver()`). Note that we cannot just cast
266
- # the value to string later as for instance `null` would become "None".
267
- yield _get_value(self.visitElement(child)), child.getText()
268
- is_previous_comma = False
269
- else:
270
- assert (
271
- isinstance(child, TerminalNode)
272
- and child.symbol.type == OmegaConfGrammarLexer.COMMA
273
- )
274
- if is_previous_comma:
275
- empty_str_warning()
276
- yield "", ""
277
- else:
278
- is_previous_comma = True
279
- if is_previous_comma:
280
- # Trailing comma.
281
- empty_str_warning()
282
- yield "", ""
283
-
284
- def visitSingleElement(
285
- self, ctx: OmegaConfGrammarParser.SingleElementContext
286
- ) -> Any:
287
- # element EOF
288
- assert ctx.getChildCount() == 2
289
- return self.visit(ctx.getChild(0))
290
-
291
- def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
292
- # (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+
293
-
294
- # Single interpolation? If yes, return its resolved value "as is".
295
- if ctx.getChildCount() == 1:
296
- c = ctx.getChild(0)
297
- if isinstance(c, OmegaConfGrammarParser.InterpolationContext):
298
- return self.visitInterpolation(c)
299
-
300
- # Otherwise, concatenate string representations together.
301
- return self._unescape(list(ctx.getChildren()))
302
-
303
- def _createPrimitive(
304
- self,
305
- ctx: Union[
306
- OmegaConfGrammarParser.PrimitiveContext,
307
- OmegaConfGrammarParser.DictKeyContext,
308
- ],
309
- ) -> Any:
310
- # (ID | NULL | INT | FLOAT | BOOL | UNQUOTED_CHAR | COLON | ESC | WS | interpolation)+
311
- if ctx.getChildCount() == 1:
312
- child = ctx.getChild(0)
313
- if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
314
- return self.visitInterpolation(child)
315
- assert isinstance(child, TerminalNode)
316
- symbol = child.symbol
317
- # Parse primitive types.
318
- if symbol.type in (
319
- OmegaConfGrammarLexer.ID,
320
- OmegaConfGrammarLexer.UNQUOTED_CHAR,
321
- OmegaConfGrammarLexer.COLON,
322
- ):
323
- return symbol.text
324
- elif symbol.type == OmegaConfGrammarLexer.NULL:
325
- return None
326
- elif symbol.type == OmegaConfGrammarLexer.INT:
327
- return int(symbol.text)
328
- elif symbol.type == OmegaConfGrammarLexer.FLOAT:
329
- return float(symbol.text)
330
- elif symbol.type == OmegaConfGrammarLexer.BOOL:
331
- return symbol.text.lower() == "true"
332
- elif symbol.type == OmegaConfGrammarLexer.ESC:
333
- return self._unescape([child])
334
- elif symbol.type == OmegaConfGrammarLexer.WS: # pragma: no cover
335
- # A single WS should have been "consumed" by another token.
336
- raise AssertionError("WS should never be reached")
337
- assert False, symbol.type
338
- # Concatenation of multiple items ==> un-escape the concatenation.
339
- return self._unescape(list(ctx.getChildren()))
340
-
341
- def _unescape(
342
- self,
343
- seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
344
- ) -> str:
345
- """
346
- Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed.
347
-
348
- Interpolations are resolved and cast to string *WITHOUT* escaping their result
349
- (it is assumed that whatever escaping is required was already handled during the
350
- resolving of the interpolation).
351
- """
352
- chrs = []
353
- for node, next_node in zip_longest(seq, seq[1:]):
354
- if isinstance(node, TerminalNode):
355
- s = node.symbol
356
- if s.type == OmegaConfGrammarLexer.ESC_INTER:
357
- # `ESC_INTER` is of the form `\\...\${`: the formula below computes
358
- # the number of characters to keep at the end of the string to remove
359
- # the correct number of backslashes.
360
- text = s.text[-(len(s.text) // 2 + 1) :]
361
- elif (
362
- # Character sequence identified as requiring un-escaping.
363
- s.type == OmegaConfGrammarLexer.ESC
364
- or (
365
- # At top level, we need to un-escape backslashes that precede
366
- # an interpolation.
367
- s.type == OmegaConfGrammarLexer.TOP_ESC
368
- and isinstance(
369
- next_node, OmegaConfGrammarParser.InterpolationContext
370
- )
371
- )
372
- or (
373
- # In a quoted sring, we need to un-escape backslashes that
374
- # either end the string, or are followed by an interpolation.
375
- s.type == OmegaConfGrammarLexer.QUOTED_ESC
376
- and (
377
- next_node is None
378
- or isinstance(
379
- next_node, OmegaConfGrammarParser.InterpolationContext
380
- )
381
- )
382
- )
383
- ):
384
- text = s.text[1::2] # un-escape the sequence
385
- else:
386
- text = s.text # keep the original text
387
- else:
388
- assert isinstance(node, OmegaConfGrammarParser.InterpolationContext)
389
- text = str(self.visitInterpolation(node))
390
- chrs.append(text)
391
-
392
- return "".join(chrs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/listconfig.py DELETED
@@ -1,679 +0,0 @@
1
- import copy
2
- import itertools
3
- from typing import (
4
- Any,
5
- Callable,
6
- Dict,
7
- Iterable,
8
- Iterator,
9
- List,
10
- MutableSequence,
11
- Optional,
12
- Tuple,
13
- Type,
14
- Union,
15
- )
16
-
17
- from ._utils import (
18
- ValueKind,
19
- _is_missing_literal,
20
- _is_none,
21
- _resolve_optional,
22
- format_and_raise,
23
- get_value_kind,
24
- is_int,
25
- is_primitive_list,
26
- is_structured_config,
27
- type_str,
28
- )
29
- from .base import Box, ContainerMetadata, Node
30
- from .basecontainer import BaseContainer
31
- from .errors import (
32
- ConfigAttributeError,
33
- ConfigTypeError,
34
- ConfigValueError,
35
- KeyValidationError,
36
- MissingMandatoryValue,
37
- ReadonlyConfigError,
38
- ValidationError,
39
- )
40
-
41
-
42
- class ListConfig(BaseContainer, MutableSequence[Any]):
43
-
44
- _content: Union[List[Node], None, str]
45
-
46
- def __init__(
47
- self,
48
- content: Union[List[Any], Tuple[Any, ...], "ListConfig", str, None],
49
- key: Any = None,
50
- parent: Optional[Box] = None,
51
- element_type: Union[Type[Any], Any] = Any,
52
- is_optional: bool = True,
53
- ref_type: Union[Type[Any], Any] = Any,
54
- flags: Optional[Dict[str, bool]] = None,
55
- ) -> None:
56
- try:
57
- if isinstance(content, ListConfig):
58
- if flags is None:
59
- flags = content._metadata.flags
60
- super().__init__(
61
- parent=parent,
62
- metadata=ContainerMetadata(
63
- ref_type=ref_type,
64
- object_type=list,
65
- key=key,
66
- optional=is_optional,
67
- element_type=element_type,
68
- key_type=int,
69
- flags=flags,
70
- ),
71
- )
72
-
73
- if isinstance(content, ListConfig):
74
- metadata = copy.deepcopy(content._metadata)
75
- metadata.key = key
76
- metadata.ref_type = ref_type
77
- metadata.optional = is_optional
78
- metadata.element_type = element_type
79
- self.__dict__["_metadata"] = metadata
80
- self._set_value(value=content, flags=flags)
81
- except Exception as ex:
82
- format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
83
-
84
- def _validate_get(self, key: Any, value: Any = None) -> None:
85
- if not isinstance(key, (int, slice)):
86
- raise KeyValidationError(
87
- "ListConfig indices must be integers or slices, not $KEY_TYPE"
88
- )
89
-
90
- def _validate_set(self, key: Any, value: Any) -> None:
91
- from omegaconf import OmegaConf
92
-
93
- self._validate_get(key, value)
94
-
95
- if self._get_flag("readonly"):
96
- raise ReadonlyConfigError("ListConfig is read-only")
97
-
98
- if 0 <= key < self.__len__():
99
- target = self._get_node(key)
100
- if target is not None:
101
- assert isinstance(target, Node)
102
- if value is None and not target._is_optional():
103
- raise ValidationError(
104
- "$FULL_KEY is not optional and cannot be assigned None"
105
- )
106
-
107
- vk = get_value_kind(value)
108
- if vk == ValueKind.MANDATORY_MISSING:
109
- return
110
- else:
111
- is_optional, target_type = _resolve_optional(self._metadata.element_type)
112
- value_type = OmegaConf.get_type(value)
113
-
114
- if (value_type is None and not is_optional) or (
115
- is_structured_config(target_type)
116
- and value_type is not None
117
- and not issubclass(value_type, target_type)
118
- ):
119
- msg = (
120
- f"Invalid type assigned: {type_str(value_type)} is not a "
121
- f"subclass of {type_str(target_type)}. value: {value}"
122
- )
123
- raise ValidationError(msg)
124
-
125
- def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
126
- res = ListConfig(None)
127
- res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
128
- res.__dict__["_flags_cache"] = copy.deepcopy(
129
- self.__dict__["_flags_cache"], memo=memo
130
- )
131
-
132
- src_content = self.__dict__["_content"]
133
- if isinstance(src_content, list):
134
- content_copy: List[Optional[Node]] = []
135
- for v in src_content:
136
- old_parent = v.__dict__["_parent"]
137
- try:
138
- v.__dict__["_parent"] = None
139
- vc = copy.deepcopy(v, memo=memo)
140
- vc.__dict__["_parent"] = res
141
- content_copy.append(vc)
142
- finally:
143
- v.__dict__["_parent"] = old_parent
144
- else:
145
- # None and strings can be assigned as is
146
- content_copy = src_content
147
-
148
- res.__dict__["_content"] = content_copy
149
- res.__dict__["_parent"] = self.__dict__["_parent"]
150
-
151
- return res
152
-
153
- def copy(self) -> "ListConfig":
154
- return copy.copy(self)
155
-
156
- # hide content while inspecting in debugger
157
- def __dir__(self) -> Iterable[str]:
158
- if self._is_missing() or self._is_none():
159
- return []
160
- return [str(x) for x in range(0, len(self))]
161
-
162
- def __setattr__(self, key: str, value: Any) -> None:
163
- self._format_and_raise(
164
- key=key,
165
- value=value,
166
- cause=ConfigAttributeError("ListConfig does not support attribute access"),
167
- )
168
- assert False
169
-
170
- def __getattr__(self, key: str) -> Any:
171
- # PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that.
172
- if key == "__members__":
173
- raise AttributeError()
174
-
175
- if key == "__name__":
176
- raise AttributeError()
177
-
178
- if is_int(key):
179
- return self.__getitem__(int(key))
180
- else:
181
- self._format_and_raise(
182
- key=key,
183
- value=None,
184
- cause=ConfigAttributeError(
185
- "ListConfig does not support attribute access"
186
- ),
187
- )
188
-
189
- def __getitem__(self, index: Union[int, slice]) -> Any:
190
- try:
191
- if self._is_missing():
192
- raise MissingMandatoryValue("ListConfig is missing")
193
- self._validate_get(index, None)
194
- if self._is_none():
195
- raise TypeError(
196
- "ListConfig object representing None is not subscriptable"
197
- )
198
-
199
- assert isinstance(self.__dict__["_content"], list)
200
- if isinstance(index, slice):
201
- result = []
202
- start, stop, step = self._correct_index_params(index)
203
- for slice_idx in itertools.islice(
204
- range(0, len(self)), start, stop, step
205
- ):
206
- val = self._resolve_with_default(
207
- key=slice_idx, value=self.__dict__["_content"][slice_idx]
208
- )
209
- result.append(val)
210
- if index.step and index.step < 0:
211
- result.reverse()
212
- return result
213
- else:
214
- return self._resolve_with_default(
215
- key=index, value=self.__dict__["_content"][index]
216
- )
217
- except Exception as e:
218
- self._format_and_raise(key=index, value=None, cause=e)
219
-
220
- def _correct_index_params(self, index: slice) -> Tuple[int, int, int]:
221
- start = index.start
222
- stop = index.stop
223
- step = index.step
224
- if index.start and index.start < 0:
225
- start = self.__len__() + index.start
226
- if index.stop and index.stop < 0:
227
- stop = self.__len__() + index.stop
228
- if index.step and index.step < 0:
229
- step = abs(step)
230
- if start and stop:
231
- if start > stop:
232
- start, stop = stop + 1, start + 1
233
- else:
234
- start = stop = 0
235
- elif not start and stop:
236
- start = list(range(self.__len__() - 1, stop, -step))[0]
237
- stop = None
238
- elif start and not stop:
239
- stop = start + 1
240
- start = (stop - 1) % step
241
- else:
242
- start = (self.__len__() - 1) % step
243
- return start, stop, step
244
-
245
- def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
246
- self._set_item_impl(index, value)
247
-
248
- def __setitem__(self, index: Union[int, slice], value: Any) -> None:
249
- try:
250
- if isinstance(index, slice):
251
- _ = iter(value) # check iterable
252
- self_indices = index.indices(len(self))
253
- indexes = range(*self_indices)
254
-
255
- # Ensure lengths match for extended slice assignment
256
- if index.step not in (None, 1):
257
- if len(indexes) != len(value):
258
- raise ValueError(
259
- f"attempt to assign sequence of size {len(value)}"
260
- f" to extended slice of size {len(indexes)}"
261
- )
262
-
263
- # Initialize insertion offsets for empty slices
264
- if len(indexes) == 0:
265
- curr_index = self_indices[0] - 1
266
- val_i = -1
267
-
268
- work_copy = self.copy() # For atomicity manipulate a copy
269
-
270
- # Delete and optionally replace non empty slices
271
- only_removed = 0
272
- for val_i, i in enumerate(indexes):
273
- curr_index = i - only_removed
274
- del work_copy[curr_index]
275
- if val_i < len(value):
276
- work_copy.insert(curr_index, value[val_i])
277
- else:
278
- only_removed += 1
279
-
280
- # Insert any remaining input items
281
- for val_i in range(val_i + 1, len(value)):
282
- curr_index += 1
283
- work_copy.insert(curr_index, value[val_i])
284
-
285
- # Reinitialize self with work_copy
286
- self.clear()
287
- self.extend(work_copy)
288
- else:
289
- self._set_at_index(index, value)
290
- except Exception as e:
291
- self._format_and_raise(key=index, value=value, cause=e)
292
-
293
- def append(self, item: Any) -> None:
294
- content = self.__dict__["_content"]
295
- index = len(content)
296
- content.append(None)
297
- try:
298
- self._set_item_impl(index, item)
299
- except Exception as e:
300
- del content[index]
301
- self._format_and_raise(key=index, value=item, cause=e)
302
- assert False
303
-
304
- def _update_keys(self) -> None:
305
- for i in range(len(self)):
306
- node = self._get_node(i)
307
- if node is not None:
308
- assert isinstance(node, Node)
309
- node._metadata.key = i
310
-
311
- def insert(self, index: int, item: Any) -> None:
312
- from omegaconf.omegaconf import _maybe_wrap
313
-
314
- try:
315
- if self._get_flag("readonly"):
316
- raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
317
- if self._is_none():
318
- raise TypeError(
319
- "Cannot insert into ListConfig object representing None"
320
- )
321
- if self._is_missing():
322
- raise MissingMandatoryValue("Cannot insert into missing ListConfig")
323
-
324
- try:
325
- assert isinstance(self.__dict__["_content"], list)
326
- # insert place holder
327
- self.__dict__["_content"].insert(index, None)
328
- is_optional, ref_type = _resolve_optional(self._metadata.element_type)
329
- node = _maybe_wrap(
330
- ref_type=ref_type,
331
- key=index,
332
- value=item,
333
- is_optional=is_optional,
334
- parent=self,
335
- )
336
- self._validate_set(key=index, value=node)
337
- self._set_at_index(index, node)
338
- self._update_keys()
339
- except Exception:
340
- del self.__dict__["_content"][index]
341
- self._update_keys()
342
- raise
343
- except Exception as e:
344
- self._format_and_raise(key=index, value=item, cause=e)
345
- assert False
346
-
347
- def extend(self, lst: Iterable[Any]) -> None:
348
- assert isinstance(lst, (tuple, list, ListConfig))
349
- for x in lst:
350
- self.append(x)
351
-
352
- def remove(self, x: Any) -> None:
353
- del self[self.index(x)]
354
-
355
- def __delitem__(self, key: Union[int, slice]) -> None:
356
- if self._get_flag("readonly"):
357
- self._format_and_raise(
358
- key=key,
359
- value=None,
360
- cause=ReadonlyConfigError(
361
- "Cannot delete item from read-only ListConfig"
362
- ),
363
- )
364
- del self.__dict__["_content"][key]
365
- self._update_keys()
366
-
367
- def clear(self) -> None:
368
- del self[:]
369
-
370
- def index(
371
- self, x: Any, start: Optional[int] = None, end: Optional[int] = None
372
- ) -> int:
373
- if start is None:
374
- start = 0
375
- if end is None:
376
- end = len(self)
377
- assert start >= 0
378
- assert end <= len(self)
379
- found_idx = -1
380
- for idx in range(start, end):
381
- item = self[idx]
382
- if x == item:
383
- found_idx = idx
384
- break
385
- if found_idx != -1:
386
- return found_idx
387
- else:
388
- self._format_and_raise(
389
- key=None,
390
- value=None,
391
- cause=ConfigValueError("Item not found in ListConfig"),
392
- )
393
- assert False
394
-
395
- def count(self, x: Any) -> int:
396
- c = 0
397
- for item in self:
398
- if item == x:
399
- c = c + 1
400
- return c
401
-
402
- def _get_node(
403
- self,
404
- key: Union[int, slice],
405
- validate_access: bool = True,
406
- validate_key: bool = True,
407
- throw_on_missing_value: bool = False,
408
- throw_on_missing_key: bool = False,
409
- ) -> Union[Optional[Node], List[Optional[Node]]]:
410
- try:
411
- if self._is_none():
412
- raise TypeError(
413
- "Cannot get_node from a ListConfig object representing None"
414
- )
415
- if self._is_missing():
416
- raise MissingMandatoryValue("Cannot get_node from a missing ListConfig")
417
- assert isinstance(self.__dict__["_content"], list)
418
- if validate_access:
419
- self._validate_get(key)
420
-
421
- value = self.__dict__["_content"][key]
422
- if value is not None:
423
- if isinstance(key, slice):
424
- assert isinstance(value, list)
425
- for v in value:
426
- if throw_on_missing_value and v._is_missing():
427
- raise MissingMandatoryValue("Missing mandatory value")
428
- else:
429
- assert isinstance(value, Node)
430
- if throw_on_missing_value and value._is_missing():
431
- raise MissingMandatoryValue("Missing mandatory value: $KEY")
432
- return value
433
- except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
434
- if isinstance(e, MissingMandatoryValue) and throw_on_missing_value:
435
- raise
436
- if validate_access:
437
- self._format_and_raise(key=key, value=None, cause=e)
438
- assert False
439
- else:
440
- return None
441
-
442
- def get(self, index: int, default_value: Any = None) -> Any:
443
- try:
444
- if self._is_none():
445
- raise TypeError("Cannot get from a ListConfig object representing None")
446
- if self._is_missing():
447
- raise MissingMandatoryValue("Cannot get from a missing ListConfig")
448
- self._validate_get(index, None)
449
- assert isinstance(self.__dict__["_content"], list)
450
- return self._resolve_with_default(
451
- key=index,
452
- value=self.__dict__["_content"][index],
453
- default_value=default_value,
454
- )
455
- except Exception as e:
456
- self._format_and_raise(key=index, value=None, cause=e)
457
- assert False
458
-
459
- def pop(self, index: int = -1) -> Any:
460
- try:
461
- if self._get_flag("readonly"):
462
- raise ReadonlyConfigError("Cannot pop from read-only ListConfig")
463
- if self._is_none():
464
- raise TypeError("Cannot pop from a ListConfig object representing None")
465
- if self._is_missing():
466
- raise MissingMandatoryValue("Cannot pop from a missing ListConfig")
467
-
468
- assert isinstance(self.__dict__["_content"], list)
469
- node = self._get_child(index)
470
- assert isinstance(node, Node)
471
- ret = self._resolve_with_default(key=index, value=node, default_value=None)
472
- del self.__dict__["_content"][index]
473
- self._update_keys()
474
- return ret
475
- except KeyValidationError as e:
476
- self._format_and_raise(
477
- key=index, value=None, cause=e, type_override=ConfigTypeError
478
- )
479
- assert False
480
- except Exception as e:
481
- self._format_and_raise(key=index, value=None, cause=e)
482
- assert False
483
-
484
- def sort(
485
- self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False
486
- ) -> None:
487
- try:
488
- if self._get_flag("readonly"):
489
- raise ReadonlyConfigError("Cannot sort a read-only ListConfig")
490
- if self._is_none():
491
- raise TypeError("Cannot sort a ListConfig object representing None")
492
- if self._is_missing():
493
- raise MissingMandatoryValue("Cannot sort a missing ListConfig")
494
-
495
- if key is None:
496
-
497
- def key1(x: Any) -> Any:
498
- return x._value()
499
-
500
- else:
501
-
502
- def key1(x: Any) -> Any:
503
- return key(x._value()) # type: ignore
504
-
505
- assert isinstance(self.__dict__["_content"], list)
506
- self.__dict__["_content"].sort(key=key1, reverse=reverse)
507
-
508
- except Exception as e:
509
- self._format_and_raise(key=None, value=None, cause=e)
510
- assert False
511
-
512
- def __eq__(self, other: Any) -> bool:
513
- if isinstance(other, (list, tuple)) or other is None:
514
- other = ListConfig(other, flags={"allow_objects": True})
515
- return ListConfig._list_eq(self, other)
516
- if other is None or isinstance(other, ListConfig):
517
- return ListConfig._list_eq(self, other)
518
- if self._is_missing():
519
- return _is_missing_literal(other)
520
- return NotImplemented
521
-
522
- def __ne__(self, other: Any) -> bool:
523
- x = self.__eq__(other)
524
- if x is not NotImplemented:
525
- return not x
526
- return NotImplemented
527
-
528
- def __hash__(self) -> int:
529
- return hash(str(self))
530
-
531
- def __iter__(self) -> Iterator[Any]:
532
- return self._iter_ex(resolve=True)
533
-
534
- class ListIterator(Iterator[Any]):
535
- def __init__(self, lst: Any, resolve: bool) -> None:
536
- self.resolve = resolve
537
- self.iterator = iter(lst.__dict__["_content"])
538
- self.index = 0
539
- from .nodes import ValueNode
540
-
541
- self.ValueNode = ValueNode
542
-
543
- def __next__(self) -> Any:
544
-
545
- x = next(self.iterator)
546
- if self.resolve:
547
- x = x._dereference_node()
548
- if x._is_missing():
549
- raise MissingMandatoryValue(f"Missing value at index {self.index}")
550
-
551
- self.index = self.index + 1
552
- if isinstance(x, self.ValueNode):
553
- return x._value()
554
- else:
555
- # Must be omegaconf.Container. not checking for perf reasons.
556
- if x._is_none():
557
- return None
558
- return x
559
-
560
- def __repr__(self) -> str: # pragma: no cover
561
- return f"ListConfig.ListIterator(resolve={self.resolve})"
562
-
563
- def _iter_ex(self, resolve: bool) -> Iterator[Any]:
564
- try:
565
- if self._is_none():
566
- raise TypeError("Cannot iterate a ListConfig object representing None")
567
- if self._is_missing():
568
- raise MissingMandatoryValue("Cannot iterate a missing ListConfig")
569
-
570
- return ListConfig.ListIterator(self, resolve)
571
- except (TypeError, MissingMandatoryValue) as e:
572
- self._format_and_raise(key=None, value=None, cause=e)
573
- assert False
574
-
575
- def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
576
- # res is sharing this list's parent to allow interpolation to work as expected
577
- res = ListConfig(parent=self._get_parent(), content=[])
578
- res.extend(self)
579
- res.extend(other)
580
- return res
581
-
582
- def __radd__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
583
- # res is sharing this list's parent to allow interpolation to work as expected
584
- res = ListConfig(parent=self._get_parent(), content=[])
585
- res.extend(other)
586
- res.extend(self)
587
- return res
588
-
589
- def __iadd__(self, other: Iterable[Any]) -> "ListConfig":
590
- self.extend(other)
591
- return self
592
-
593
- def __contains__(self, item: Any) -> bool:
594
- if self._is_none():
595
- raise TypeError(
596
- "Cannot check if an item is in a ListConfig object representing None"
597
- )
598
- if self._is_missing():
599
- raise MissingMandatoryValue(
600
- "Cannot check if an item is in missing ListConfig"
601
- )
602
-
603
- lst = self.__dict__["_content"]
604
- for x in lst:
605
- x = x._dereference_node()
606
- if x == item:
607
- return True
608
- return False
609
-
610
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
611
- try:
612
- previous_content = self.__dict__["_content"]
613
- previous_metadata = self.__dict__["_metadata"]
614
- self._set_value_impl(value, flags)
615
- except Exception as e:
616
- self.__dict__["_content"] = previous_content
617
- self.__dict__["_metadata"] = previous_metadata
618
- raise e
619
-
620
- def _set_value_impl(
621
- self, value: Any, flags: Optional[Dict[str, bool]] = None
622
- ) -> None:
623
- from omegaconf import MISSING, flag_override
624
-
625
- if flags is None:
626
- flags = {}
627
-
628
- vk = get_value_kind(value, strict_interpolation_validation=True)
629
- if _is_none(value):
630
- if not self._is_optional():
631
- raise ValidationError(
632
- "Non optional ListConfig cannot be constructed from None"
633
- )
634
- self.__dict__["_content"] = None
635
- self._metadata.object_type = None
636
- elif vk is ValueKind.MANDATORY_MISSING:
637
- self.__dict__["_content"] = MISSING
638
- self._metadata.object_type = None
639
- elif vk == ValueKind.INTERPOLATION:
640
- self.__dict__["_content"] = value
641
- self._metadata.object_type = None
642
- else:
643
- if not (is_primitive_list(value) or isinstance(value, ListConfig)):
644
- type_ = type(value)
645
- msg = f"Invalid value assigned: {type_.__name__} is not a ListConfig, list or tuple."
646
- raise ValidationError(msg)
647
-
648
- self.__dict__["_content"] = []
649
- if isinstance(value, ListConfig):
650
- self._metadata.flags = copy.deepcopy(flags)
651
- # disable struct and readonly for the construction phase
652
- # retaining other flags like allow_objects. The real flags are restored at the end of this function
653
- with flag_override(self, ["struct", "readonly"], False):
654
- for item in value._iter_ex(resolve=False):
655
- self.append(item)
656
- elif is_primitive_list(value):
657
- with flag_override(self, ["struct", "readonly"], False):
658
- for item in value:
659
- self.append(item)
660
- self._metadata.object_type = list
661
-
662
- @staticmethod
663
- def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
664
- l1_none = l1.__dict__["_content"] is None
665
- l2_none = l2.__dict__["_content"] is None
666
- if l1_none and l2_none:
667
- return True
668
- if l1_none != l2_none:
669
- return False
670
-
671
- assert isinstance(l1, ListConfig)
672
- assert isinstance(l2, ListConfig)
673
- if len(l1) != len(l2):
674
- return False
675
- for i in range(len(l1)):
676
- if not BaseContainer._item_eq(l1, i, l2, i):
677
- return False
678
-
679
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/nodes.py DELETED
@@ -1,545 +0,0 @@
1
- import copy
2
- import math
3
- import sys
4
- from abc import abstractmethod
5
- from enum import Enum
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Type, Union
8
-
9
- from omegaconf._utils import (
10
- ValueKind,
11
- _is_interpolation,
12
- get_type_of,
13
- get_value_kind,
14
- is_primitive_container,
15
- type_str,
16
- )
17
- from omegaconf.base import Box, DictKeyType, Metadata, Node
18
- from omegaconf.errors import ReadonlyConfigError, UnsupportedValueType, ValidationError
19
-
20
-
21
- class ValueNode(Node):
22
- _val: Any
23
-
24
- def __init__(self, parent: Optional[Box], value: Any, metadata: Metadata):
25
- from omegaconf import read_write
26
-
27
- super().__init__(parent=parent, metadata=metadata)
28
- with read_write(self):
29
- self._set_value(value) # lgtm [py/init-calls-subclass]
30
-
31
- def _value(self) -> Any:
32
- return self._val
33
-
34
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
35
- if self._get_flag("readonly"):
36
- raise ReadonlyConfigError("Cannot set value of read-only config node")
37
-
38
- if isinstance(value, str) and get_value_kind(
39
- value, strict_interpolation_validation=True
40
- ) in (
41
- ValueKind.INTERPOLATION,
42
- ValueKind.MANDATORY_MISSING,
43
- ):
44
- self._val = value
45
- else:
46
- self._val = self.validate_and_convert(value)
47
-
48
- def _strict_validate_type(self, value: Any) -> None:
49
- ref_type = self._metadata.ref_type
50
- if isinstance(ref_type, type) and type(value) is not ref_type:
51
- type_hint = type_str(self._metadata.type_hint)
52
- raise ValidationError(
53
- f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_hint}'"
54
- )
55
-
56
- def validate_and_convert(self, value: Any) -> Any:
57
- """
58
- Validates input and converts to canonical form
59
- :param value: input value
60
- :return: converted value ("100" may be converted to 100 for example)
61
- """
62
- if value is None:
63
- if self._is_optional():
64
- return None
65
- ref_type_str = type_str(self._metadata.ref_type)
66
- raise ValidationError(
67
- f"Incompatible value '{value}' for field of type '{ref_type_str}'"
68
- )
69
-
70
- # Subclasses can assume that `value` is not None in
71
- # `_validate_and_convert_impl()` and in `_strict_validate_type()`.
72
- if self._get_flag("convert") is False:
73
- self._strict_validate_type(value)
74
- return value
75
- else:
76
- return self._validate_and_convert_impl(value)
77
-
78
- @abstractmethod
79
- def _validate_and_convert_impl(self, value: Any) -> Any:
80
- ...
81
-
82
- def __str__(self) -> str:
83
- return str(self._val)
84
-
85
- def __repr__(self) -> str:
86
- return repr(self._val) if hasattr(self, "_val") else "__INVALID__"
87
-
88
- def __eq__(self, other: Any) -> bool:
89
- if isinstance(other, AnyNode):
90
- return self._val == other._val # type: ignore
91
- else:
92
- return self._val == other # type: ignore
93
-
94
- def __ne__(self, other: Any) -> bool:
95
- x = self.__eq__(other)
96
- assert x is not NotImplemented
97
- return not x
98
-
99
- def __hash__(self) -> int:
100
- return hash(self._val)
101
-
102
- def _deepcopy_impl(self, res: Any, memo: Dict[int, Any]) -> None:
103
- res.__dict__["_metadata"] = copy.deepcopy(self._metadata, memo=memo)
104
- # shallow copy for value to support non-copyable value
105
- res.__dict__["_val"] = self._val
106
-
107
- # parent is retained, but not copied
108
- res.__dict__["_parent"] = self._parent
109
-
110
- def _is_optional(self) -> bool:
111
- return self._metadata.optional
112
-
113
- def _is_interpolation(self) -> bool:
114
- return _is_interpolation(self._value())
115
-
116
- def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
117
- parent = self._get_parent()
118
- if parent is None:
119
- if self._metadata.key is None:
120
- return ""
121
- else:
122
- return str(self._metadata.key)
123
- else:
124
- return parent._get_full_key(self._metadata.key)
125
-
126
-
127
- class AnyNode(ValueNode):
128
- def __init__(
129
- self,
130
- value: Any = None,
131
- key: Any = None,
132
- parent: Optional[Box] = None,
133
- flags: Optional[Dict[str, bool]] = None,
134
- ):
135
- super().__init__(
136
- parent=parent,
137
- value=value,
138
- metadata=Metadata(
139
- ref_type=Any, object_type=None, key=key, optional=True, flags=flags
140
- ),
141
- )
142
-
143
- def _validate_and_convert_impl(self, value: Any) -> Any:
144
- from ._utils import is_primitive_type_annotation
145
-
146
- # allow_objects is internal and not an official API. use at your own risk.
147
- # Please be aware that this support is subject to change without notice.
148
- # If this is deemed useful and supportable it may become an official API.
149
-
150
- if self._get_flag(
151
- "allow_objects"
152
- ) is not True and not is_primitive_type_annotation(value):
153
- t = get_type_of(value)
154
- raise UnsupportedValueType(
155
- f"Value '{t.__name__}' is not a supported primitive type"
156
- )
157
- return value
158
-
159
- def __deepcopy__(self, memo: Dict[int, Any]) -> "AnyNode":
160
- res = AnyNode()
161
- self._deepcopy_impl(res, memo)
162
- return res
163
-
164
-
165
- class StringNode(ValueNode):
166
- def __init__(
167
- self,
168
- value: Any = None,
169
- key: Any = None,
170
- parent: Optional[Box] = None,
171
- is_optional: bool = True,
172
- flags: Optional[Dict[str, bool]] = None,
173
- ):
174
- super().__init__(
175
- parent=parent,
176
- value=value,
177
- metadata=Metadata(
178
- key=key,
179
- optional=is_optional,
180
- ref_type=str,
181
- object_type=str,
182
- flags=flags,
183
- ),
184
- )
185
-
186
- def _validate_and_convert_impl(self, value: Any) -> str:
187
- from omegaconf import OmegaConf
188
-
189
- if (
190
- OmegaConf.is_config(value)
191
- or is_primitive_container(value)
192
- or isinstance(value, bytes)
193
- ):
194
- raise ValidationError("Cannot convert '$VALUE_TYPE' to string: '$VALUE'")
195
- return str(value)
196
-
197
- def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode":
198
- res = StringNode()
199
- self._deepcopy_impl(res, memo)
200
- return res
201
-
202
-
203
- class PathNode(ValueNode):
204
- def __init__(
205
- self,
206
- value: Any = None,
207
- key: Any = None,
208
- parent: Optional[Box] = None,
209
- is_optional: bool = True,
210
- flags: Optional[Dict[str, bool]] = None,
211
- ):
212
- super().__init__(
213
- parent=parent,
214
- value=value,
215
- metadata=Metadata(
216
- key=key,
217
- optional=is_optional,
218
- ref_type=Path,
219
- object_type=Path,
220
- flags=flags,
221
- ),
222
- )
223
-
224
- def _strict_validate_type(self, value: Any) -> None:
225
- if not isinstance(value, Path):
226
- raise ValidationError(
227
- "Value '$VALUE' of type '$VALUE_TYPE' is not an instance of 'pathlib.Path'"
228
- )
229
-
230
- def _validate_and_convert_impl(self, value: Any) -> Path:
231
- if not isinstance(value, (str, Path)):
232
- raise ValidationError(
233
- "Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Path"
234
- )
235
-
236
- return Path(value)
237
-
238
- def __deepcopy__(self, memo: Dict[int, Any]) -> "PathNode":
239
- res = PathNode()
240
- self._deepcopy_impl(res, memo)
241
- return res
242
-
243
-
244
- class IntegerNode(ValueNode):
245
- def __init__(
246
- self,
247
- value: Any = None,
248
- key: Any = None,
249
- parent: Optional[Box] = None,
250
- is_optional: bool = True,
251
- flags: Optional[Dict[str, bool]] = None,
252
- ):
253
- super().__init__(
254
- parent=parent,
255
- value=value,
256
- metadata=Metadata(
257
- key=key,
258
- optional=is_optional,
259
- ref_type=int,
260
- object_type=int,
261
- flags=flags,
262
- ),
263
- )
264
-
265
- def _validate_and_convert_impl(self, value: Any) -> int:
266
- try:
267
- if type(value) in (str, int):
268
- val = int(value)
269
- else:
270
- raise ValueError()
271
- except ValueError:
272
- raise ValidationError(
273
- "Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Integer"
274
- )
275
- return val
276
-
277
- def __deepcopy__(self, memo: Dict[int, Any]) -> "IntegerNode":
278
- res = IntegerNode()
279
- self._deepcopy_impl(res, memo)
280
- return res
281
-
282
-
283
- class BytesNode(ValueNode):
284
- def __init__(
285
- self,
286
- value: Any = None,
287
- key: Any = None,
288
- parent: Optional[Box] = None,
289
- is_optional: bool = True,
290
- flags: Optional[Dict[str, bool]] = None,
291
- ):
292
- super().__init__(
293
- parent=parent,
294
- value=value,
295
- metadata=Metadata(
296
- key=key,
297
- optional=is_optional,
298
- ref_type=bytes,
299
- object_type=bytes,
300
- flags=flags,
301
- ),
302
- )
303
-
304
- def _validate_and_convert_impl(self, value: Any) -> bytes:
305
- if not isinstance(value, bytes):
306
- raise ValidationError(
307
- "Value '$VALUE' of type '$VALUE_TYPE' is not of type 'bytes'"
308
- )
309
- return value
310
-
311
- def __deepcopy__(self, memo: Dict[int, Any]) -> "BytesNode":
312
- res = BytesNode()
313
- self._deepcopy_impl(res, memo)
314
- return res
315
-
316
-
317
- class FloatNode(ValueNode):
318
- def __init__(
319
- self,
320
- value: Any = None,
321
- key: Any = None,
322
- parent: Optional[Box] = None,
323
- is_optional: bool = True,
324
- flags: Optional[Dict[str, bool]] = None,
325
- ):
326
- super().__init__(
327
- parent=parent,
328
- value=value,
329
- metadata=Metadata(
330
- key=key,
331
- optional=is_optional,
332
- ref_type=float,
333
- object_type=float,
334
- flags=flags,
335
- ),
336
- )
337
-
338
- def _validate_and_convert_impl(self, value: Any) -> float:
339
- try:
340
- if type(value) in (float, str, int):
341
- return float(value)
342
- else:
343
- raise ValueError()
344
- except ValueError:
345
- raise ValidationError(
346
- "Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Float"
347
- )
348
-
349
- def __eq__(self, other: Any) -> bool:
350
- if isinstance(other, ValueNode):
351
- other_val = other._val
352
- else:
353
- other_val = other
354
- if self._val is None and other is None:
355
- return True
356
- if self._val is None and other is not None:
357
- return False
358
- if self._val is not None and other is None:
359
- return False
360
- nan1 = math.isnan(self._val) if isinstance(self._val, float) else False
361
- nan2 = math.isnan(other_val) if isinstance(other_val, float) else False
362
- return self._val == other_val or (nan1 and nan2)
363
-
364
- def __hash__(self) -> int:
365
- return hash(self._val)
366
-
367
- def __deepcopy__(self, memo: Dict[int, Any]) -> "FloatNode":
368
- res = FloatNode()
369
- self._deepcopy_impl(res, memo)
370
- return res
371
-
372
-
373
- class BooleanNode(ValueNode):
374
- def __init__(
375
- self,
376
- value: Any = None,
377
- key: Any = None,
378
- parent: Optional[Box] = None,
379
- is_optional: bool = True,
380
- flags: Optional[Dict[str, bool]] = None,
381
- ):
382
- super().__init__(
383
- parent=parent,
384
- value=value,
385
- metadata=Metadata(
386
- key=key,
387
- optional=is_optional,
388
- ref_type=bool,
389
- object_type=bool,
390
- flags=flags,
391
- ),
392
- )
393
-
394
- def _validate_and_convert_impl(self, value: Any) -> bool:
395
- if isinstance(value, bool):
396
- return value
397
- if isinstance(value, int):
398
- return value != 0
399
- elif isinstance(value, str):
400
- try:
401
- return self._validate_and_convert_impl(int(value))
402
- except ValueError as e:
403
- if value.lower() in ("yes", "y", "on", "true"):
404
- return True
405
- elif value.lower() in ("no", "n", "off", "false"):
406
- return False
407
- else:
408
- raise ValidationError(
409
- "Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
410
- ).with_traceback(sys.exc_info()[2]) from e
411
- else:
412
- raise ValidationError(
413
- "Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
414
- )
415
-
416
- def __deepcopy__(self, memo: Dict[int, Any]) -> "BooleanNode":
417
- res = BooleanNode()
418
- self._deepcopy_impl(res, memo)
419
- return res
420
-
421
-
422
- class EnumNode(ValueNode): # lgtm [py/missing-equals] : Intentional.
423
- """
424
- NOTE: EnumNode is serialized to yaml as a string ("Color.BLUE"), not as a fully qualified yaml type.
425
- this means serialization to YAML of a typed config (with EnumNode) will not retain the type of the Enum
426
- when loaded.
427
- This is intentional, Please open an issue against OmegaConf if you wish to discuss this decision.
428
- """
429
-
430
- def __init__(
431
- self,
432
- enum_type: Type[Enum],
433
- value: Optional[Union[Enum, str]] = None,
434
- key: Any = None,
435
- parent: Optional[Box] = None,
436
- is_optional: bool = True,
437
- flags: Optional[Dict[str, bool]] = None,
438
- ):
439
- if not isinstance(enum_type, type) or not issubclass(enum_type, Enum):
440
- raise ValidationError(
441
- f"EnumNode can only operate on Enum subclasses ({enum_type})"
442
- )
443
- self.fields: Dict[str, str] = {}
444
- self.enum_type: Type[Enum] = enum_type
445
- for name, constant in enum_type.__members__.items():
446
- self.fields[name] = constant.value
447
- super().__init__(
448
- parent=parent,
449
- value=value,
450
- metadata=Metadata(
451
- key=key,
452
- optional=is_optional,
453
- ref_type=enum_type,
454
- object_type=enum_type,
455
- flags=flags,
456
- ),
457
- )
458
-
459
- def _strict_validate_type(self, value: Any) -> None:
460
- ref_type = self._metadata.ref_type
461
- if not isinstance(value, ref_type):
462
- type_hint = type_str(self._metadata.type_hint)
463
- raise ValidationError(
464
- f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_hint}'"
465
- )
466
-
467
- def _validate_and_convert_impl(self, value: Any) -> Enum:
468
- return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value)
469
-
470
- @staticmethod
471
- def validate_and_convert_to_enum(enum_type: Type[Enum], value: Any) -> Enum:
472
- if not isinstance(value, (str, int)) and not isinstance(value, enum_type):
473
- raise ValidationError(
474
- f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}"
475
- )
476
-
477
- if isinstance(value, enum_type):
478
- return value
479
-
480
- try:
481
- if isinstance(value, (float, bool)):
482
- raise ValueError
483
-
484
- if isinstance(value, int):
485
- return enum_type(value)
486
-
487
- if isinstance(value, str):
488
- prefix = f"{enum_type.__name__}."
489
- if value.startswith(prefix):
490
- value = value[len(prefix) :]
491
- return enum_type[value]
492
-
493
- assert False
494
-
495
- except (ValueError, KeyError) as e:
496
- valid = ", ".join([x for x in enum_type.__members__.keys()])
497
- raise ValidationError(
498
- f"Invalid value '$VALUE', expected one of [{valid}]"
499
- ).with_traceback(sys.exc_info()[2]) from e
500
-
501
- def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode":
502
- res = EnumNode(enum_type=self.enum_type)
503
- self._deepcopy_impl(res, memo)
504
- return res
505
-
506
-
507
- class InterpolationResultNode(ValueNode):
508
- """
509
- Special node type, used to wrap interpolation results.
510
- """
511
-
512
- def __init__(
513
- self,
514
- value: Any,
515
- key: Any = None,
516
- parent: Optional[Box] = None,
517
- flags: Optional[Dict[str, bool]] = None,
518
- ):
519
- super().__init__(
520
- parent=parent,
521
- value=value,
522
- metadata=Metadata(
523
- ref_type=Any, object_type=None, key=key, optional=True, flags=flags
524
- ),
525
- )
526
- # In general we should not try to write into interpolation results.
527
- if flags is None or "readonly" not in flags:
528
- self._set_flag("readonly", True)
529
-
530
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
531
- if self._get_flag("readonly"):
532
- raise ReadonlyConfigError("Cannot set value of read-only config node")
533
- self._val = self.validate_and_convert(value)
534
-
535
- def _validate_and_convert_impl(self, value: Any) -> Any:
536
- # Interpolation results may be anything.
537
- return value
538
-
539
- def __deepcopy__(self, memo: Dict[int, Any]) -> "InterpolationResultNode":
540
- # Currently there should be no need to deep-copy such nodes.
541
- raise NotImplementedError
542
-
543
- def _is_interpolation(self) -> bool:
544
- # The result of an interpolation cannot be itself an interpolation.
545
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/omegaconf.py DELETED
@@ -1,1157 +0,0 @@
1
- """OmegaConf module"""
2
- import copy
3
- import inspect
4
- import io
5
- import os
6
- import pathlib
7
- import sys
8
- import warnings
9
- from collections import defaultdict
10
- from contextlib import contextmanager
11
- from enum import Enum
12
- from textwrap import dedent
13
- from typing import (
14
- IO,
15
- Any,
16
- Callable,
17
- Dict,
18
- Generator,
19
- Iterable,
20
- List,
21
- Optional,
22
- Set,
23
- Tuple,
24
- Type,
25
- Union,
26
- overload,
27
- )
28
-
29
- import yaml
30
-
31
- from . import DictConfig, DictKeyType, ListConfig
32
- from ._utils import (
33
- _DEFAULT_MARKER_,
34
- _ensure_container,
35
- _get_value,
36
- format_and_raise,
37
- get_dict_key_value_types,
38
- get_list_element_type,
39
- get_omega_conf_dumper,
40
- get_type_of,
41
- is_attr_class,
42
- is_dataclass,
43
- is_dict_annotation,
44
- is_int,
45
- is_list_annotation,
46
- is_primitive_container,
47
- is_primitive_dict,
48
- is_primitive_list,
49
- is_structured_config,
50
- is_tuple_annotation,
51
- is_union_annotation,
52
- nullcontext,
53
- split_key,
54
- type_str,
55
- )
56
- from .base import Box, Container, Node, SCMode, UnionNode
57
- from .basecontainer import BaseContainer
58
- from .errors import (
59
- MissingMandatoryValue,
60
- OmegaConfBaseException,
61
- UnsupportedInterpolationType,
62
- ValidationError,
63
- )
64
- from .nodes import (
65
- AnyNode,
66
- BooleanNode,
67
- BytesNode,
68
- EnumNode,
69
- FloatNode,
70
- IntegerNode,
71
- PathNode,
72
- StringNode,
73
- ValueNode,
74
- )
75
-
76
- MISSING: Any = "???"
77
-
78
- Resolver = Callable[..., Any]
79
-
80
-
81
- def II(interpolation: str) -> Any:
82
- """
83
- Equivalent to ``${interpolation}``
84
-
85
- :param interpolation:
86
- :return: input ``${node}`` with type Any
87
- """
88
- return "${" + interpolation + "}"
89
-
90
-
91
- def SI(interpolation: str) -> Any:
92
- """
93
- Use this for String interpolation, for example ``"http://${host}:${port}"``
94
-
95
- :param interpolation: interpolation string
96
- :return: input interpolation with type ``Any``
97
- """
98
- return interpolation
99
-
100
-
101
- def register_default_resolvers() -> None:
102
- from omegaconf.resolvers import oc
103
-
104
- OmegaConf.register_new_resolver("oc.create", oc.create)
105
- OmegaConf.register_new_resolver("oc.decode", oc.decode)
106
- OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
107
- OmegaConf.register_new_resolver("oc.env", oc.env)
108
- OmegaConf.register_new_resolver("oc.select", oc.select)
109
- OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys)
110
- OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values)
111
-
112
-
113
- class OmegaConf:
114
- """OmegaConf primary class"""
115
-
116
- def __init__(self) -> None:
117
- raise NotImplementedError("Use one of the static construction functions")
118
-
119
- @staticmethod
120
- def structured(
121
- obj: Any,
122
- parent: Optional[BaseContainer] = None,
123
- flags: Optional[Dict[str, bool]] = None,
124
- ) -> Any:
125
- return OmegaConf.create(obj, parent, flags)
126
-
127
- @staticmethod
128
- @overload
129
- def create(
130
- obj: str,
131
- parent: Optional[BaseContainer] = None,
132
- flags: Optional[Dict[str, bool]] = None,
133
- ) -> Union[DictConfig, ListConfig]:
134
- ...
135
-
136
- @staticmethod
137
- @overload
138
- def create(
139
- obj: Union[List[Any], Tuple[Any, ...]],
140
- parent: Optional[BaseContainer] = None,
141
- flags: Optional[Dict[str, bool]] = None,
142
- ) -> ListConfig:
143
- ...
144
-
145
- @staticmethod
146
- @overload
147
- def create(
148
- obj: DictConfig,
149
- parent: Optional[BaseContainer] = None,
150
- flags: Optional[Dict[str, bool]] = None,
151
- ) -> DictConfig:
152
- ...
153
-
154
- @staticmethod
155
- @overload
156
- def create(
157
- obj: ListConfig,
158
- parent: Optional[BaseContainer] = None,
159
- flags: Optional[Dict[str, bool]] = None,
160
- ) -> ListConfig:
161
- ...
162
-
163
- @staticmethod
164
- @overload
165
- def create(
166
- obj: Optional[Dict[Any, Any]] = None,
167
- parent: Optional[BaseContainer] = None,
168
- flags: Optional[Dict[str, bool]] = None,
169
- ) -> DictConfig:
170
- ...
171
-
172
- @staticmethod
173
- def create( # noqa F811
174
- obj: Any = _DEFAULT_MARKER_,
175
- parent: Optional[BaseContainer] = None,
176
- flags: Optional[Dict[str, bool]] = None,
177
- ) -> Union[DictConfig, ListConfig]:
178
- return OmegaConf._create_impl(
179
- obj=obj,
180
- parent=parent,
181
- flags=flags,
182
- )
183
-
184
- @staticmethod
185
- def load(file_: Union[str, pathlib.Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
186
- from ._utils import get_yaml_loader
187
-
188
- if isinstance(file_, (str, pathlib.Path)):
189
- with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f:
190
- obj = yaml.load(f, Loader=get_yaml_loader())
191
- elif getattr(file_, "read", None):
192
- obj = yaml.load(file_, Loader=get_yaml_loader())
193
- else:
194
- raise TypeError("Unexpected file type")
195
-
196
- if obj is not None and not isinstance(obj, (list, dict, str)):
197
- raise IOError( # pragma: no cover
198
- f"Invalid loaded object type: {type(obj).__name__}"
199
- )
200
-
201
- ret: Union[DictConfig, ListConfig]
202
- if obj is None:
203
- ret = OmegaConf.create()
204
- else:
205
- ret = OmegaConf.create(obj)
206
- return ret
207
-
208
- @staticmethod
209
- def save(
210
- config: Any, f: Union[str, pathlib.Path, IO[Any]], resolve: bool = False
211
- ) -> None:
212
- """
213
- Save as configuration object to a file
214
-
215
- :param config: omegaconf.Config object (DictConfig or ListConfig).
216
- :param f: filename or file object
217
- :param resolve: True to save a resolved config (defaults to False)
218
- """
219
- if is_dataclass(config) or is_attr_class(config):
220
- config = OmegaConf.create(config)
221
- data = OmegaConf.to_yaml(config, resolve=resolve)
222
- if isinstance(f, (str, pathlib.Path)):
223
- with io.open(os.path.abspath(f), "w", encoding="utf-8") as file:
224
- file.write(data)
225
- elif hasattr(f, "write"):
226
- f.write(data)
227
- f.flush()
228
- else:
229
- raise TypeError("Unexpected file type")
230
-
231
- @staticmethod
232
- def from_cli(args_list: Optional[List[str]] = None) -> DictConfig:
233
- if args_list is None:
234
- # Skip program name
235
- args_list = sys.argv[1:]
236
- return OmegaConf.from_dotlist(args_list)
237
-
238
- @staticmethod
239
- def from_dotlist(dotlist: List[str]) -> DictConfig:
240
- """
241
- Creates config from the content sys.argv or from the specified args list of not None
242
-
243
- :param dotlist: A list of dotlist-style strings, e.g. ``["foo.bar=1", "baz=qux"]``.
244
- :return: A ``DictConfig`` object created from the dotlist.
245
- """
246
- conf = OmegaConf.create()
247
- conf.merge_with_dotlist(dotlist)
248
- return conf
249
-
250
- @staticmethod
251
- def merge(
252
- *configs: Union[
253
- DictConfig,
254
- ListConfig,
255
- Dict[DictKeyType, Any],
256
- List[Any],
257
- Tuple[Any, ...],
258
- Any,
259
- ],
260
- ) -> Union[ListConfig, DictConfig]:
261
- """
262
- Merge a list of previously created configs into a single one
263
-
264
- :param configs: Input configs
265
- :return: the merged config object.
266
- """
267
- assert len(configs) > 0
268
- target = copy.deepcopy(configs[0])
269
- target = _ensure_container(target)
270
- assert isinstance(target, (DictConfig, ListConfig))
271
-
272
- with flag_override(target, "readonly", False):
273
- target.merge_with(*configs[1:])
274
- turned_readonly = target._get_flag("readonly") is True
275
-
276
- if turned_readonly:
277
- OmegaConf.set_readonly(target, True)
278
-
279
- return target
280
-
281
- @staticmethod
282
- def unsafe_merge(
283
- *configs: Union[
284
- DictConfig,
285
- ListConfig,
286
- Dict[DictKeyType, Any],
287
- List[Any],
288
- Tuple[Any, ...],
289
- Any,
290
- ],
291
- ) -> Union[ListConfig, DictConfig]:
292
- """
293
- Merge a list of previously created configs into a single one
294
- This is much faster than OmegaConf.merge() as the input configs are not copied.
295
- However, the input configs must not be used after this operation as will become inconsistent.
296
-
297
- :param configs: Input configs
298
- :return: the merged config object.
299
- """
300
- assert len(configs) > 0
301
- target = configs[0]
302
- target = _ensure_container(target)
303
- assert isinstance(target, (DictConfig, ListConfig))
304
-
305
- with flag_override(
306
- target, ["readonly", "no_deepcopy_set_nodes"], [False, True]
307
- ):
308
- target.merge_with(*configs[1:])
309
- turned_readonly = target._get_flag("readonly") is True
310
-
311
- if turned_readonly:
312
- OmegaConf.set_readonly(target, True)
313
-
314
- return target
315
-
316
- @staticmethod
317
- def register_resolver(name: str, resolver: Resolver) -> None:
318
- warnings.warn(
319
- dedent(
320
- """\
321
- register_resolver() is deprecated.
322
- See https://github.com/omry/omegaconf/issues/426 for migration instructions.
323
- """
324
- ),
325
- stacklevel=2,
326
- )
327
- return OmegaConf.legacy_register_resolver(name, resolver)
328
-
329
- # This function will eventually be deprecated and removed.
330
- @staticmethod
331
- def legacy_register_resolver(name: str, resolver: Resolver) -> None:
332
- assert callable(resolver), "resolver must be callable"
333
- # noinspection PyProtectedMember
334
- assert (
335
- name not in BaseContainer._resolvers
336
- ), f"resolver '{name}' is already registered"
337
-
338
- def resolver_wrapper(
339
- config: BaseContainer,
340
- parent: BaseContainer,
341
- node: Node,
342
- args: Tuple[Any, ...],
343
- args_str: Tuple[str, ...],
344
- ) -> Any:
345
- cache = OmegaConf.get_cache(config)[name]
346
- # "Un-escape " spaces and commas.
347
- args_unesc = [x.replace(r"\ ", " ").replace(r"\,", ",") for x in args_str]
348
-
349
- # Nested interpolations behave in a potentially surprising way with
350
- # legacy resolvers (they remain as strings, e.g., "${foo}"). If any
351
- # input looks like an interpolation we thus raise an exception.
352
- try:
353
- bad_arg = next(i for i in args_unesc if "${" in i)
354
- except StopIteration:
355
- pass
356
- else:
357
- raise ValueError(
358
- f"Resolver '{name}' was called with argument '{bad_arg}' that appears "
359
- f"to be an interpolation. Nested interpolations are not supported for "
360
- f"resolvers registered with `[legacy_]register_resolver()`, please use "
361
- f"`register_new_resolver()` instead (see "
362
- f"https://github.com/omry/omegaconf/issues/426 for migration instructions)."
363
- )
364
- key = args_str
365
- val = cache[key] if key in cache else resolver(*args_unesc)
366
- cache[key] = val
367
- return val
368
-
369
- # noinspection PyProtectedMember
370
- BaseContainer._resolvers[name] = resolver_wrapper
371
-
372
- @staticmethod
373
- def register_new_resolver(
374
- name: str,
375
- resolver: Resolver,
376
- *,
377
- replace: bool = False,
378
- use_cache: bool = False,
379
- ) -> None:
380
- """
381
- Register a resolver.
382
-
383
- :param name: Name of the resolver.
384
- :param resolver: Callable whose arguments are provided in the interpolation,
385
- e.g., with ${foo:x,0,${y.z}} these arguments are respectively "x" (str),
386
- 0 (int) and the value of ``y.z``.
387
- :param replace: If set to ``False`` (default), then a ``ValueError`` is raised if
388
- an existing resolver has already been registered with the same name.
389
- If set to ``True``, then the new resolver replaces the previous one.
390
- NOTE: The cache on existing config objects is not affected, use
391
- ``OmegaConf.clear_cache(cfg)`` to clear it.
392
- :param use_cache: Whether the resolver's outputs should be cached. The cache is
393
- based only on the string literals representing the resolver arguments, e.g.,
394
- ${foo:${bar}} will always return the same value regardless of the value of
395
- ``bar`` if the cache is enabled for ``foo``.
396
- """
397
- if not callable(resolver):
398
- raise TypeError("resolver must be callable")
399
- if not name:
400
- raise ValueError("cannot use an empty resolver name")
401
-
402
- if not replace and OmegaConf.has_resolver(name):
403
- raise ValueError(f"resolver '{name}' is already registered")
404
-
405
- try:
406
- sig: Optional[inspect.Signature] = inspect.signature(resolver)
407
- except ValueError:
408
- sig = None
409
-
410
- def _should_pass(special: str) -> bool:
411
- ret = sig is not None and special in sig.parameters
412
- if ret and use_cache:
413
- raise ValueError(
414
- f"use_cache=True is incompatible with functions that receive the {special}"
415
- )
416
- return ret
417
-
418
- pass_parent = _should_pass("_parent_")
419
- pass_node = _should_pass("_node_")
420
- pass_root = _should_pass("_root_")
421
-
422
- def resolver_wrapper(
423
- config: BaseContainer,
424
- parent: Container,
425
- node: Node,
426
- args: Tuple[Any, ...],
427
- args_str: Tuple[str, ...],
428
- ) -> Any:
429
- if use_cache:
430
- cache = OmegaConf.get_cache(config)[name]
431
- try:
432
- return cache[args_str]
433
- except KeyError:
434
- pass
435
-
436
- # Call resolver.
437
- kwargs: Dict[str, Node] = {}
438
- if pass_parent:
439
- kwargs["_parent_"] = parent
440
- if pass_node:
441
- kwargs["_node_"] = node
442
- if pass_root:
443
- kwargs["_root_"] = config
444
-
445
- ret = resolver(*args, **kwargs)
446
-
447
- if use_cache:
448
- cache[args_str] = ret
449
- return ret
450
-
451
- # noinspection PyProtectedMember
452
- BaseContainer._resolvers[name] = resolver_wrapper
453
-
454
- @classmethod
455
- def has_resolver(cls, name: str) -> bool:
456
- return cls._get_resolver(name) is not None
457
-
458
- # noinspection PyProtectedMember
459
- @staticmethod
460
- def clear_resolvers() -> None:
461
- """
462
- Clear(remove) all OmegaConf resolvers, then re-register OmegaConf's default resolvers.
463
- """
464
- BaseContainer._resolvers = {}
465
- register_default_resolvers()
466
-
467
- @classmethod
468
- def clear_resolver(cls, name: str) -> bool:
469
- """
470
- Clear(remove) any resolver only if it exists.
471
-
472
- Returns a bool: True if resolver is removed and False if not removed.
473
-
474
- .. warning:
475
- This method can remove deafult resolvers as well.
476
-
477
- :param name: Name of the resolver.
478
- :return: A bool (``True`` if resolver is removed, ``False`` if not found before removing).
479
- """
480
- if cls.has_resolver(name):
481
- BaseContainer._resolvers.pop(name)
482
- return True
483
- else:
484
- # return False if resolver does not exist
485
- return False
486
-
487
- @staticmethod
488
- def get_cache(conf: BaseContainer) -> Dict[str, Any]:
489
- return conf._metadata.resolver_cache
490
-
491
- @staticmethod
492
- def set_cache(conf: BaseContainer, cache: Dict[str, Any]) -> None:
493
- conf._metadata.resolver_cache = copy.deepcopy(cache)
494
-
495
- @staticmethod
496
- def clear_cache(conf: BaseContainer) -> None:
497
- OmegaConf.set_cache(conf, defaultdict(dict, {}))
498
-
499
- @staticmethod
500
- def copy_cache(from_config: BaseContainer, to_config: BaseContainer) -> None:
501
- OmegaConf.set_cache(to_config, OmegaConf.get_cache(from_config))
502
-
503
- @staticmethod
504
- def set_readonly(conf: Node, value: Optional[bool]) -> None:
505
- # noinspection PyProtectedMember
506
- conf._set_flag("readonly", value)
507
-
508
- @staticmethod
509
- def is_readonly(conf: Node) -> Optional[bool]:
510
- # noinspection PyProtectedMember
511
- return conf._get_flag("readonly")
512
-
513
- @staticmethod
514
- def set_struct(conf: Container, value: Optional[bool]) -> None:
515
- # noinspection PyProtectedMember
516
- conf._set_flag("struct", value)
517
-
518
- @staticmethod
519
- def is_struct(conf: Container) -> Optional[bool]:
520
- # noinspection PyProtectedMember
521
- return conf._get_flag("struct")
522
-
523
- @staticmethod
524
- def masked_copy(conf: DictConfig, keys: Union[str, List[str]]) -> DictConfig:
525
- """
526
- Create a masked copy of of this config that contains a subset of the keys
527
-
528
- :param conf: DictConfig object
529
- :param keys: keys to preserve in the copy
530
- :return: The masked ``DictConfig`` object.
531
- """
532
- from .dictconfig import DictConfig
533
-
534
- if not isinstance(conf, DictConfig):
535
- raise ValueError("masked_copy is only supported for DictConfig")
536
-
537
- if isinstance(keys, str):
538
- keys = [keys]
539
- content = {key: value for key, value in conf.items_ex(resolve=False, keys=keys)}
540
- return DictConfig(content=content)
541
-
542
- @staticmethod
543
- def to_container(
544
- cfg: Any,
545
- *,
546
- resolve: bool = False,
547
- throw_on_missing: bool = False,
548
- enum_to_str: bool = False,
549
- structured_config_mode: SCMode = SCMode.DICT,
550
- ) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
551
- """
552
- Resursively converts an OmegaConf config to a primitive container (dict or list).
553
-
554
- :param cfg: the config to convert
555
- :param resolve: True to resolve all values
556
- :param throw_on_missing: When True, raise MissingMandatoryValue if any missing values are present.
557
- When False (the default), replace missing values with the string "???" in the output container.
558
- :param enum_to_str: True to convert Enum keys and values to strings
559
- :param structured_config_mode: Specify how Structured Configs (DictConfigs backed by a dataclass) are handled.
560
- - By default (``structured_config_mode=SCMode.DICT``) structured configs are converted to plain dicts.
561
- - If ``structured_config_mode=SCMode.DICT_CONFIG``, structured config nodes will remain as DictConfig.
562
- - If ``structured_config_mode=SCMode.INSTANTIATE``, this function will instantiate structured configs
563
- (DictConfigs backed by a dataclass), by creating an instance of the underlying dataclass.
564
-
565
- See also OmegaConf.to_object.
566
- :return: A dict or a list representing this config as a primitive container.
567
- """
568
- if not OmegaConf.is_config(cfg):
569
- raise ValueError(
570
- f"Input cfg is not an OmegaConf config object ({type_str(type(cfg))})"
571
- )
572
-
573
- return BaseContainer._to_content(
574
- cfg,
575
- resolve=resolve,
576
- throw_on_missing=throw_on_missing,
577
- enum_to_str=enum_to_str,
578
- structured_config_mode=structured_config_mode,
579
- )
580
-
581
- @staticmethod
582
- def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
583
- """
584
- Resursively converts an OmegaConf config to a primitive container (dict or list).
585
- Any DictConfig objects backed by dataclasses or attrs classes are instantiated
586
- as instances of those backing classes.
587
-
588
- This is an alias for OmegaConf.to_container(..., resolve=True, throw_on_missing=True,
589
- structured_config_mode=SCMode.INSTANTIATE)
590
-
591
- :param cfg: the config to convert
592
- :return: A dict or a list or dataclass representing this config.
593
- """
594
- return OmegaConf.to_container(
595
- cfg=cfg,
596
- resolve=True,
597
- throw_on_missing=True,
598
- enum_to_str=False,
599
- structured_config_mode=SCMode.INSTANTIATE,
600
- )
601
-
602
- @staticmethod
603
- def is_missing(cfg: Any, key: DictKeyType) -> bool:
604
- assert isinstance(cfg, Container)
605
- try:
606
- node = cfg._get_child(key)
607
- if node is None:
608
- return False
609
- assert isinstance(node, Node)
610
- return node._is_missing()
611
- except (UnsupportedInterpolationType, KeyError, AttributeError):
612
- return False
613
-
614
- @staticmethod
615
- def is_interpolation(node: Any, key: Optional[Union[int, str]] = None) -> bool:
616
- if key is not None:
617
- assert isinstance(node, Container)
618
- target = node._get_child(key)
619
- else:
620
- target = node
621
- if target is not None:
622
- assert isinstance(target, Node)
623
- return target._is_interpolation()
624
- return False
625
-
626
- @staticmethod
627
- def is_list(obj: Any) -> bool:
628
- from . import ListConfig
629
-
630
- return isinstance(obj, ListConfig)
631
-
632
- @staticmethod
633
- def is_dict(obj: Any) -> bool:
634
- from . import DictConfig
635
-
636
- return isinstance(obj, DictConfig)
637
-
638
- @staticmethod
639
- def is_config(obj: Any) -> bool:
640
- from . import Container
641
-
642
- return isinstance(obj, Container)
643
-
644
- @staticmethod
645
- def get_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]:
646
- if key is not None:
647
- c = obj._get_child(key)
648
- else:
649
- c = obj
650
- return OmegaConf._get_obj_type(c)
651
-
652
- @staticmethod
653
- def select(
654
- cfg: Container,
655
- key: str,
656
- *,
657
- default: Any = _DEFAULT_MARKER_,
658
- throw_on_resolution_failure: bool = True,
659
- throw_on_missing: bool = False,
660
- ) -> Any:
661
- """
662
- :param cfg: Config node to select from
663
- :param key: Key to select
664
- :param default: Default value to return if key is not found
665
- :param throw_on_resolution_failure: Raise an exception if an interpolation
666
- resolution error occurs, otherwise return None
667
- :param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???')
668
- is made, otherwise return None
669
- :return: selected value or None if not found.
670
- """
671
- from ._impl import select_value
672
-
673
- try:
674
- return select_value(
675
- cfg=cfg,
676
- key=key,
677
- default=default,
678
- throw_on_resolution_failure=throw_on_resolution_failure,
679
- throw_on_missing=throw_on_missing,
680
- )
681
- except Exception as e:
682
- format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e))
683
-
684
- @staticmethod
685
- def update(
686
- cfg: Container,
687
- key: str,
688
- value: Any = None,
689
- *,
690
- merge: bool = True,
691
- force_add: bool = False,
692
- ) -> None:
693
- """
694
- Updates a dot separated key sequence to a value
695
-
696
- :param cfg: input config to update
697
- :param key: key to update (can be a dot separated path)
698
- :param value: value to set, if value if a list or a dict it will be merged or set
699
- depending on merge_config_values
700
- :param merge: If value is a dict or a list, True (default) to merge
701
- into the destination, False to replace the destination.
702
- :param force_add: insert the entire path regardless of Struct flag or Structured Config nodes.
703
- """
704
-
705
- split = split_key(key)
706
- root = cfg
707
- for i in range(len(split) - 1):
708
- k = split[i]
709
- # if next_root is a primitive (string, int etc) replace it with an empty map
710
- next_root, key_ = _select_one(root, k, throw_on_missing=False)
711
- if not isinstance(next_root, Container):
712
- if force_add:
713
- with flag_override(root, "struct", False):
714
- root[key_] = {}
715
- else:
716
- root[key_] = {}
717
- root = root[key_]
718
-
719
- last = split[-1]
720
-
721
- assert isinstance(
722
- root, Container
723
- ), f"Unexpected type for root: {type(root).__name__}"
724
-
725
- last_key: Union[str, int] = last
726
- if isinstance(root, ListConfig):
727
- last_key = int(last)
728
-
729
- ctx = flag_override(root, "struct", False) if force_add else nullcontext()
730
- with ctx:
731
- if merge and (OmegaConf.is_config(value) or is_primitive_container(value)):
732
- assert isinstance(root, BaseContainer)
733
- node = root._get_child(last_key)
734
- if OmegaConf.is_config(node):
735
- assert isinstance(node, BaseContainer)
736
- node.merge_with(value)
737
- return
738
-
739
- if OmegaConf.is_dict(root):
740
- assert isinstance(last_key, str)
741
- root.__setattr__(last_key, value)
742
- elif OmegaConf.is_list(root):
743
- assert isinstance(last_key, int)
744
- root.__setitem__(last_key, value)
745
- else:
746
- assert False
747
-
748
- @staticmethod
749
- def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:
750
- """
751
- returns a yaml dump of this config object.
752
-
753
- :param cfg: Config object, Structured Config type or instance
754
- :param resolve: if True, will return a string with the interpolations resolved, otherwise
755
- interpolations are preserved
756
- :param sort_keys: If True, will print dict keys in sorted order. default False.
757
- :return: A string containing the yaml representation.
758
- """
759
- cfg = _ensure_container(cfg)
760
- container = OmegaConf.to_container(cfg, resolve=resolve, enum_to_str=True)
761
- return yaml.dump( # type: ignore
762
- container,
763
- default_flow_style=False,
764
- allow_unicode=True,
765
- sort_keys=sort_keys,
766
- Dumper=get_omega_conf_dumper(),
767
- )
768
-
769
- @staticmethod
770
- def resolve(cfg: Container) -> None:
771
- """
772
- Resolves all interpolations in the given config object in-place.
773
-
774
- :param cfg: An OmegaConf container (DictConfig, ListConfig)
775
- Raises a ValueError if the input object is not an OmegaConf container.
776
- """
777
- import omegaconf._impl
778
-
779
- if not OmegaConf.is_config(cfg):
780
- # Since this function is mutating the input object in-place, it doesn't make sense to
781
- # auto-convert the input object to an OmegaConf container
782
- raise ValueError(
783
- f"Invalid config type ({type(cfg).__name__}), expected an OmegaConf Container"
784
- )
785
- omegaconf._impl._resolve(cfg)
786
-
787
- @staticmethod
788
- def missing_keys(cfg: Any) -> Set[str]:
789
- """
790
- Returns a set of missing keys in a dotlist style.
791
-
792
- :param cfg: An ``OmegaConf.Container``,
793
- or a convertible object via ``OmegaConf.create`` (dict, list, ...).
794
- :return: set of strings of the missing keys.
795
- :raises ValueError: On input not representing a config.
796
- """
797
- cfg = _ensure_container(cfg)
798
- missings: Set[str] = set()
799
-
800
- def gather(_cfg: Container) -> None:
801
- itr: Iterable[Any]
802
- if isinstance(_cfg, ListConfig):
803
- itr = range(len(_cfg))
804
- else:
805
- itr = _cfg
806
-
807
- for key in itr:
808
- if OmegaConf.is_missing(_cfg, key):
809
- missings.add(_cfg._get_full_key(key))
810
- elif OmegaConf.is_config(_cfg[key]):
811
- gather(_cfg[key])
812
-
813
- gather(cfg)
814
- return missings
815
-
816
- # === private === #
817
-
818
- @staticmethod
819
- def _create_impl( # noqa F811
820
- obj: Any = _DEFAULT_MARKER_,
821
- parent: Optional[BaseContainer] = None,
822
- flags: Optional[Dict[str, bool]] = None,
823
- ) -> Union[DictConfig, ListConfig]:
824
- try:
825
- from ._utils import get_yaml_loader
826
- from .dictconfig import DictConfig
827
- from .listconfig import ListConfig
828
-
829
- if obj is _DEFAULT_MARKER_:
830
- obj = {}
831
- if isinstance(obj, str):
832
- obj = yaml.load(obj, Loader=get_yaml_loader())
833
- if obj is None:
834
- return OmegaConf.create({}, parent=parent, flags=flags)
835
- elif isinstance(obj, str):
836
- return OmegaConf.create({obj: None}, parent=parent, flags=flags)
837
- else:
838
- assert isinstance(obj, (list, dict))
839
- return OmegaConf.create(obj, parent=parent, flags=flags)
840
-
841
- else:
842
- if (
843
- is_primitive_dict(obj)
844
- or OmegaConf.is_dict(obj)
845
- or is_structured_config(obj)
846
- or obj is None
847
- ):
848
- if isinstance(obj, DictConfig):
849
- return DictConfig(
850
- content=obj,
851
- parent=parent,
852
- ref_type=obj._metadata.ref_type,
853
- is_optional=obj._metadata.optional,
854
- key_type=obj._metadata.key_type,
855
- element_type=obj._metadata.element_type,
856
- flags=flags,
857
- )
858
- else:
859
- obj_type = OmegaConf.get_type(obj)
860
- key_type, element_type = get_dict_key_value_types(obj_type)
861
- return DictConfig(
862
- content=obj,
863
- parent=parent,
864
- key_type=key_type,
865
- element_type=element_type,
866
- flags=flags,
867
- )
868
- elif is_primitive_list(obj) or OmegaConf.is_list(obj):
869
- if isinstance(obj, ListConfig):
870
- return ListConfig(
871
- content=obj,
872
- parent=parent,
873
- element_type=obj._metadata.element_type,
874
- ref_type=obj._metadata.ref_type,
875
- is_optional=obj._metadata.optional,
876
- flags=flags,
877
- )
878
- else:
879
- obj_type = OmegaConf.get_type(obj)
880
- element_type = get_list_element_type(obj_type)
881
- return ListConfig(
882
- content=obj,
883
- parent=parent,
884
- element_type=element_type,
885
- ref_type=Any,
886
- is_optional=True,
887
- flags=flags,
888
- )
889
- else:
890
- if isinstance(obj, type):
891
- raise ValidationError(
892
- f"Input class '{obj.__name__}' is not a structured config. "
893
- "did you forget to decorate it as a dataclass?"
894
- )
895
- else:
896
- raise ValidationError(
897
- f"Object of unsupported type: '{type(obj).__name__}'"
898
- )
899
- except OmegaConfBaseException as e:
900
- format_and_raise(node=None, key=None, value=None, msg=str(e), cause=e)
901
- assert False
902
-
903
- @staticmethod
904
- def _get_obj_type(c: Any) -> Optional[Type[Any]]:
905
- if is_structured_config(c):
906
- return get_type_of(c)
907
- elif c is None:
908
- return None
909
- elif isinstance(c, DictConfig):
910
- if c._is_none():
911
- return None
912
- elif c._is_missing():
913
- return None
914
- else:
915
- if is_structured_config(c._metadata.object_type):
916
- return c._metadata.object_type
917
- else:
918
- return dict
919
- elif isinstance(c, ListConfig):
920
- return list
921
- elif isinstance(c, ValueNode):
922
- return type(c._value())
923
- elif isinstance(c, UnionNode):
924
- return type(_get_value(c))
925
- elif isinstance(c, dict):
926
- return dict
927
- elif isinstance(c, (list, tuple)):
928
- return list
929
- else:
930
- return get_type_of(c)
931
-
932
- @staticmethod
933
- def _get_resolver(
934
- name: str,
935
- ) -> Optional[
936
- Callable[
937
- [Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]],
938
- Any,
939
- ]
940
- ]:
941
- # noinspection PyProtectedMember
942
- return (
943
- BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
944
- )
945
-
946
-
947
- # register all default resolvers
948
- register_default_resolvers()
949
-
950
-
951
- @contextmanager
952
- def flag_override(
953
- config: Node,
954
- names: Union[List[str], str],
955
- values: Union[List[Optional[bool]], Optional[bool]],
956
- ) -> Generator[Node, None, None]:
957
- if isinstance(names, str):
958
- names = [names]
959
- if values is None or isinstance(values, bool):
960
- values = [values]
961
-
962
- prev_states = [config._get_node_flag(name) for name in names]
963
-
964
- try:
965
- config._set_flag(names, values)
966
- yield config
967
- finally:
968
- config._set_flag(names, prev_states)
969
-
970
-
971
- @contextmanager
972
- def read_write(config: Node) -> Generator[Node, None, None]:
973
- prev_state = config._get_node_flag("readonly")
974
- try:
975
- OmegaConf.set_readonly(config, False)
976
- yield config
977
- finally:
978
- OmegaConf.set_readonly(config, prev_state)
979
-
980
-
981
- @contextmanager
982
- def open_dict(config: Container) -> Generator[Container, None, None]:
983
- prev_state = config._get_node_flag("struct")
984
- try:
985
- OmegaConf.set_struct(config, False)
986
- yield config
987
- finally:
988
- OmegaConf.set_struct(config, prev_state)
989
-
990
-
991
- # === private === #
992
-
993
-
994
- def _node_wrap(
995
- parent: Optional[Box],
996
- is_optional: bool,
997
- value: Any,
998
- key: Any,
999
- ref_type: Any = Any,
1000
- ) -> Node:
1001
- node: Node
1002
- if is_dict_annotation(ref_type) or (is_primitive_dict(value) and ref_type is Any):
1003
- key_type, element_type = get_dict_key_value_types(ref_type)
1004
- node = DictConfig(
1005
- content=value,
1006
- key=key,
1007
- parent=parent,
1008
- ref_type=ref_type,
1009
- is_optional=is_optional,
1010
- key_type=key_type,
1011
- element_type=element_type,
1012
- )
1013
- elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or (
1014
- type(value) in (list, tuple) and ref_type is Any
1015
- ):
1016
- element_type = get_list_element_type(ref_type)
1017
- node = ListConfig(
1018
- content=value,
1019
- key=key,
1020
- parent=parent,
1021
- is_optional=is_optional,
1022
- element_type=element_type,
1023
- ref_type=ref_type,
1024
- )
1025
- elif is_structured_config(ref_type) or is_structured_config(value):
1026
- key_type, element_type = get_dict_key_value_types(value)
1027
- node = DictConfig(
1028
- ref_type=ref_type,
1029
- is_optional=is_optional,
1030
- content=value,
1031
- key=key,
1032
- parent=parent,
1033
- key_type=key_type,
1034
- element_type=element_type,
1035
- )
1036
- elif is_union_annotation(ref_type):
1037
- node = UnionNode(
1038
- content=value,
1039
- ref_type=ref_type,
1040
- is_optional=is_optional,
1041
- key=key,
1042
- parent=parent,
1043
- )
1044
- elif ref_type == Any or ref_type is None:
1045
- node = AnyNode(value=value, key=key, parent=parent)
1046
- elif isinstance(ref_type, type) and issubclass(ref_type, Enum):
1047
- node = EnumNode(
1048
- enum_type=ref_type,
1049
- value=value,
1050
- key=key,
1051
- parent=parent,
1052
- is_optional=is_optional,
1053
- )
1054
- elif ref_type == int:
1055
- node = IntegerNode(value=value, key=key, parent=parent, is_optional=is_optional)
1056
- elif ref_type == float:
1057
- node = FloatNode(value=value, key=key, parent=parent, is_optional=is_optional)
1058
- elif ref_type == bool:
1059
- node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
1060
- elif ref_type == str:
1061
- node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
1062
- elif ref_type == bytes:
1063
- node = BytesNode(value=value, key=key, parent=parent, is_optional=is_optional)
1064
- elif ref_type == pathlib.Path:
1065
- node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional)
1066
- else:
1067
- if parent is not None and parent._get_flag("allow_objects") is True:
1068
- if type(value) in (list, tuple):
1069
- node = ListConfig(
1070
- content=value,
1071
- key=key,
1072
- parent=parent,
1073
- ref_type=ref_type,
1074
- is_optional=is_optional,
1075
- )
1076
- elif is_primitive_dict(value):
1077
- node = DictConfig(
1078
- content=value,
1079
- key=key,
1080
- parent=parent,
1081
- ref_type=ref_type,
1082
- is_optional=is_optional,
1083
- )
1084
- else:
1085
- node = AnyNode(value=value, key=key, parent=parent)
1086
- else:
1087
- raise ValidationError(f"Unexpected type annotation: {type_str(ref_type)}")
1088
- return node
1089
-
1090
-
1091
- def _maybe_wrap(
1092
- ref_type: Any,
1093
- key: Any,
1094
- value: Any,
1095
- is_optional: bool,
1096
- parent: Optional[BaseContainer],
1097
- ) -> Node:
1098
- # if already a node, update key and parent and return as is.
1099
- # NOTE: that this mutate the input node!
1100
- if isinstance(value, Node):
1101
- value._set_key(key)
1102
- value._set_parent(parent)
1103
- return value
1104
- else:
1105
- return _node_wrap(
1106
- ref_type=ref_type,
1107
- parent=parent,
1108
- is_optional=is_optional,
1109
- value=value,
1110
- key=key,
1111
- )
1112
-
1113
-
1114
- def _select_one(
1115
- c: Container, key: str, throw_on_missing: bool, throw_on_type_error: bool = True
1116
- ) -> Tuple[Optional[Node], Union[str, int]]:
1117
- from .dictconfig import DictConfig
1118
- from .listconfig import ListConfig
1119
-
1120
- ret_key: Union[str, int] = key
1121
- assert isinstance(c, Container), f"Unexpected type: {c}"
1122
- if c._is_none():
1123
- return None, ret_key
1124
-
1125
- if isinstance(c, DictConfig):
1126
- assert isinstance(ret_key, str)
1127
- val = c._get_child(ret_key, validate_access=False)
1128
- elif isinstance(c, ListConfig):
1129
- assert isinstance(ret_key, str)
1130
- if not is_int(ret_key):
1131
- if throw_on_type_error:
1132
- raise TypeError(
1133
- f"Index '{ret_key}' ({type(ret_key).__name__}) is not an int"
1134
- )
1135
- else:
1136
- val = None
1137
- else:
1138
- ret_key = int(ret_key)
1139
- if ret_key < 0 or ret_key + 1 > len(c):
1140
- val = None
1141
- else:
1142
- val = c._get_child(ret_key)
1143
- else:
1144
- assert False
1145
-
1146
- if val is not None:
1147
- assert isinstance(val, Node)
1148
- if val._is_missing():
1149
- if throw_on_missing:
1150
- raise MissingMandatoryValue(
1151
- f"Missing mandatory value: {c._get_full_key(ret_key)}"
1152
- )
1153
- else:
1154
- return val, ret_key
1155
-
1156
- assert val is None or isinstance(val, Node)
1157
- return val, ret_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/py.typed DELETED
File without changes
omegaconf/resolvers/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from omegaconf.resolvers import oc
2
-
3
- __all__ = [
4
- "oc",
5
- ]
 
 
 
 
 
 
omegaconf/resolvers/oc/__init__.py DELETED
@@ -1,113 +0,0 @@
1
- import os
2
- import string
3
- import warnings
4
- from typing import Any, Optional
5
-
6
- from omegaconf import Container, Node
7
- from omegaconf._utils import _DEFAULT_MARKER_, _get_value
8
- from omegaconf.basecontainer import BaseContainer
9
- from omegaconf.errors import ConfigKeyError
10
- from omegaconf.grammar_parser import parse
11
- from omegaconf.resolvers.oc import dict
12
-
13
-
14
- def create(obj: Any, _parent_: Container) -> Any:
15
- """Create a config object from `obj`, similar to `OmegaConf.create`"""
16
- from omegaconf import OmegaConf
17
-
18
- assert isinstance(_parent_, BaseContainer)
19
- return OmegaConf.create(obj, parent=_parent_)
20
-
21
-
22
- def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
23
- """
24
- :param key: Environment variable key
25
- :param default: Optional default value to use in case the key environment variable is not set.
26
- If default is not a string, it is converted with str(default).
27
- None default is returned as is.
28
- :return: The environment variable 'key'. If the environment variable is not set and a default is
29
- provided, the default is used. If used, the default is converted to a string with str(default).
30
- If the default is None, None is returned (without a string conversion).
31
- """
32
- try:
33
- return os.environ[key]
34
- except KeyError:
35
- if default is not _DEFAULT_MARKER_:
36
- return str(default) if default is not None else None
37
- else:
38
- raise KeyError(f"Environment variable '{key}' not found")
39
-
40
-
41
- def decode(expr: Optional[str], _parent_: Container, _node_: Node) -> Any:
42
- """
43
- Parse and evaluate `expr` according to the `singleElement` rule of the grammar.
44
-
45
- If `expr` is `None`, then return `None`.
46
- """
47
- if expr is None:
48
- return None
49
-
50
- if not isinstance(expr, str):
51
- raise TypeError(
52
- f"`oc.decode` can only take strings or None as input, "
53
- f"but `{expr}` is of type {type(expr).__name__}"
54
- )
55
-
56
- parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE")
57
- val = _parent_.resolve_parse_tree(parse_tree, node=_node_)
58
- return _get_value(val)
59
-
60
-
61
- def deprecated(
62
- key: str,
63
- message: str = "'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'",
64
- *,
65
- _parent_: Container,
66
- _node_: Node,
67
- ) -> Any:
68
- from omegaconf._impl import select_node
69
-
70
- if not isinstance(key, str):
71
- raise TypeError(
72
- f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})"
73
- )
74
-
75
- if not isinstance(message, str):
76
- raise TypeError(
77
- f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})"
78
- )
79
-
80
- full_key = _node_._get_full_key(key=None)
81
- target_node = select_node(_parent_, key, absolute_key=True)
82
- if target_node is None:
83
- raise ConfigKeyError(
84
- f"In oc.deprecated resolver at '{full_key}': Key not found: '{key}'"
85
- )
86
- new_key = target_node._get_full_key(key=None)
87
- msg = string.Template(message).safe_substitute(
88
- OLD_KEY=full_key,
89
- NEW_KEY=new_key,
90
- )
91
- warnings.warn(category=UserWarning, message=msg)
92
- return target_node
93
-
94
-
95
- def select(
96
- key: str,
97
- default: Any = _DEFAULT_MARKER_,
98
- *,
99
- _parent_: Container,
100
- ) -> Any:
101
- from omegaconf._impl import select_value
102
-
103
- return select_value(cfg=_parent_, key=key, absolute_key=True, default=default)
104
-
105
-
106
- __all__ = [
107
- "create",
108
- "decode",
109
- "deprecated",
110
- "dict",
111
- "env",
112
- "select",
113
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/resolvers/oc/dict.py DELETED
@@ -1,83 +0,0 @@
1
- from typing import Any, List
2
-
3
- from omegaconf import AnyNode, Container, DictConfig, ListConfig
4
- from omegaconf._utils import Marker
5
- from omegaconf.basecontainer import BaseContainer
6
- from omegaconf.errors import ConfigKeyError
7
-
8
- _DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")
9
-
10
-
11
- def keys(
12
- key: str,
13
- _parent_: Container,
14
- ) -> ListConfig:
15
- from omegaconf import OmegaConf
16
-
17
- assert isinstance(_parent_, BaseContainer)
18
-
19
- in_dict = _get_and_validate_dict_input(
20
- key, parent=_parent_, resolver_name="oc.dict.keys"
21
- )
22
-
23
- ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
24
- assert isinstance(ret, ListConfig)
25
- return ret
26
-
27
-
28
- def values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
29
- assert isinstance(_parent_, BaseContainer)
30
- in_dict = _get_and_validate_dict_input(
31
- key, parent=_parent_, resolver_name="oc.dict.values"
32
- )
33
-
34
- content = in_dict._content
35
- assert isinstance(content, dict)
36
-
37
- ret = ListConfig([])
38
- if key.startswith("."):
39
- key = f".{key}" # extra dot to compensate for extra level of nesting within ret ListConfig
40
- for k in content:
41
- ref_node = AnyNode(f"${{{key}.{k!s}}}")
42
- ret.append(ref_node)
43
-
44
- # Finalize result by setting proper type and parent.
45
- element_type: Any = in_dict._metadata.element_type
46
- ret._metadata.element_type = element_type
47
- ret._metadata.ref_type = List[element_type]
48
- ret._set_parent(_parent_)
49
-
50
- return ret
51
-
52
-
53
- def _get_and_validate_dict_input(
54
- key: str,
55
- parent: BaseContainer,
56
- resolver_name: str,
57
- ) -> DictConfig:
58
- from omegaconf._impl import select_value
59
-
60
- if not isinstance(key, str):
61
- raise TypeError(
62
- f"`{resolver_name}` requires a string as input, but obtained `{key}` "
63
- f"of type: {type(key).__name__}"
64
- )
65
-
66
- in_dict = select_value(
67
- parent,
68
- key,
69
- throw_on_missing=True,
70
- absolute_key=True,
71
- default=_DEFAULT_SELECT_MARKER_,
72
- )
73
-
74
- if in_dict is _DEFAULT_SELECT_MARKER_:
75
- raise ConfigKeyError(f"Key not found: '{key}'")
76
-
77
- if not isinstance(in_dict, DictConfig):
78
- raise TypeError(
79
- f"`{resolver_name}` cannot be applied to objects of type: "
80
- f"{type(in_dict).__name__}"
81
- )
82
-
83
- return in_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omegaconf/version.py DELETED
@@ -1,13 +0,0 @@
1
- import sys # pragma: no cover
2
-
3
- __version__ = "2.3.0"
4
-
5
- msg = """OmegaConf 2.0 and above is compatible with Python 3.6 and newer.
6
- You have the following options:
7
- 1. Upgrade to Python 3.6 or newer.
8
- This is highly recommended. new features will not be added to OmegaConf 1.4.
9
- 2. Continue using OmegaConf 1.4:
10
- You can pip install 'OmegaConf<1.5' to do that.
11
- """
12
- if sys.version_info < (3, 6):
13
- raise ImportError(msg) # pragma: no cover
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -7,4 +7,4 @@ tensorboard
7
  slider==0.8.1
8
  torch_tb_profiler
9
  rosu_pp_py
10
- wandb
 
7
  slider==0.8.1
8
  torch_tb_profiler
9
  rosu_pp_py
10
+ omegaconf