edit
Browse files- omegaconf/__init__.py +0 -65
- omegaconf/_impl.py +0 -101
- omegaconf/_utils.py +0 -1039
- omegaconf/base.py +0 -962
- omegaconf/basecontainer.py +0 -916
- omegaconf/dictconfig.py +0 -776
- omegaconf/errors.py +0 -141
- omegaconf/grammar/OmegaConfGrammarLexer.g4 +0 -137
- omegaconf/grammar/OmegaConfGrammarParser.g4 +0 -91
- omegaconf/grammar/__init__.py +0 -0
- omegaconf/grammar/gen/__init__.py +0 -0
- omegaconf/grammar_parser.py +0 -144
- omegaconf/grammar_visitor.py +0 -392
- omegaconf/listconfig.py +0 -679
- omegaconf/nodes.py +0 -545
- omegaconf/omegaconf.py +0 -1157
- omegaconf/py.typed +0 -0
- omegaconf/resolvers/__init__.py +0 -5
- omegaconf/resolvers/oc/__init__.py +0 -113
- omegaconf/resolvers/oc/dict.py +0 -83
- omegaconf/version.py +0 -13
- requirements.txt +1 -1
omegaconf/__init__.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
from .base import Container, DictKeyType, Node, SCMode, UnionNode
|
| 2 |
-
from .dictconfig import DictConfig
|
| 3 |
-
from .errors import (
|
| 4 |
-
KeyValidationError,
|
| 5 |
-
MissingMandatoryValue,
|
| 6 |
-
ReadonlyConfigError,
|
| 7 |
-
UnsupportedValueType,
|
| 8 |
-
ValidationError,
|
| 9 |
-
)
|
| 10 |
-
from .listconfig import ListConfig
|
| 11 |
-
from .nodes import (
|
| 12 |
-
AnyNode,
|
| 13 |
-
BooleanNode,
|
| 14 |
-
BytesNode,
|
| 15 |
-
EnumNode,
|
| 16 |
-
FloatNode,
|
| 17 |
-
IntegerNode,
|
| 18 |
-
PathNode,
|
| 19 |
-
StringNode,
|
| 20 |
-
ValueNode,
|
| 21 |
-
)
|
| 22 |
-
from .omegaconf import (
|
| 23 |
-
II,
|
| 24 |
-
MISSING,
|
| 25 |
-
SI,
|
| 26 |
-
OmegaConf,
|
| 27 |
-
Resolver,
|
| 28 |
-
flag_override,
|
| 29 |
-
open_dict,
|
| 30 |
-
read_write,
|
| 31 |
-
)
|
| 32 |
-
from .version import __version__
|
| 33 |
-
|
| 34 |
-
__all__ = [
|
| 35 |
-
"__version__",
|
| 36 |
-
"MissingMandatoryValue",
|
| 37 |
-
"ValidationError",
|
| 38 |
-
"ReadonlyConfigError",
|
| 39 |
-
"UnsupportedValueType",
|
| 40 |
-
"KeyValidationError",
|
| 41 |
-
"Container",
|
| 42 |
-
"UnionNode",
|
| 43 |
-
"ListConfig",
|
| 44 |
-
"DictConfig",
|
| 45 |
-
"DictKeyType",
|
| 46 |
-
"OmegaConf",
|
| 47 |
-
"Resolver",
|
| 48 |
-
"SCMode",
|
| 49 |
-
"flag_override",
|
| 50 |
-
"read_write",
|
| 51 |
-
"open_dict",
|
| 52 |
-
"Node",
|
| 53 |
-
"ValueNode",
|
| 54 |
-
"AnyNode",
|
| 55 |
-
"IntegerNode",
|
| 56 |
-
"StringNode",
|
| 57 |
-
"BytesNode",
|
| 58 |
-
"PathNode",
|
| 59 |
-
"BooleanNode",
|
| 60 |
-
"EnumNode",
|
| 61 |
-
"FloatNode",
|
| 62 |
-
"MISSING",
|
| 63 |
-
"SI",
|
| 64 |
-
"II",
|
| 65 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/_impl.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
-
|
| 3 |
-
from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode
|
| 4 |
-
from omegaconf.errors import ConfigTypeError, InterpolationToMissingValueError
|
| 5 |
-
|
| 6 |
-
from ._utils import _DEFAULT_MARKER_, _get_value
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def _resolve_container_value(cfg: Container, key: Any) -> None:
|
| 10 |
-
node = cfg._get_child(key)
|
| 11 |
-
assert isinstance(node, Node)
|
| 12 |
-
if node._is_interpolation():
|
| 13 |
-
try:
|
| 14 |
-
resolved = node._dereference_node()
|
| 15 |
-
except InterpolationToMissingValueError:
|
| 16 |
-
node._set_value(MISSING)
|
| 17 |
-
else:
|
| 18 |
-
if isinstance(resolved, Container):
|
| 19 |
-
_resolve(resolved)
|
| 20 |
-
if isinstance(resolved, Container) and isinstance(node, ValueNode):
|
| 21 |
-
cfg[key] = resolved
|
| 22 |
-
else:
|
| 23 |
-
node._set_value(_get_value(resolved))
|
| 24 |
-
else:
|
| 25 |
-
_resolve(node)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _resolve(cfg: Node) -> Node:
|
| 29 |
-
assert isinstance(cfg, Node)
|
| 30 |
-
if cfg._is_interpolation():
|
| 31 |
-
try:
|
| 32 |
-
resolved = cfg._dereference_node()
|
| 33 |
-
except InterpolationToMissingValueError:
|
| 34 |
-
cfg._set_value(MISSING)
|
| 35 |
-
else:
|
| 36 |
-
cfg._set_value(resolved._value())
|
| 37 |
-
|
| 38 |
-
if isinstance(cfg, DictConfig):
|
| 39 |
-
for k in cfg.keys():
|
| 40 |
-
_resolve_container_value(cfg, k)
|
| 41 |
-
|
| 42 |
-
elif isinstance(cfg, ListConfig):
|
| 43 |
-
for i in range(len(cfg)):
|
| 44 |
-
_resolve_container_value(cfg, i)
|
| 45 |
-
|
| 46 |
-
return cfg
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def select_value(
|
| 50 |
-
cfg: Container,
|
| 51 |
-
key: str,
|
| 52 |
-
*,
|
| 53 |
-
default: Any = _DEFAULT_MARKER_,
|
| 54 |
-
throw_on_resolution_failure: bool = True,
|
| 55 |
-
throw_on_missing: bool = False,
|
| 56 |
-
absolute_key: bool = False,
|
| 57 |
-
) -> Any:
|
| 58 |
-
node = select_node(
|
| 59 |
-
cfg=cfg,
|
| 60 |
-
key=key,
|
| 61 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 62 |
-
throw_on_missing=throw_on_missing,
|
| 63 |
-
absolute_key=absolute_key,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
node_not_found = node is None
|
| 67 |
-
if node_not_found or node._is_missing():
|
| 68 |
-
if default is not _DEFAULT_MARKER_:
|
| 69 |
-
return default
|
| 70 |
-
else:
|
| 71 |
-
return None
|
| 72 |
-
|
| 73 |
-
return _get_value(node)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def select_node(
|
| 77 |
-
cfg: Container,
|
| 78 |
-
key: str,
|
| 79 |
-
*,
|
| 80 |
-
throw_on_resolution_failure: bool = True,
|
| 81 |
-
throw_on_missing: bool = False,
|
| 82 |
-
absolute_key: bool = False,
|
| 83 |
-
) -> Any:
|
| 84 |
-
try:
|
| 85 |
-
# for non relative keys, the interpretation can be:
|
| 86 |
-
# 1. relative to cfg
|
| 87 |
-
# 2. relative to the config root
|
| 88 |
-
# This is controlled by the absolute_key flag. By default, such keys are relative to cfg.
|
| 89 |
-
if not absolute_key and not key.startswith("."):
|
| 90 |
-
key = f".{key}"
|
| 91 |
-
|
| 92 |
-
cfg, key = cfg._resolve_key_and_root(key)
|
| 93 |
-
_root, _last_key, node = cfg._select_impl(
|
| 94 |
-
key,
|
| 95 |
-
throw_on_missing=throw_on_missing,
|
| 96 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 97 |
-
)
|
| 98 |
-
except ConfigTypeError:
|
| 99 |
-
return None
|
| 100 |
-
|
| 101 |
-
return node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/_utils.py
DELETED
|
@@ -1,1039 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import os
|
| 3 |
-
import pathlib
|
| 4 |
-
import re
|
| 5 |
-
import string
|
| 6 |
-
import sys
|
| 7 |
-
import types
|
| 8 |
-
import warnings
|
| 9 |
-
from contextlib import contextmanager
|
| 10 |
-
from enum import Enum
|
| 11 |
-
from textwrap import dedent
|
| 12 |
-
from typing import (
|
| 13 |
-
Any,
|
| 14 |
-
Dict,
|
| 15 |
-
Iterator,
|
| 16 |
-
List,
|
| 17 |
-
Optional,
|
| 18 |
-
Tuple,
|
| 19 |
-
Type,
|
| 20 |
-
Union,
|
| 21 |
-
get_type_hints,
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
import yaml
|
| 25 |
-
|
| 26 |
-
from .errors import (
|
| 27 |
-
ConfigIndexError,
|
| 28 |
-
ConfigTypeError,
|
| 29 |
-
ConfigValueError,
|
| 30 |
-
GrammarParseError,
|
| 31 |
-
OmegaConfBaseException,
|
| 32 |
-
ValidationError,
|
| 33 |
-
)
|
| 34 |
-
from .grammar_parser import SIMPLE_INTERPOLATION_PATTERN, parse
|
| 35 |
-
|
| 36 |
-
try:
|
| 37 |
-
import dataclasses
|
| 38 |
-
|
| 39 |
-
except ImportError: # pragma: no cover
|
| 40 |
-
dataclasses = None # type: ignore # pragma: no cover
|
| 41 |
-
|
| 42 |
-
try:
|
| 43 |
-
import attr
|
| 44 |
-
|
| 45 |
-
except ImportError: # pragma: no cover
|
| 46 |
-
attr = None # type: ignore # pragma: no cover
|
| 47 |
-
|
| 48 |
-
NoneType: Type[None] = type(None)
|
| 49 |
-
|
| 50 |
-
BUILTIN_VALUE_TYPES: Tuple[Type[Any], ...] = (
|
| 51 |
-
int,
|
| 52 |
-
float,
|
| 53 |
-
bool,
|
| 54 |
-
str,
|
| 55 |
-
bytes,
|
| 56 |
-
NoneType,
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
# Regexprs to match key paths like: a.b, a[b], ..a[c].d, etc.
|
| 60 |
-
# We begin by matching the head (in these examples: a, a, ..a).
|
| 61 |
-
# This can be read as "dots followed by any character but `.` or `[`"
|
| 62 |
-
# Note that a key starting with brackets, like [a], is purposedly *not*
|
| 63 |
-
# matched here and will instead be handled in the next regex below (this
|
| 64 |
-
# is to keep this regex simple).
|
| 65 |
-
KEY_PATH_HEAD = re.compile(r"(\.)*[^.[]*")
|
| 66 |
-
# Then we match other keys. The following expression matches one key and can
|
| 67 |
-
# be read as a choice between two syntaxes:
|
| 68 |
-
# - `.` followed by anything except `.` or `[` (ex: .b, .d)
|
| 69 |
-
# - `[` followed by anything then `]` (ex: [b], [c])
|
| 70 |
-
KEY_PATH_OTHER = re.compile(r"\.([^.[]*)|\[(.*?)\]")
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# source: https://yaml.org/type/bool.html
|
| 74 |
-
YAML_BOOL_TYPES = [
|
| 75 |
-
"y",
|
| 76 |
-
"Y",
|
| 77 |
-
"yes",
|
| 78 |
-
"Yes",
|
| 79 |
-
"YES",
|
| 80 |
-
"n",
|
| 81 |
-
"N",
|
| 82 |
-
"no",
|
| 83 |
-
"No",
|
| 84 |
-
"NO",
|
| 85 |
-
"true",
|
| 86 |
-
"True",
|
| 87 |
-
"TRUE",
|
| 88 |
-
"false",
|
| 89 |
-
"False",
|
| 90 |
-
"FALSE",
|
| 91 |
-
"on",
|
| 92 |
-
"On",
|
| 93 |
-
"ON",
|
| 94 |
-
"off",
|
| 95 |
-
"Off",
|
| 96 |
-
"OFF",
|
| 97 |
-
]
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
class Marker:
|
| 101 |
-
def __init__(self, desc: str):
|
| 102 |
-
self.desc = desc
|
| 103 |
-
|
| 104 |
-
def __repr__(self) -> str:
|
| 105 |
-
return self.desc
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# To be used as default value when `None` is not an option.
|
| 109 |
-
_DEFAULT_MARKER_: Any = Marker("_DEFAULT_MARKER_")
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class OmegaConfDumper(yaml.Dumper): # type: ignore
|
| 113 |
-
str_representer_added = False
|
| 114 |
-
|
| 115 |
-
@staticmethod
|
| 116 |
-
def str_representer(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
|
| 117 |
-
with_quotes = yaml_is_bool(data) or is_int(data) or is_float(data)
|
| 118 |
-
return dumper.represent_scalar(
|
| 119 |
-
yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG,
|
| 120 |
-
data,
|
| 121 |
-
style=("'" if with_quotes else None),
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def get_omega_conf_dumper() -> Type[OmegaConfDumper]:
|
| 126 |
-
if not OmegaConfDumper.str_representer_added:
|
| 127 |
-
OmegaConfDumper.add_representer(str, OmegaConfDumper.str_representer)
|
| 128 |
-
OmegaConfDumper.str_representer_added = True
|
| 129 |
-
return OmegaConfDumper
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def yaml_is_bool(b: str) -> bool:
|
| 133 |
-
return b in YAML_BOOL_TYPES
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def get_yaml_loader() -> Any:
|
| 137 |
-
class OmegaConfLoader(yaml.SafeLoader): # type: ignore
|
| 138 |
-
def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Any:
|
| 139 |
-
keys = set()
|
| 140 |
-
for key_node, value_node in node.value:
|
| 141 |
-
if key_node.tag != yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG:
|
| 142 |
-
continue
|
| 143 |
-
if key_node.value in keys:
|
| 144 |
-
raise yaml.constructor.ConstructorError(
|
| 145 |
-
"while constructing a mapping",
|
| 146 |
-
node.start_mark,
|
| 147 |
-
f"found duplicate key {key_node.value}",
|
| 148 |
-
key_node.start_mark,
|
| 149 |
-
)
|
| 150 |
-
keys.add(key_node.value)
|
| 151 |
-
return super().construct_mapping(node, deep=deep)
|
| 152 |
-
|
| 153 |
-
loader = OmegaConfLoader
|
| 154 |
-
loader.add_implicit_resolver(
|
| 155 |
-
"tag:yaml.org,2002:float",
|
| 156 |
-
re.compile(
|
| 157 |
-
"""^(?:
|
| 158 |
-
[-+]?[0-9]+(?:_[0-9]+)*\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
| 159 |
-
|[-+]?[0-9]+(?:_[0-9]+)*(?:[eE][-+]?[0-9]+)
|
| 160 |
-
|\\.[0-9]+(?:_[0-9]+)*(?:[eE][-+][0-9]+)?
|
| 161 |
-
|[-+]?[0-9]+(?:_[0-9]+)*(?::[0-5]?[0-9])+\\.[0-9_]*
|
| 162 |
-
|[-+]?\\.(?:inf|Inf|INF)
|
| 163 |
-
|\\.(?:nan|NaN|NAN))$""",
|
| 164 |
-
re.X,
|
| 165 |
-
),
|
| 166 |
-
list("-+0123456789."),
|
| 167 |
-
)
|
| 168 |
-
loader.yaml_implicit_resolvers = {
|
| 169 |
-
key: [
|
| 170 |
-
(tag, regexp)
|
| 171 |
-
for tag, regexp in resolvers
|
| 172 |
-
if tag != "tag:yaml.org,2002:timestamp"
|
| 173 |
-
]
|
| 174 |
-
for key, resolvers in loader.yaml_implicit_resolvers.items()
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
loader.add_constructor(
|
| 178 |
-
"tag:yaml.org,2002:python/object/apply:pathlib.Path",
|
| 179 |
-
lambda loader, node: pathlib.Path(*loader.construct_sequence(node)),
|
| 180 |
-
)
|
| 181 |
-
loader.add_constructor(
|
| 182 |
-
"tag:yaml.org,2002:python/object/apply:pathlib.PosixPath",
|
| 183 |
-
lambda loader, node: pathlib.PosixPath(*loader.construct_sequence(node)),
|
| 184 |
-
)
|
| 185 |
-
loader.add_constructor(
|
| 186 |
-
"tag:yaml.org,2002:python/object/apply:pathlib.WindowsPath",
|
| 187 |
-
lambda loader, node: pathlib.WindowsPath(*loader.construct_sequence(node)),
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
return loader
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def _get_class(path: str) -> type:
|
| 194 |
-
from importlib import import_module
|
| 195 |
-
|
| 196 |
-
module_path, _, class_name = path.rpartition(".")
|
| 197 |
-
mod = import_module(module_path)
|
| 198 |
-
try:
|
| 199 |
-
klass: type = getattr(mod, class_name)
|
| 200 |
-
except AttributeError:
|
| 201 |
-
raise ImportError(f"Class {class_name} is not in module {module_path}")
|
| 202 |
-
return klass
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def is_union_annotation(type_: Any) -> bool:
|
| 206 |
-
if sys.version_info >= (3, 10): # pragma: no cover
|
| 207 |
-
if isinstance(type_, types.UnionType):
|
| 208 |
-
return True
|
| 209 |
-
return getattr(type_, "__origin__", None) is Union
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
| 213 |
-
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
| 214 |
-
if is_union_annotation(type_):
|
| 215 |
-
args = type_.__args__
|
| 216 |
-
if NoneType in args:
|
| 217 |
-
optional = True
|
| 218 |
-
args = tuple(a for a in args if a is not NoneType)
|
| 219 |
-
else:
|
| 220 |
-
optional = False
|
| 221 |
-
if len(args) == 1:
|
| 222 |
-
return optional, args[0]
|
| 223 |
-
elif len(args) >= 2:
|
| 224 |
-
return optional, Union[args]
|
| 225 |
-
else:
|
| 226 |
-
assert False
|
| 227 |
-
|
| 228 |
-
if type_ is Any:
|
| 229 |
-
return True, Any
|
| 230 |
-
|
| 231 |
-
if type_ in (None, NoneType):
|
| 232 |
-
return True, NoneType
|
| 233 |
-
|
| 234 |
-
return False, type_
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def _is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool:
|
| 238 |
-
"""Check `obj` metadata to see if the given node is optional."""
|
| 239 |
-
from .base import Container, Node
|
| 240 |
-
|
| 241 |
-
if key is not None:
|
| 242 |
-
assert isinstance(obj, Container)
|
| 243 |
-
obj = obj._get_node(key)
|
| 244 |
-
assert isinstance(obj, Node)
|
| 245 |
-
return obj._is_optional()
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
|
| 249 |
-
import typing # lgtm [py/import-and-import-from]
|
| 250 |
-
|
| 251 |
-
forward = typing.ForwardRef if hasattr(typing, "ForwardRef") else typing._ForwardRef # type: ignore
|
| 252 |
-
if type(type_) is forward:
|
| 253 |
-
return _get_class(f"{module}.{type_.__forward_arg__}")
|
| 254 |
-
else:
|
| 255 |
-
if is_dict_annotation(type_):
|
| 256 |
-
kt, vt = get_dict_key_value_types(type_)
|
| 257 |
-
if kt is not None:
|
| 258 |
-
kt = _resolve_forward(kt, module=module)
|
| 259 |
-
if vt is not None:
|
| 260 |
-
vt = _resolve_forward(vt, module=module)
|
| 261 |
-
return Dict[kt, vt] # type: ignore
|
| 262 |
-
if is_list_annotation(type_):
|
| 263 |
-
et = get_list_element_type(type_)
|
| 264 |
-
if et is not None:
|
| 265 |
-
et = _resolve_forward(et, module=module)
|
| 266 |
-
return List[et] # type: ignore
|
| 267 |
-
if is_tuple_annotation(type_):
|
| 268 |
-
its = get_tuple_item_types(type_)
|
| 269 |
-
its = tuple(_resolve_forward(it, module=module) for it in its)
|
| 270 |
-
return Tuple[its] # type: ignore
|
| 271 |
-
|
| 272 |
-
return type_
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]]:
|
| 276 |
-
"""Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
|
| 277 |
-
from omegaconf.omegaconf import _maybe_wrap
|
| 278 |
-
|
| 279 |
-
is_type = isinstance(obj, type)
|
| 280 |
-
obj_type = obj if is_type else type(obj)
|
| 281 |
-
subclasses_dict = is_dict_subclass(obj_type)
|
| 282 |
-
|
| 283 |
-
if subclasses_dict:
|
| 284 |
-
warnings.warn(
|
| 285 |
-
f"Class `{obj_type.__name__}` subclasses `Dict`."
|
| 286 |
-
+ " Subclassing `Dict` in Structured Config classes is deprecated,"
|
| 287 |
-
+ " see github.com/omry/omegaconf/issues/663",
|
| 288 |
-
UserWarning,
|
| 289 |
-
stacklevel=9,
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
if is_type:
|
| 293 |
-
return None
|
| 294 |
-
elif subclasses_dict:
|
| 295 |
-
dict_subclass_data = {}
|
| 296 |
-
key_type, element_type = get_dict_key_value_types(obj_type)
|
| 297 |
-
for name, value in obj.items():
|
| 298 |
-
is_optional, type_ = _resolve_optional(element_type)
|
| 299 |
-
type_ = _resolve_forward(type_, obj.__module__)
|
| 300 |
-
try:
|
| 301 |
-
dict_subclass_data[name] = _maybe_wrap(
|
| 302 |
-
ref_type=type_,
|
| 303 |
-
is_optional=is_optional,
|
| 304 |
-
key=name,
|
| 305 |
-
value=value,
|
| 306 |
-
parent=parent,
|
| 307 |
-
)
|
| 308 |
-
except ValidationError as ex:
|
| 309 |
-
format_and_raise(
|
| 310 |
-
node=None, key=name, value=value, cause=ex, msg=str(ex)
|
| 311 |
-
)
|
| 312 |
-
return dict_subclass_data
|
| 313 |
-
else:
|
| 314 |
-
return None
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
def get_attr_class_fields(obj: Any) -> List["attr.Attribute[Any]"]:
|
| 318 |
-
is_type = isinstance(obj, type)
|
| 319 |
-
obj_type = obj if is_type else type(obj)
|
| 320 |
-
fields = attr.fields_dict(obj_type).values()
|
| 321 |
-
return [f for f in fields if f.metadata.get("omegaconf_ignore") is not True]
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
|
| 325 |
-
from omegaconf.omegaconf import OmegaConf, _maybe_wrap
|
| 326 |
-
|
| 327 |
-
flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
|
| 328 |
-
|
| 329 |
-
from omegaconf import MISSING
|
| 330 |
-
|
| 331 |
-
d = {}
|
| 332 |
-
is_type = isinstance(obj, type)
|
| 333 |
-
obj_type = obj if is_type else type(obj)
|
| 334 |
-
dummy_parent = OmegaConf.create({}, flags=flags)
|
| 335 |
-
dummy_parent._metadata.object_type = obj_type
|
| 336 |
-
resolved_hints = get_type_hints(obj_type)
|
| 337 |
-
|
| 338 |
-
for attrib in get_attr_class_fields(obj):
|
| 339 |
-
name = attrib.name
|
| 340 |
-
is_optional, type_ = _resolve_optional(resolved_hints[name])
|
| 341 |
-
type_ = _resolve_forward(type_, obj.__module__)
|
| 342 |
-
if not is_type:
|
| 343 |
-
value = getattr(obj, name)
|
| 344 |
-
else:
|
| 345 |
-
value = attrib.default
|
| 346 |
-
if value == attr.NOTHING:
|
| 347 |
-
value = MISSING
|
| 348 |
-
if is_union_annotation(type_) and not is_supported_union_annotation(type_):
|
| 349 |
-
e = ConfigValueError(
|
| 350 |
-
f"Unions of containers are not supported:\n{name}: {type_str(type_)}"
|
| 351 |
-
)
|
| 352 |
-
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
|
| 353 |
-
|
| 354 |
-
try:
|
| 355 |
-
d[name] = _maybe_wrap(
|
| 356 |
-
ref_type=type_,
|
| 357 |
-
is_optional=is_optional,
|
| 358 |
-
key=name,
|
| 359 |
-
value=value,
|
| 360 |
-
parent=dummy_parent,
|
| 361 |
-
)
|
| 362 |
-
except (ValidationError, GrammarParseError) as ex:
|
| 363 |
-
format_and_raise(
|
| 364 |
-
node=dummy_parent, key=name, value=value, cause=ex, msg=str(ex)
|
| 365 |
-
)
|
| 366 |
-
d[name]._set_parent(None)
|
| 367 |
-
dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
|
| 368 |
-
if dict_subclass_data is not None:
|
| 369 |
-
d.update(dict_subclass_data)
|
| 370 |
-
return d
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
def get_dataclass_fields(obj: Any) -> List["dataclasses.Field[Any]"]:
|
| 374 |
-
fields = dataclasses.fields(obj)
|
| 375 |
-
return [f for f in fields if f.metadata.get("omegaconf_ignore") is not True]
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
def get_dataclass_data(
|
| 379 |
-
obj: Any, allow_objects: Optional[bool] = None
|
| 380 |
-
) -> Dict[str, Any]:
|
| 381 |
-
from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap
|
| 382 |
-
|
| 383 |
-
flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
|
| 384 |
-
d = {}
|
| 385 |
-
is_type = isinstance(obj, type)
|
| 386 |
-
obj_type = get_type_of(obj)
|
| 387 |
-
dummy_parent = OmegaConf.create({}, flags=flags)
|
| 388 |
-
dummy_parent._metadata.object_type = obj_type
|
| 389 |
-
resolved_hints = get_type_hints(obj_type)
|
| 390 |
-
for field in get_dataclass_fields(obj):
|
| 391 |
-
name = field.name
|
| 392 |
-
is_optional, type_ = _resolve_optional(resolved_hints[field.name])
|
| 393 |
-
type_ = _resolve_forward(type_, obj.__module__)
|
| 394 |
-
has_default = field.default != dataclasses.MISSING
|
| 395 |
-
has_default_factory = field.default_factory != dataclasses.MISSING
|
| 396 |
-
|
| 397 |
-
if not is_type:
|
| 398 |
-
value = getattr(obj, name)
|
| 399 |
-
else:
|
| 400 |
-
if has_default:
|
| 401 |
-
value = field.default
|
| 402 |
-
elif has_default_factory:
|
| 403 |
-
value = field.default_factory() # type: ignore
|
| 404 |
-
else:
|
| 405 |
-
value = MISSING
|
| 406 |
-
|
| 407 |
-
if is_union_annotation(type_) and not is_supported_union_annotation(type_):
|
| 408 |
-
e = ConfigValueError(
|
| 409 |
-
f"Unions of containers are not supported:\n{name}: {type_str(type_)}"
|
| 410 |
-
)
|
| 411 |
-
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
|
| 412 |
-
try:
|
| 413 |
-
d[name] = _maybe_wrap(
|
| 414 |
-
ref_type=type_,
|
| 415 |
-
is_optional=is_optional,
|
| 416 |
-
key=name,
|
| 417 |
-
value=value,
|
| 418 |
-
parent=dummy_parent,
|
| 419 |
-
)
|
| 420 |
-
except (ValidationError, GrammarParseError) as ex:
|
| 421 |
-
format_and_raise(
|
| 422 |
-
node=dummy_parent, key=name, value=value, cause=ex, msg=str(ex)
|
| 423 |
-
)
|
| 424 |
-
d[name]._set_parent(None)
|
| 425 |
-
dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
|
| 426 |
-
if dict_subclass_data is not None:
|
| 427 |
-
d.update(dict_subclass_data)
|
| 428 |
-
return d
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
def is_dataclass(obj: Any) -> bool:
|
| 432 |
-
from omegaconf.base import Node
|
| 433 |
-
|
| 434 |
-
if dataclasses is None or isinstance(obj, Node):
|
| 435 |
-
return False
|
| 436 |
-
return dataclasses.is_dataclass(obj)
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
def is_attr_class(obj: Any) -> bool:
|
| 440 |
-
from omegaconf.base import Node
|
| 441 |
-
|
| 442 |
-
if attr is None or isinstance(obj, Node):
|
| 443 |
-
return False
|
| 444 |
-
return attr.has(obj)
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
def is_structured_config(obj: Any) -> bool:
|
| 448 |
-
return is_attr_class(obj) or is_dataclass(obj)
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
def is_dataclass_frozen(type_: Any) -> bool:
|
| 452 |
-
return type_.__dataclass_params__.frozen # type: ignore
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
def is_attr_frozen(type_: type) -> bool:
|
| 456 |
-
# This is very hacky and probably fragile as well.
|
| 457 |
-
# Unfortunately currently there isn't an official API in attr that can detect that.
|
| 458 |
-
# noinspection PyProtectedMember
|
| 459 |
-
return type_.__setattr__ == attr._make._frozen_setattrs # type: ignore
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
def get_type_of(class_or_object: Any) -> Type[Any]:
|
| 463 |
-
type_ = class_or_object
|
| 464 |
-
if not isinstance(type_, type):
|
| 465 |
-
type_ = type(class_or_object)
|
| 466 |
-
assert isinstance(type_, type)
|
| 467 |
-
return type_
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
def is_structured_config_frozen(obj: Any) -> bool:
|
| 471 |
-
type_ = get_type_of(obj)
|
| 472 |
-
|
| 473 |
-
if is_dataclass(type_):
|
| 474 |
-
return is_dataclass_frozen(type_)
|
| 475 |
-
if is_attr_class(type_):
|
| 476 |
-
return is_attr_frozen(type_)
|
| 477 |
-
return False
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
def get_structured_config_init_field_names(obj: Any) -> List[str]:
|
| 481 |
-
fields: Union[List["dataclasses.Field[Any]"], List["attr.Attribute[Any]"]]
|
| 482 |
-
if is_dataclass(obj):
|
| 483 |
-
fields = get_dataclass_fields(obj)
|
| 484 |
-
elif is_attr_class(obj):
|
| 485 |
-
fields = get_attr_class_fields(obj)
|
| 486 |
-
else:
|
| 487 |
-
raise ValueError(f"Unsupported type: {type(obj).__name__}")
|
| 488 |
-
return [f.name for f in fields if f.init]
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
def get_structured_config_data(
|
| 492 |
-
obj: Any, allow_objects: Optional[bool] = None
|
| 493 |
-
) -> Dict[str, Any]:
|
| 494 |
-
if is_dataclass(obj):
|
| 495 |
-
return get_dataclass_data(obj, allow_objects=allow_objects)
|
| 496 |
-
elif is_attr_class(obj):
|
| 497 |
-
return get_attr_data(obj, allow_objects=allow_objects)
|
| 498 |
-
else:
|
| 499 |
-
raise ValueError(f"Unsupported type: {type(obj).__name__}")
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
class ValueKind(Enum):
|
| 503 |
-
VALUE = 0
|
| 504 |
-
MANDATORY_MISSING = 1
|
| 505 |
-
INTERPOLATION = 2
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
def _is_missing_value(value: Any) -> bool:
|
| 509 |
-
from omegaconf import Node
|
| 510 |
-
|
| 511 |
-
if isinstance(value, Node):
|
| 512 |
-
value = value._value()
|
| 513 |
-
return _is_missing_literal(value)
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
def _is_missing_literal(value: Any) -> bool:
|
| 517 |
-
# Uses literal '???' instead of the MISSING const for performance reasons.
|
| 518 |
-
return isinstance(value, str) and value == "???"
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
def _is_none(
|
| 522 |
-
value: Any, resolve: bool = False, throw_on_resolution_failure: bool = True
|
| 523 |
-
) -> bool:
|
| 524 |
-
from omegaconf import Node
|
| 525 |
-
|
| 526 |
-
if not isinstance(value, Node):
|
| 527 |
-
return value is None
|
| 528 |
-
|
| 529 |
-
if resolve:
|
| 530 |
-
value = value._maybe_dereference_node(
|
| 531 |
-
throw_on_resolution_failure=throw_on_resolution_failure
|
| 532 |
-
)
|
| 533 |
-
if not throw_on_resolution_failure and value is None:
|
| 534 |
-
# Resolution failure: consider that it is *not* None.
|
| 535 |
-
return False
|
| 536 |
-
assert isinstance(value, Node)
|
| 537 |
-
|
| 538 |
-
return value._is_none()
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
def get_value_kind(
|
| 542 |
-
value: Any, strict_interpolation_validation: bool = False
|
| 543 |
-
) -> ValueKind:
|
| 544 |
-
"""
|
| 545 |
-
Determine the kind of a value
|
| 546 |
-
Examples:
|
| 547 |
-
VALUE: "10", "20", True
|
| 548 |
-
MANDATORY_MISSING: "???"
|
| 549 |
-
INTERPOLATION: "${foo.bar}", "${foo.${bar}}", "${foo:bar}", "[${foo}, ${bar}]",
|
| 550 |
-
"ftp://${host}/path", "${foo:${bar}, [true], {'baz': ${baz}}}"
|
| 551 |
-
|
| 552 |
-
:param value: Input to classify.
|
| 553 |
-
:param strict_interpolation_validation: If `True`, then when `value` is a string
|
| 554 |
-
containing "${", it is parsed to validate the interpolation syntax. If `False`,
|
| 555 |
-
this parsing step is skipped: this is more efficient, but will not detect errors.
|
| 556 |
-
"""
|
| 557 |
-
|
| 558 |
-
if _is_missing_value(value):
|
| 559 |
-
return ValueKind.MANDATORY_MISSING
|
| 560 |
-
|
| 561 |
-
if _is_interpolation(value, strict_interpolation_validation):
|
| 562 |
-
return ValueKind.INTERPOLATION
|
| 563 |
-
|
| 564 |
-
return ValueKind.VALUE
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
def _is_interpolation(v: Any, strict_interpolation_validation: bool = False) -> bool:
|
| 568 |
-
from omegaconf import Node
|
| 569 |
-
|
| 570 |
-
if isinstance(v, Node):
|
| 571 |
-
v = v._value()
|
| 572 |
-
|
| 573 |
-
if isinstance(v, str) and _is_interpolation_string(
|
| 574 |
-
v, strict_interpolation_validation
|
| 575 |
-
):
|
| 576 |
-
return True
|
| 577 |
-
return False
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
def _is_interpolation_string(value: str, strict_interpolation_validation: bool) -> bool:
|
| 581 |
-
# We identify potential interpolations by the presence of "${" in the string.
|
| 582 |
-
# Note that escaped interpolations (ex: "esc: \${bar}") are identified as
|
| 583 |
-
# interpolations: this is intended, since they must be processed as interpolations
|
| 584 |
-
# for the string to be properly un-escaped.
|
| 585 |
-
# Keep in mind that invalid interpolations will only be detected when
|
| 586 |
-
# `strict_interpolation_validation` is True.
|
| 587 |
-
if "${" in value:
|
| 588 |
-
if strict_interpolation_validation:
|
| 589 |
-
# First try the cheap regex matching that detects common interpolations.
|
| 590 |
-
if SIMPLE_INTERPOLATION_PATTERN.match(value) is None:
|
| 591 |
-
# If no match, do the more expensive grammar parsing to detect errors.
|
| 592 |
-
parse(value)
|
| 593 |
-
return True
|
| 594 |
-
return False
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
def _is_special(value: Any) -> bool:
|
| 598 |
-
"""Special values are None, MISSING, and interpolation."""
|
| 599 |
-
return _is_none(value) or get_value_kind(value) in (
|
| 600 |
-
ValueKind.MANDATORY_MISSING,
|
| 601 |
-
ValueKind.INTERPOLATION,
|
| 602 |
-
)
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
def is_float(st: str) -> bool:
|
| 606 |
-
try:
|
| 607 |
-
float(st)
|
| 608 |
-
return True
|
| 609 |
-
except ValueError:
|
| 610 |
-
return False
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
def is_int(st: str) -> bool:
|
| 614 |
-
try:
|
| 615 |
-
int(st)
|
| 616 |
-
return True
|
| 617 |
-
except ValueError:
|
| 618 |
-
return False
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
def is_primitive_list(obj: Any) -> bool:
|
| 622 |
-
return isinstance(obj, (list, tuple))
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
def is_primitive_dict(obj: Any) -> bool:
|
| 626 |
-
t = get_type_of(obj)
|
| 627 |
-
return t is dict
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
def is_dict_annotation(type_: Any) -> bool:
|
| 631 |
-
if type_ in (dict, Dict):
|
| 632 |
-
return True
|
| 633 |
-
origin = getattr(type_, "__origin__", None)
|
| 634 |
-
# type_dict is a bit hard to detect.
|
| 635 |
-
# this support is tentative, if it eventually causes issues in other areas it may be dropped.
|
| 636 |
-
if sys.version_info < (3, 7, 0): # pragma: no cover
|
| 637 |
-
typed_dict = hasattr(type_, "__base__") and type_.__base__ == Dict
|
| 638 |
-
return origin is Dict or type_ is Dict or typed_dict
|
| 639 |
-
else: # pragma: no cover
|
| 640 |
-
typed_dict = hasattr(type_, "__base__") and type_.__base__ == dict
|
| 641 |
-
return origin is dict or typed_dict
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
def is_list_annotation(type_: Any) -> bool:
|
| 645 |
-
if type_ in (list, List):
|
| 646 |
-
return True
|
| 647 |
-
origin = getattr(type_, "__origin__", None)
|
| 648 |
-
if sys.version_info < (3, 7, 0):
|
| 649 |
-
return origin is List or type_ is List # pragma: no cover
|
| 650 |
-
else:
|
| 651 |
-
return origin is list # pragma: no cover
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
def is_tuple_annotation(type_: Any) -> bool:
|
| 655 |
-
if type_ in (tuple, Tuple):
|
| 656 |
-
return True
|
| 657 |
-
origin = getattr(type_, "__origin__", None)
|
| 658 |
-
if sys.version_info < (3, 7, 0):
|
| 659 |
-
return origin is Tuple or type_ is Tuple # pragma: no cover
|
| 660 |
-
else:
|
| 661 |
-
return origin is tuple # pragma: no cover
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
def is_supported_union_annotation(obj: Any) -> bool:
|
| 665 |
-
"""Currently only primitive types are supported in Unions, e.g. Union[int, str]"""
|
| 666 |
-
if not is_union_annotation(obj):
|
| 667 |
-
return False
|
| 668 |
-
args = obj.__args__
|
| 669 |
-
return all(is_primitive_type_annotation(arg) for arg in args)
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
def is_dict_subclass(type_: Any) -> bool:
|
| 673 |
-
return type_ is not None and isinstance(type_, type) and issubclass(type_, Dict)
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
def is_dict(obj: Any) -> bool:
|
| 677 |
-
return is_primitive_dict(obj) or is_dict_annotation(obj) or is_dict_subclass(obj)
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
def is_primitive_container(obj: Any) -> bool:
|
| 681 |
-
return is_primitive_list(obj) or is_primitive_dict(obj)
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any:
|
| 685 |
-
args = getattr(ref_type, "__args__", None)
|
| 686 |
-
if ref_type is not List and args is not None and args[0]:
|
| 687 |
-
element_type = args[0]
|
| 688 |
-
else:
|
| 689 |
-
element_type = Any
|
| 690 |
-
return element_type
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
def get_tuple_item_types(ref_type: Type[Any]) -> Tuple[Any, ...]:
|
| 694 |
-
args = getattr(ref_type, "__args__", None)
|
| 695 |
-
if args in (None, ()):
|
| 696 |
-
args = (Any, ...)
|
| 697 |
-
assert isinstance(args, tuple)
|
| 698 |
-
return args
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:
|
| 702 |
-
args = getattr(ref_type, "__args__", None)
|
| 703 |
-
if args is None:
|
| 704 |
-
bases = getattr(ref_type, "__orig_bases__", None)
|
| 705 |
-
if bases is not None and len(bases) > 0:
|
| 706 |
-
args = getattr(bases[0], "__args__", None)
|
| 707 |
-
|
| 708 |
-
key_type: Any
|
| 709 |
-
element_type: Any
|
| 710 |
-
if ref_type is None or ref_type == Dict:
|
| 711 |
-
key_type = Any
|
| 712 |
-
element_type = Any
|
| 713 |
-
else:
|
| 714 |
-
if args is not None:
|
| 715 |
-
key_type = args[0]
|
| 716 |
-
element_type = args[1]
|
| 717 |
-
else:
|
| 718 |
-
key_type = Any
|
| 719 |
-
element_type = Any
|
| 720 |
-
|
| 721 |
-
return key_type, element_type
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
def is_valid_value_annotation(type_: Any) -> bool:
|
| 725 |
-
_, type_ = _resolve_optional(type_)
|
| 726 |
-
return (
|
| 727 |
-
type_ is Any
|
| 728 |
-
or is_primitive_type_annotation(type_)
|
| 729 |
-
or is_structured_config(type_)
|
| 730 |
-
or is_container_annotation(type_)
|
| 731 |
-
or is_supported_union_annotation(type_)
|
| 732 |
-
)
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
def _valid_dict_key_annotation_type(type_: Any) -> bool:
|
| 736 |
-
from omegaconf import DictKeyType
|
| 737 |
-
|
| 738 |
-
return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
def is_primitive_type_annotation(type_: Any) -> bool:
|
| 742 |
-
type_ = get_type_of(type_)
|
| 743 |
-
return issubclass(type_, (Enum, pathlib.Path)) or type_ in BUILTIN_VALUE_TYPES
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
def _get_value(value: Any) -> Any:
|
| 747 |
-
from .base import Container, UnionNode
|
| 748 |
-
from .nodes import ValueNode
|
| 749 |
-
|
| 750 |
-
if isinstance(value, ValueNode):
|
| 751 |
-
return value._value()
|
| 752 |
-
elif isinstance(value, Container):
|
| 753 |
-
boxed = value._value()
|
| 754 |
-
if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
|
| 755 |
-
return boxed
|
| 756 |
-
elif isinstance(value, UnionNode):
|
| 757 |
-
boxed = value._value()
|
| 758 |
-
if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
|
| 759 |
-
return boxed
|
| 760 |
-
else:
|
| 761 |
-
return _get_value(boxed) # pass through value of boxed node
|
| 762 |
-
|
| 763 |
-
# return primitives and regular OmegaConf Containers as is
|
| 764 |
-
return value
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
def get_type_hint(obj: Any, key: Any = None) -> Optional[Type[Any]]:
|
| 768 |
-
from omegaconf import Container, Node
|
| 769 |
-
|
| 770 |
-
if isinstance(obj, Container):
|
| 771 |
-
if key is not None:
|
| 772 |
-
obj = obj._get_node(key)
|
| 773 |
-
else:
|
| 774 |
-
if key is not None:
|
| 775 |
-
raise ValueError("Key must only be provided when obj is a container")
|
| 776 |
-
|
| 777 |
-
if isinstance(obj, Node):
|
| 778 |
-
ref_type = obj._metadata.ref_type
|
| 779 |
-
if obj._is_optional() and ref_type is not Any:
|
| 780 |
-
return Optional[ref_type] # type: ignore
|
| 781 |
-
else:
|
| 782 |
-
return ref_type
|
| 783 |
-
else:
|
| 784 |
-
return Any # type: ignore
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
def _raise(ex: Exception, cause: Exception) -> None:
|
| 788 |
-
# Set the environment variable OC_CAUSE=1 to get a stacktrace that includes the
|
| 789 |
-
# causing exception.
|
| 790 |
-
env_var = os.environ["OC_CAUSE"] if "OC_CAUSE" in os.environ else None
|
| 791 |
-
debugging = sys.gettrace() is not None
|
| 792 |
-
full_backtrace = (debugging and not env_var == "0") or (env_var == "1")
|
| 793 |
-
if full_backtrace:
|
| 794 |
-
ex.__cause__ = cause
|
| 795 |
-
else:
|
| 796 |
-
ex.__cause__ = None
|
| 797 |
-
raise ex.with_traceback(sys.exc_info()[2]) # set env var OC_CAUSE=1 for full trace
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
def format_and_raise(
|
| 801 |
-
node: Any,
|
| 802 |
-
key: Any,
|
| 803 |
-
value: Any,
|
| 804 |
-
msg: str,
|
| 805 |
-
cause: Exception,
|
| 806 |
-
type_override: Any = None,
|
| 807 |
-
) -> None:
|
| 808 |
-
from omegaconf import OmegaConf
|
| 809 |
-
from omegaconf.base import Node
|
| 810 |
-
|
| 811 |
-
if isinstance(cause, AssertionError):
|
| 812 |
-
raise
|
| 813 |
-
|
| 814 |
-
if isinstance(cause, OmegaConfBaseException) and cause._initialized:
|
| 815 |
-
ex = cause
|
| 816 |
-
if type_override is not None:
|
| 817 |
-
ex = type_override(str(cause))
|
| 818 |
-
ex.__dict__ = copy.deepcopy(cause.__dict__)
|
| 819 |
-
_raise(ex, cause)
|
| 820 |
-
|
| 821 |
-
object_type: Optional[Type[Any]]
|
| 822 |
-
object_type_str: Optional[str] = None
|
| 823 |
-
ref_type: Optional[Type[Any]]
|
| 824 |
-
ref_type_str: Optional[str]
|
| 825 |
-
|
| 826 |
-
child_node: Optional[Node] = None
|
| 827 |
-
if node is None:
|
| 828 |
-
full_key = key if key is not None else ""
|
| 829 |
-
object_type = None
|
| 830 |
-
ref_type = None
|
| 831 |
-
ref_type_str = None
|
| 832 |
-
else:
|
| 833 |
-
if key is not None and not node._is_none():
|
| 834 |
-
child_node = node._get_node(key, validate_access=False)
|
| 835 |
-
|
| 836 |
-
try:
|
| 837 |
-
full_key = node._get_full_key(key=key)
|
| 838 |
-
except Exception as exc:
|
| 839 |
-
# Since we are handling an exception, raising a different one here would
|
| 840 |
-
# be misleading. Instead, we display it in the key.
|
| 841 |
-
full_key = f"<unresolvable due to {type(exc).__name__}: {exc}>"
|
| 842 |
-
|
| 843 |
-
object_type = OmegaConf.get_type(node)
|
| 844 |
-
object_type_str = type_str(object_type)
|
| 845 |
-
|
| 846 |
-
ref_type = get_type_hint(node)
|
| 847 |
-
ref_type_str = type_str(ref_type)
|
| 848 |
-
|
| 849 |
-
msg = string.Template(msg).safe_substitute(
|
| 850 |
-
REF_TYPE=ref_type_str,
|
| 851 |
-
OBJECT_TYPE=object_type_str,
|
| 852 |
-
KEY=key,
|
| 853 |
-
FULL_KEY=full_key,
|
| 854 |
-
VALUE=value,
|
| 855 |
-
VALUE_TYPE=type_str(type(value), include_module_name=True),
|
| 856 |
-
KEY_TYPE=f"{type(key).__name__}",
|
| 857 |
-
)
|
| 858 |
-
|
| 859 |
-
if ref_type not in (None, Any):
|
| 860 |
-
template = dedent(
|
| 861 |
-
"""\
|
| 862 |
-
$MSG
|
| 863 |
-
full_key: $FULL_KEY
|
| 864 |
-
reference_type=$REF_TYPE
|
| 865 |
-
object_type=$OBJECT_TYPE"""
|
| 866 |
-
)
|
| 867 |
-
else:
|
| 868 |
-
template = dedent(
|
| 869 |
-
"""\
|
| 870 |
-
$MSG
|
| 871 |
-
full_key: $FULL_KEY
|
| 872 |
-
object_type=$OBJECT_TYPE"""
|
| 873 |
-
)
|
| 874 |
-
s = string.Template(template=template)
|
| 875 |
-
|
| 876 |
-
message = s.substitute(
|
| 877 |
-
REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, MSG=msg, FULL_KEY=full_key
|
| 878 |
-
)
|
| 879 |
-
exception_type = type(cause) if type_override is None else type_override
|
| 880 |
-
if exception_type == TypeError:
|
| 881 |
-
exception_type = ConfigTypeError
|
| 882 |
-
elif exception_type == IndexError:
|
| 883 |
-
exception_type = ConfigIndexError
|
| 884 |
-
|
| 885 |
-
ex = exception_type(f"{message}")
|
| 886 |
-
if issubclass(exception_type, OmegaConfBaseException):
|
| 887 |
-
ex._initialized = True
|
| 888 |
-
ex.msg = message
|
| 889 |
-
ex.parent_node = node
|
| 890 |
-
ex.child_node = child_node
|
| 891 |
-
ex.key = key
|
| 892 |
-
ex.full_key = full_key
|
| 893 |
-
ex.value = value
|
| 894 |
-
ex.object_type = object_type
|
| 895 |
-
ex.object_type_str = object_type_str
|
| 896 |
-
ex.ref_type = ref_type
|
| 897 |
-
ex.ref_type_str = ref_type_str
|
| 898 |
-
|
| 899 |
-
_raise(ex, cause)
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
def type_str(t: Any, include_module_name: bool = False) -> str:
|
| 903 |
-
is_optional, t = _resolve_optional(t)
|
| 904 |
-
if t is NoneType:
|
| 905 |
-
return str(t.__name__)
|
| 906 |
-
if t is Any:
|
| 907 |
-
return "Any"
|
| 908 |
-
if t is ...:
|
| 909 |
-
return "..."
|
| 910 |
-
|
| 911 |
-
if hasattr(t, "__name__"):
|
| 912 |
-
name = str(t.__name__)
|
| 913 |
-
elif getattr(t, "_name", None) is not None: # pragma: no cover
|
| 914 |
-
name = str(t._name)
|
| 915 |
-
elif getattr(t, "__origin__", None) is not None: # pragma: no cover
|
| 916 |
-
name = type_str(t.__origin__)
|
| 917 |
-
else:
|
| 918 |
-
name = str(t)
|
| 919 |
-
if name.startswith("typing."): # pragma: no cover
|
| 920 |
-
name = name[len("typing.") :]
|
| 921 |
-
|
| 922 |
-
args = getattr(t, "__args__", None)
|
| 923 |
-
if args is not None:
|
| 924 |
-
args = ", ".join(
|
| 925 |
-
[type_str(t, include_module_name=include_module_name) for t in t.__args__]
|
| 926 |
-
)
|
| 927 |
-
ret = f"{name}[{args}]"
|
| 928 |
-
else:
|
| 929 |
-
ret = name
|
| 930 |
-
if include_module_name:
|
| 931 |
-
if (
|
| 932 |
-
hasattr(t, "__module__")
|
| 933 |
-
and t.__module__ != "builtins"
|
| 934 |
-
and t.__module__ != "typing"
|
| 935 |
-
and not t.__module__.startswith("omegaconf.")
|
| 936 |
-
):
|
| 937 |
-
module_prefix = str(t.__module__) + "."
|
| 938 |
-
else:
|
| 939 |
-
module_prefix = ""
|
| 940 |
-
ret = module_prefix + ret
|
| 941 |
-
if is_optional:
|
| 942 |
-
return f"Optional[{ret}]"
|
| 943 |
-
else:
|
| 944 |
-
return ret
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
def _ensure_container(target: Any, flags: Optional[Dict[str, bool]] = None) -> Any:
|
| 948 |
-
from omegaconf import OmegaConf
|
| 949 |
-
|
| 950 |
-
if is_primitive_container(target):
|
| 951 |
-
assert isinstance(target, (list, dict))
|
| 952 |
-
target = OmegaConf.create(target, flags=flags)
|
| 953 |
-
elif is_structured_config(target):
|
| 954 |
-
target = OmegaConf.structured(target, flags=flags)
|
| 955 |
-
elif not OmegaConf.is_config(target):
|
| 956 |
-
raise ValueError(
|
| 957 |
-
"Invalid input. Supports one of "
|
| 958 |
-
+ "[dict,list,DictConfig,ListConfig,dataclass,dataclass instance,attr class,attr class instance]"
|
| 959 |
-
)
|
| 960 |
-
|
| 961 |
-
return target
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
def is_generic_list(type_: Any) -> bool:
|
| 965 |
-
"""
|
| 966 |
-
Checks if a type is a generic list, for example:
|
| 967 |
-
list returns False
|
| 968 |
-
typing.List returns False
|
| 969 |
-
typing.List[T] returns True
|
| 970 |
-
|
| 971 |
-
:param type_: variable type
|
| 972 |
-
:return: bool
|
| 973 |
-
"""
|
| 974 |
-
return is_list_annotation(type_) and get_list_element_type(type_) is not None
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
def is_generic_dict(type_: Any) -> bool:
|
| 978 |
-
"""
|
| 979 |
-
Checks if a type is a generic dict, for example:
|
| 980 |
-
list returns False
|
| 981 |
-
typing.List returns False
|
| 982 |
-
typing.List[T] returns True
|
| 983 |
-
|
| 984 |
-
:param type_: variable type
|
| 985 |
-
:return: bool
|
| 986 |
-
"""
|
| 987 |
-
return is_dict_annotation(type_) and len(get_dict_key_value_types(type_)) > 0
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
def is_container_annotation(type_: Any) -> bool:
|
| 991 |
-
return is_list_annotation(type_) or is_dict_annotation(type_)
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
def split_key(key: str) -> List[str]:
|
| 995 |
-
"""
|
| 996 |
-
Split a full key path into its individual components.
|
| 997 |
-
|
| 998 |
-
This is similar to `key.split(".")` but also works with the getitem syntax:
|
| 999 |
-
"a.b" -> ["a", "b"]
|
| 1000 |
-
"a[b]" -> ["a", "b"]
|
| 1001 |
-
".a.b[c].d" -> ["", "a", "b", "c", "d"]
|
| 1002 |
-
"[a].b" -> ["a", "b"]
|
| 1003 |
-
"""
|
| 1004 |
-
# Obtain the first part of the key (in docstring examples: a, a, .a, '')
|
| 1005 |
-
first = KEY_PATH_HEAD.match(key)
|
| 1006 |
-
assert first is not None
|
| 1007 |
-
first_stop = first.span()[1]
|
| 1008 |
-
|
| 1009 |
-
# `tokens` will contain all elements composing the key.
|
| 1010 |
-
tokens = key[0:first_stop].split(".")
|
| 1011 |
-
|
| 1012 |
-
# Optimization in case `key` has no other component: we are done.
|
| 1013 |
-
if first_stop == len(key):
|
| 1014 |
-
return tokens
|
| 1015 |
-
|
| 1016 |
-
if key[first_stop] == "[" and not tokens[-1]:
|
| 1017 |
-
# This is a special case where the first key starts with brackets, e.g.
|
| 1018 |
-
# [a] or ..[a]. In that case there is an extra "" in `tokens` that we
|
| 1019 |
-
# need to get rid of:
|
| 1020 |
-
# [a] -> tokens = [""] but we would like []
|
| 1021 |
-
# ..[a] -> tokens = ["", "", ""] but we would like ["", ""]
|
| 1022 |
-
tokens.pop()
|
| 1023 |
-
|
| 1024 |
-
# Identify other key elements (in docstring examples: b, b, b/c/d, b)
|
| 1025 |
-
others = KEY_PATH_OTHER.findall(key[first_stop:])
|
| 1026 |
-
|
| 1027 |
-
# There are two groups in the `KEY_PATH_OTHER` regex: one for keys starting
|
| 1028 |
-
# with a dot (.b, .d) and one for keys starting with a bracket ([b], [c]).
|
| 1029 |
-
# Only one group can be non-empty.
|
| 1030 |
-
tokens += [dot_key if dot_key else bracket_key for dot_key, bracket_key in others]
|
| 1031 |
-
|
| 1032 |
-
return tokens
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
# Similar to Python 3.7+'s `contextlib.nullcontext` (which should be used instead,
|
| 1036 |
-
# once support for Python 3.6 is dropped).
|
| 1037 |
-
@contextmanager
|
| 1038 |
-
def nullcontext(enter_result: Any = None) -> Iterator[Any]:
|
| 1039 |
-
yield enter_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/base.py
DELETED
|
@@ -1,962 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import sys
|
| 3 |
-
from abc import ABC, abstractmethod
|
| 4 |
-
from collections import defaultdict
|
| 5 |
-
from dataclasses import dataclass, field
|
| 6 |
-
from enum import Enum
|
| 7 |
-
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
|
| 8 |
-
|
| 9 |
-
from antlr4 import ParserRuleContext
|
| 10 |
-
|
| 11 |
-
from ._utils import (
|
| 12 |
-
_DEFAULT_MARKER_,
|
| 13 |
-
NoneType,
|
| 14 |
-
ValueKind,
|
| 15 |
-
_get_value,
|
| 16 |
-
_is_interpolation,
|
| 17 |
-
_is_missing_value,
|
| 18 |
-
_is_special,
|
| 19 |
-
format_and_raise,
|
| 20 |
-
get_value_kind,
|
| 21 |
-
is_union_annotation,
|
| 22 |
-
is_valid_value_annotation,
|
| 23 |
-
split_key,
|
| 24 |
-
type_str,
|
| 25 |
-
)
|
| 26 |
-
from .errors import (
|
| 27 |
-
ConfigKeyError,
|
| 28 |
-
ConfigTypeError,
|
| 29 |
-
InterpolationKeyError,
|
| 30 |
-
InterpolationResolutionError,
|
| 31 |
-
InterpolationToMissingValueError,
|
| 32 |
-
InterpolationValidationError,
|
| 33 |
-
MissingMandatoryValue,
|
| 34 |
-
UnsupportedInterpolationType,
|
| 35 |
-
ValidationError,
|
| 36 |
-
)
|
| 37 |
-
from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
|
| 38 |
-
from .grammar_parser import parse
|
| 39 |
-
from .grammar_visitor import GrammarVisitor
|
| 40 |
-
|
| 41 |
-
DictKeyType = Union[str, bytes, int, Enum, float, bool]
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
@dataclass
|
| 45 |
-
class Metadata:
|
| 46 |
-
|
| 47 |
-
ref_type: Union[Type[Any], Any]
|
| 48 |
-
|
| 49 |
-
object_type: Union[Type[Any], Any]
|
| 50 |
-
|
| 51 |
-
optional: bool
|
| 52 |
-
|
| 53 |
-
key: Any
|
| 54 |
-
|
| 55 |
-
# Flags have 3 modes:
|
| 56 |
-
# unset : inherit from parent (None if no parent specifies)
|
| 57 |
-
# set to true: flag is true
|
| 58 |
-
# set to false: flag is false
|
| 59 |
-
flags: Optional[Dict[str, bool]] = None
|
| 60 |
-
|
| 61 |
-
# If True, when checking the value of a flag, if the flag is not set None is returned
|
| 62 |
-
# otherwise, the parent node is queried.
|
| 63 |
-
flags_root: bool = False
|
| 64 |
-
|
| 65 |
-
resolver_cache: Dict[str, Any] = field(default_factory=lambda: defaultdict(dict))
|
| 66 |
-
|
| 67 |
-
def __post_init__(self) -> None:
|
| 68 |
-
if self.flags is None:
|
| 69 |
-
self.flags = {}
|
| 70 |
-
|
| 71 |
-
@property
|
| 72 |
-
def type_hint(self) -> Union[Type[Any], Any]:
|
| 73 |
-
"""Compute `type_hint` from `self.optional` and `self.ref_type`"""
|
| 74 |
-
# For compatibility with pickled OmegaConf objects created using older
|
| 75 |
-
# versions of OmegaConf, we store `ref_type` and `object_type`
|
| 76 |
-
# separately (rather than storing `type_hint` directly).
|
| 77 |
-
if self.optional:
|
| 78 |
-
return Optional[self.ref_type]
|
| 79 |
-
else:
|
| 80 |
-
return self.ref_type
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
@dataclass
|
| 84 |
-
class ContainerMetadata(Metadata):
|
| 85 |
-
key_type: Any = None
|
| 86 |
-
element_type: Any = None
|
| 87 |
-
|
| 88 |
-
def __post_init__(self) -> None:
|
| 89 |
-
if self.ref_type is None:
|
| 90 |
-
self.ref_type = Any
|
| 91 |
-
assert self.key_type is Any or isinstance(self.key_type, type)
|
| 92 |
-
if self.element_type is not None:
|
| 93 |
-
if not is_valid_value_annotation(self.element_type):
|
| 94 |
-
raise ValidationError(
|
| 95 |
-
f"Unsupported value type: '{type_str(self.element_type, include_module_name=True)}'"
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
if self.flags is None:
|
| 99 |
-
self.flags = {}
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
class Node(ABC):
|
| 103 |
-
_metadata: Metadata
|
| 104 |
-
|
| 105 |
-
_parent: Optional["Box"]
|
| 106 |
-
_flags_cache: Optional[Dict[str, Optional[bool]]]
|
| 107 |
-
|
| 108 |
-
def __init__(self, parent: Optional["Box"], metadata: Metadata):
|
| 109 |
-
self.__dict__["_metadata"] = metadata
|
| 110 |
-
self.__dict__["_parent"] = parent
|
| 111 |
-
self.__dict__["_flags_cache"] = None
|
| 112 |
-
|
| 113 |
-
def __getstate__(self) -> Dict[str, Any]:
|
| 114 |
-
# Overridden to ensure that the flags cache is cleared on serialization.
|
| 115 |
-
state_dict = copy.copy(self.__dict__)
|
| 116 |
-
del state_dict["_flags_cache"]
|
| 117 |
-
return state_dict
|
| 118 |
-
|
| 119 |
-
def __setstate__(self, state_dict: Dict[str, Any]) -> None:
|
| 120 |
-
self.__dict__.update(state_dict)
|
| 121 |
-
self.__dict__["_flags_cache"] = None
|
| 122 |
-
|
| 123 |
-
def _set_parent(self, parent: Optional["Box"]) -> None:
|
| 124 |
-
assert parent is None or isinstance(parent, Box)
|
| 125 |
-
self.__dict__["_parent"] = parent
|
| 126 |
-
self._invalidate_flags_cache()
|
| 127 |
-
|
| 128 |
-
def _invalidate_flags_cache(self) -> None:
|
| 129 |
-
self.__dict__["_flags_cache"] = None
|
| 130 |
-
|
| 131 |
-
def _get_parent(self) -> Optional["Box"]:
|
| 132 |
-
parent = self.__dict__["_parent"]
|
| 133 |
-
assert parent is None or isinstance(parent, Box)
|
| 134 |
-
return parent
|
| 135 |
-
|
| 136 |
-
def _get_parent_container(self) -> Optional["Container"]:
|
| 137 |
-
"""
|
| 138 |
-
Like _get_parent, but returns the grandparent
|
| 139 |
-
in the case where `self` is wrapped by a UnionNode.
|
| 140 |
-
"""
|
| 141 |
-
parent = self.__dict__["_parent"]
|
| 142 |
-
assert parent is None or isinstance(parent, Box)
|
| 143 |
-
|
| 144 |
-
if isinstance(parent, UnionNode):
|
| 145 |
-
grandparent = parent.__dict__["_parent"]
|
| 146 |
-
assert grandparent is None or isinstance(grandparent, Container)
|
| 147 |
-
return grandparent
|
| 148 |
-
else:
|
| 149 |
-
assert parent is None or isinstance(parent, Container)
|
| 150 |
-
return parent
|
| 151 |
-
|
| 152 |
-
def _set_flag(
|
| 153 |
-
self,
|
| 154 |
-
flags: Union[List[str], str],
|
| 155 |
-
values: Union[List[Optional[bool]], Optional[bool]],
|
| 156 |
-
) -> "Node":
|
| 157 |
-
if isinstance(flags, str):
|
| 158 |
-
flags = [flags]
|
| 159 |
-
|
| 160 |
-
if values is None or isinstance(values, bool):
|
| 161 |
-
values = [values]
|
| 162 |
-
|
| 163 |
-
if len(values) == 1:
|
| 164 |
-
values = len(flags) * values
|
| 165 |
-
|
| 166 |
-
if len(flags) != len(values):
|
| 167 |
-
raise ValueError("Inconsistent lengths of input flag names and values")
|
| 168 |
-
|
| 169 |
-
for idx, flag in enumerate(flags):
|
| 170 |
-
value = values[idx]
|
| 171 |
-
if value is None:
|
| 172 |
-
assert self._metadata.flags is not None
|
| 173 |
-
if flag in self._metadata.flags:
|
| 174 |
-
del self._metadata.flags[flag]
|
| 175 |
-
else:
|
| 176 |
-
assert self._metadata.flags is not None
|
| 177 |
-
self._metadata.flags[flag] = value
|
| 178 |
-
self._invalidate_flags_cache()
|
| 179 |
-
return self
|
| 180 |
-
|
| 181 |
-
def _get_node_flag(self, flag: str) -> Optional[bool]:
|
| 182 |
-
"""
|
| 183 |
-
:param flag: flag to inspect
|
| 184 |
-
:return: the state of the flag on this node.
|
| 185 |
-
"""
|
| 186 |
-
assert self._metadata.flags is not None
|
| 187 |
-
return self._metadata.flags.get(flag)
|
| 188 |
-
|
| 189 |
-
def _get_flag(self, flag: str) -> Optional[bool]:
|
| 190 |
-
cache = self.__dict__["_flags_cache"]
|
| 191 |
-
if cache is None:
|
| 192 |
-
cache = self.__dict__["_flags_cache"] = {}
|
| 193 |
-
|
| 194 |
-
ret = cache.get(flag, _DEFAULT_MARKER_)
|
| 195 |
-
if ret is _DEFAULT_MARKER_:
|
| 196 |
-
ret = self._get_flag_no_cache(flag)
|
| 197 |
-
cache[flag] = ret
|
| 198 |
-
assert ret is None or isinstance(ret, bool)
|
| 199 |
-
return ret
|
| 200 |
-
|
| 201 |
-
def _get_flag_no_cache(self, flag: str) -> Optional[bool]:
|
| 202 |
-
"""
|
| 203 |
-
Returns True if this config node flag is set
|
| 204 |
-
A flag is set if node.set_flag(True) was called
|
| 205 |
-
or one if it's parents is flag is set
|
| 206 |
-
:return:
|
| 207 |
-
"""
|
| 208 |
-
flags = self._metadata.flags
|
| 209 |
-
assert flags is not None
|
| 210 |
-
if flag in flags and flags[flag] is not None:
|
| 211 |
-
return flags[flag]
|
| 212 |
-
|
| 213 |
-
if self._is_flags_root():
|
| 214 |
-
return None
|
| 215 |
-
|
| 216 |
-
parent = self._get_parent()
|
| 217 |
-
if parent is None:
|
| 218 |
-
return None
|
| 219 |
-
else:
|
| 220 |
-
# noinspection PyProtectedMember
|
| 221 |
-
return parent._get_flag(flag)
|
| 222 |
-
|
| 223 |
-
def _format_and_raise(
|
| 224 |
-
self,
|
| 225 |
-
key: Any,
|
| 226 |
-
value: Any,
|
| 227 |
-
cause: Exception,
|
| 228 |
-
msg: Optional[str] = None,
|
| 229 |
-
type_override: Any = None,
|
| 230 |
-
) -> None:
|
| 231 |
-
format_and_raise(
|
| 232 |
-
node=self,
|
| 233 |
-
key=key,
|
| 234 |
-
value=value,
|
| 235 |
-
msg=str(cause) if msg is None else msg,
|
| 236 |
-
cause=cause,
|
| 237 |
-
type_override=type_override,
|
| 238 |
-
)
|
| 239 |
-
assert False
|
| 240 |
-
|
| 241 |
-
@abstractmethod
|
| 242 |
-
def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
|
| 243 |
-
...
|
| 244 |
-
|
| 245 |
-
def _dereference_node(self) -> "Node":
|
| 246 |
-
node = self._dereference_node_impl(throw_on_resolution_failure=True)
|
| 247 |
-
assert node is not None
|
| 248 |
-
return node
|
| 249 |
-
|
| 250 |
-
def _maybe_dereference_node(
|
| 251 |
-
self,
|
| 252 |
-
throw_on_resolution_failure: bool = False,
|
| 253 |
-
memo: Optional[Set[int]] = None,
|
| 254 |
-
) -> Optional["Node"]:
|
| 255 |
-
return self._dereference_node_impl(
|
| 256 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 257 |
-
memo=memo,
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
def _dereference_node_impl(
|
| 261 |
-
self,
|
| 262 |
-
throw_on_resolution_failure: bool,
|
| 263 |
-
memo: Optional[Set[int]] = None,
|
| 264 |
-
) -> Optional["Node"]:
|
| 265 |
-
if not self._is_interpolation():
|
| 266 |
-
return self
|
| 267 |
-
|
| 268 |
-
parent = self._get_parent_container()
|
| 269 |
-
if parent is None:
|
| 270 |
-
if throw_on_resolution_failure:
|
| 271 |
-
raise InterpolationResolutionError(
|
| 272 |
-
"Cannot resolve interpolation for a node without a parent"
|
| 273 |
-
)
|
| 274 |
-
return None
|
| 275 |
-
assert parent is not None
|
| 276 |
-
key = self._key()
|
| 277 |
-
return parent._resolve_interpolation_from_parse_tree(
|
| 278 |
-
parent=parent,
|
| 279 |
-
key=key,
|
| 280 |
-
value=self,
|
| 281 |
-
parse_tree=parse(_get_value(self)),
|
| 282 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 283 |
-
memo=memo,
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
def _get_root(self) -> "Container":
|
| 287 |
-
root: Optional[Box] = self._get_parent()
|
| 288 |
-
if root is None:
|
| 289 |
-
assert isinstance(self, Container)
|
| 290 |
-
return self
|
| 291 |
-
assert root is not None and isinstance(root, Box)
|
| 292 |
-
while root._get_parent() is not None:
|
| 293 |
-
root = root._get_parent()
|
| 294 |
-
assert root is not None and isinstance(root, Box)
|
| 295 |
-
assert root is not None and isinstance(root, Container)
|
| 296 |
-
return root
|
| 297 |
-
|
| 298 |
-
def _is_missing(self) -> bool:
|
| 299 |
-
"""
|
| 300 |
-
Check if the node's value is `???` (does *not* resolve interpolations).
|
| 301 |
-
"""
|
| 302 |
-
return _is_missing_value(self)
|
| 303 |
-
|
| 304 |
-
def _is_none(self) -> bool:
|
| 305 |
-
"""
|
| 306 |
-
Check if the node's value is `None` (does *not* resolve interpolations).
|
| 307 |
-
"""
|
| 308 |
-
return self._value() is None
|
| 309 |
-
|
| 310 |
-
@abstractmethod
|
| 311 |
-
def __eq__(self, other: Any) -> bool:
|
| 312 |
-
...
|
| 313 |
-
|
| 314 |
-
@abstractmethod
|
| 315 |
-
def __ne__(self, other: Any) -> bool:
|
| 316 |
-
...
|
| 317 |
-
|
| 318 |
-
@abstractmethod
|
| 319 |
-
def __hash__(self) -> int:
|
| 320 |
-
...
|
| 321 |
-
|
| 322 |
-
@abstractmethod
|
| 323 |
-
def _value(self) -> Any:
|
| 324 |
-
...
|
| 325 |
-
|
| 326 |
-
@abstractmethod
|
| 327 |
-
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
|
| 328 |
-
...
|
| 329 |
-
|
| 330 |
-
@abstractmethod
|
| 331 |
-
def _is_optional(self) -> bool:
|
| 332 |
-
...
|
| 333 |
-
|
| 334 |
-
@abstractmethod
|
| 335 |
-
def _is_interpolation(self) -> bool:
|
| 336 |
-
...
|
| 337 |
-
|
| 338 |
-
def _key(self) -> Any:
|
| 339 |
-
return self._metadata.key
|
| 340 |
-
|
| 341 |
-
def _set_key(self, key: Any) -> None:
|
| 342 |
-
self._metadata.key = key
|
| 343 |
-
|
| 344 |
-
def _is_flags_root(self) -> bool:
|
| 345 |
-
return self._metadata.flags_root
|
| 346 |
-
|
| 347 |
-
def _set_flags_root(self, flags_root: bool) -> None:
|
| 348 |
-
if self._metadata.flags_root != flags_root:
|
| 349 |
-
self._metadata.flags_root = flags_root
|
| 350 |
-
self._invalidate_flags_cache()
|
| 351 |
-
|
| 352 |
-
def _has_ref_type(self) -> bool:
|
| 353 |
-
return self._metadata.ref_type is not Any
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
class Box(Node):
|
| 357 |
-
"""
|
| 358 |
-
Base class for nodes that can contain other nodes.
|
| 359 |
-
Concrete subclasses include DictConfig, ListConfig, and UnionNode.
|
| 360 |
-
"""
|
| 361 |
-
|
| 362 |
-
_content: Any
|
| 363 |
-
|
| 364 |
-
def __init__(self, parent: Optional["Box"], metadata: Metadata):
|
| 365 |
-
super().__init__(parent=parent, metadata=metadata)
|
| 366 |
-
self.__dict__["_content"] = None
|
| 367 |
-
|
| 368 |
-
def __copy__(self) -> Any:
|
| 369 |
-
# real shallow copy is impossible because of the reference to the parent.
|
| 370 |
-
return copy.deepcopy(self)
|
| 371 |
-
|
| 372 |
-
def _re_parent(self) -> None:
|
| 373 |
-
from .dictconfig import DictConfig
|
| 374 |
-
from .listconfig import ListConfig
|
| 375 |
-
|
| 376 |
-
# update parents of first level Config nodes to self
|
| 377 |
-
|
| 378 |
-
if isinstance(self, DictConfig):
|
| 379 |
-
content = self.__dict__["_content"]
|
| 380 |
-
if isinstance(content, dict):
|
| 381 |
-
for _key, value in self.__dict__["_content"].items():
|
| 382 |
-
if value is not None:
|
| 383 |
-
value._set_parent(self)
|
| 384 |
-
if isinstance(value, Box):
|
| 385 |
-
value._re_parent()
|
| 386 |
-
elif isinstance(self, ListConfig):
|
| 387 |
-
content = self.__dict__["_content"]
|
| 388 |
-
if isinstance(content, list):
|
| 389 |
-
for item in self.__dict__["_content"]:
|
| 390 |
-
if item is not None:
|
| 391 |
-
item._set_parent(self)
|
| 392 |
-
if isinstance(item, Box):
|
| 393 |
-
item._re_parent()
|
| 394 |
-
elif isinstance(self, UnionNode):
|
| 395 |
-
content = self.__dict__["_content"]
|
| 396 |
-
if isinstance(content, Node):
|
| 397 |
-
content._set_parent(self)
|
| 398 |
-
if isinstance(content, Box): # pragma: no cover
|
| 399 |
-
# No coverage here as support for containers inside
|
| 400 |
-
# UnionNode is not yet implemented
|
| 401 |
-
content._re_parent()
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
class Container(Box):
|
| 405 |
-
"""
|
| 406 |
-
Container tagging interface
|
| 407 |
-
"""
|
| 408 |
-
|
| 409 |
-
_metadata: ContainerMetadata
|
| 410 |
-
|
| 411 |
-
@abstractmethod
|
| 412 |
-
def _get_child(
|
| 413 |
-
self,
|
| 414 |
-
key: Any,
|
| 415 |
-
validate_access: bool = True,
|
| 416 |
-
validate_key: bool = True,
|
| 417 |
-
throw_on_missing_value: bool = False,
|
| 418 |
-
throw_on_missing_key: bool = False,
|
| 419 |
-
) -> Union[Optional[Node], List[Optional[Node]]]:
|
| 420 |
-
...
|
| 421 |
-
|
| 422 |
-
@abstractmethod
|
| 423 |
-
def _get_node(
|
| 424 |
-
self,
|
| 425 |
-
key: Any,
|
| 426 |
-
validate_access: bool = True,
|
| 427 |
-
validate_key: bool = True,
|
| 428 |
-
throw_on_missing_value: bool = False,
|
| 429 |
-
throw_on_missing_key: bool = False,
|
| 430 |
-
) -> Union[Optional[Node], List[Optional[Node]]]:
|
| 431 |
-
...
|
| 432 |
-
|
| 433 |
-
@abstractmethod
|
| 434 |
-
def __delitem__(self, key: Any) -> None:
|
| 435 |
-
...
|
| 436 |
-
|
| 437 |
-
@abstractmethod
|
| 438 |
-
def __setitem__(self, key: Any, value: Any) -> None:
|
| 439 |
-
...
|
| 440 |
-
|
| 441 |
-
@abstractmethod
|
| 442 |
-
def __iter__(self) -> Iterator[Any]:
|
| 443 |
-
...
|
| 444 |
-
|
| 445 |
-
@abstractmethod
|
| 446 |
-
def __getitem__(self, key_or_index: Any) -> Any:
|
| 447 |
-
...
|
| 448 |
-
|
| 449 |
-
def _resolve_key_and_root(self, key: str) -> Tuple["Container", str]:
|
| 450 |
-
orig = key
|
| 451 |
-
if not key.startswith("."):
|
| 452 |
-
return self._get_root(), key
|
| 453 |
-
else:
|
| 454 |
-
root: Optional[Container] = self
|
| 455 |
-
assert key.startswith(".")
|
| 456 |
-
while True:
|
| 457 |
-
assert root is not None
|
| 458 |
-
key = key[1:]
|
| 459 |
-
if not key.startswith("."):
|
| 460 |
-
break
|
| 461 |
-
root = root._get_parent_container()
|
| 462 |
-
if root is None:
|
| 463 |
-
raise ConfigKeyError(f"Error resolving key '{orig}'")
|
| 464 |
-
|
| 465 |
-
return root, key
|
| 466 |
-
|
| 467 |
-
def _select_impl(
|
| 468 |
-
self,
|
| 469 |
-
key: str,
|
| 470 |
-
throw_on_missing: bool,
|
| 471 |
-
throw_on_resolution_failure: bool,
|
| 472 |
-
memo: Optional[Set[int]] = None,
|
| 473 |
-
) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]:
|
| 474 |
-
"""
|
| 475 |
-
Select a value using dot separated key sequence
|
| 476 |
-
"""
|
| 477 |
-
from .omegaconf import _select_one
|
| 478 |
-
|
| 479 |
-
if key == "":
|
| 480 |
-
return self, "", self
|
| 481 |
-
|
| 482 |
-
split = split_key(key)
|
| 483 |
-
root: Optional[Container] = self
|
| 484 |
-
for i in range(len(split) - 1):
|
| 485 |
-
if root is None:
|
| 486 |
-
break
|
| 487 |
-
|
| 488 |
-
k = split[i]
|
| 489 |
-
ret, _ = _select_one(
|
| 490 |
-
c=root,
|
| 491 |
-
key=k,
|
| 492 |
-
throw_on_missing=throw_on_missing,
|
| 493 |
-
throw_on_type_error=throw_on_resolution_failure,
|
| 494 |
-
)
|
| 495 |
-
if isinstance(ret, Node):
|
| 496 |
-
ret = ret._maybe_dereference_node(
|
| 497 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 498 |
-
memo=memo,
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
if ret is not None and not isinstance(ret, Container):
|
| 502 |
-
parent_key = ".".join(split[0 : i + 1])
|
| 503 |
-
child_key = split[i + 1]
|
| 504 |
-
raise ConfigTypeError(
|
| 505 |
-
f"Error trying to access {key}: node `{parent_key}` "
|
| 506 |
-
f"is not a container and thus cannot contain `{child_key}`"
|
| 507 |
-
)
|
| 508 |
-
root = ret
|
| 509 |
-
|
| 510 |
-
if root is None:
|
| 511 |
-
return None, None, None
|
| 512 |
-
|
| 513 |
-
last_key = split[-1]
|
| 514 |
-
value, _ = _select_one(
|
| 515 |
-
c=root,
|
| 516 |
-
key=last_key,
|
| 517 |
-
throw_on_missing=throw_on_missing,
|
| 518 |
-
throw_on_type_error=throw_on_resolution_failure,
|
| 519 |
-
)
|
| 520 |
-
if value is None:
|
| 521 |
-
return root, last_key, None
|
| 522 |
-
|
| 523 |
-
if memo is not None:
|
| 524 |
-
vid = id(value)
|
| 525 |
-
if vid in memo:
|
| 526 |
-
raise InterpolationResolutionError("Recursive interpolation detected")
|
| 527 |
-
# push to memo "stack"
|
| 528 |
-
memo.add(vid)
|
| 529 |
-
|
| 530 |
-
try:
|
| 531 |
-
value = root._maybe_resolve_interpolation(
|
| 532 |
-
parent=root,
|
| 533 |
-
key=last_key,
|
| 534 |
-
value=value,
|
| 535 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 536 |
-
memo=memo,
|
| 537 |
-
)
|
| 538 |
-
finally:
|
| 539 |
-
if memo is not None:
|
| 540 |
-
# pop from memo "stack"
|
| 541 |
-
memo.remove(vid)
|
| 542 |
-
|
| 543 |
-
return root, last_key, value
|
| 544 |
-
|
| 545 |
-
def _resolve_interpolation_from_parse_tree(
|
| 546 |
-
self,
|
| 547 |
-
parent: Optional["Container"],
|
| 548 |
-
value: "Node",
|
| 549 |
-
key: Any,
|
| 550 |
-
parse_tree: OmegaConfGrammarParser.ConfigValueContext,
|
| 551 |
-
throw_on_resolution_failure: bool,
|
| 552 |
-
memo: Optional[Set[int]],
|
| 553 |
-
) -> Optional["Node"]:
|
| 554 |
-
"""
|
| 555 |
-
Resolve an interpolation.
|
| 556 |
-
|
| 557 |
-
This happens in two steps:
|
| 558 |
-
1. The parse tree is visited, which outputs either a `Node` (e.g.,
|
| 559 |
-
for node interpolations "${foo}"), a string (e.g., for string
|
| 560 |
-
interpolations "hello ${name}", or any other arbitrary value
|
| 561 |
-
(e.g., or custom interpolations "${foo:bar}").
|
| 562 |
-
2. This output is potentially validated and converted when the node
|
| 563 |
-
being resolved (`value`) is typed.
|
| 564 |
-
|
| 565 |
-
If an error occurs in one of the above steps, an `InterpolationResolutionError`
|
| 566 |
-
(or a subclass of it) is raised, *unless* `throw_on_resolution_failure` is set
|
| 567 |
-
to `False` (in which case the return value is `None`).
|
| 568 |
-
|
| 569 |
-
:param parent: Parent of the node being resolved.
|
| 570 |
-
:param value: Node being resolved.
|
| 571 |
-
:param key: The associated key in the parent.
|
| 572 |
-
:param parse_tree: The parse tree as obtained from `grammar_parser.parse()`.
|
| 573 |
-
:param throw_on_resolution_failure: If `False`, then exceptions raised during
|
| 574 |
-
the resolution of the interpolation are silenced, and instead `None` is
|
| 575 |
-
returned.
|
| 576 |
-
|
| 577 |
-
:return: A `Node` that contains the interpolation result. This may be an existing
|
| 578 |
-
node in the config (in the case of a node interpolation "${foo}"), or a new
|
| 579 |
-
node that is created to wrap the interpolated value. It is `None` if and only if
|
| 580 |
-
`throw_on_resolution_failure` is `False` and an error occurs during resolution.
|
| 581 |
-
"""
|
| 582 |
-
|
| 583 |
-
try:
|
| 584 |
-
resolved = self.resolve_parse_tree(
|
| 585 |
-
parse_tree=parse_tree, node=value, key=key, memo=memo
|
| 586 |
-
)
|
| 587 |
-
except InterpolationResolutionError:
|
| 588 |
-
if throw_on_resolution_failure:
|
| 589 |
-
raise
|
| 590 |
-
return None
|
| 591 |
-
|
| 592 |
-
return self._validate_and_convert_interpolation_result(
|
| 593 |
-
parent=parent,
|
| 594 |
-
value=value,
|
| 595 |
-
key=key,
|
| 596 |
-
resolved=resolved,
|
| 597 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 598 |
-
)
|
| 599 |
-
|
| 600 |
-
def _validate_and_convert_interpolation_result(
|
| 601 |
-
self,
|
| 602 |
-
parent: Optional["Container"],
|
| 603 |
-
value: "Node",
|
| 604 |
-
key: Any,
|
| 605 |
-
resolved: Any,
|
| 606 |
-
throw_on_resolution_failure: bool,
|
| 607 |
-
) -> Optional["Node"]:
|
| 608 |
-
from .nodes import AnyNode, InterpolationResultNode, ValueNode
|
| 609 |
-
|
| 610 |
-
# If the output is not a Node already (e.g., because it is the output of a
|
| 611 |
-
# custom resolver), then we will need to wrap it within a Node.
|
| 612 |
-
must_wrap = not isinstance(resolved, Node)
|
| 613 |
-
|
| 614 |
-
# If the node is typed, validate (and possibly convert) the result.
|
| 615 |
-
if isinstance(value, ValueNode) and not isinstance(value, AnyNode):
|
| 616 |
-
res_value = _get_value(resolved)
|
| 617 |
-
try:
|
| 618 |
-
conv_value = value.validate_and_convert(res_value)
|
| 619 |
-
except ValidationError as e:
|
| 620 |
-
if throw_on_resolution_failure:
|
| 621 |
-
self._format_and_raise(
|
| 622 |
-
key=key,
|
| 623 |
-
value=res_value,
|
| 624 |
-
cause=e,
|
| 625 |
-
msg=f"While dereferencing interpolation '{value}': {e}",
|
| 626 |
-
type_override=InterpolationValidationError,
|
| 627 |
-
)
|
| 628 |
-
return None
|
| 629 |
-
|
| 630 |
-
# If the converted value is of the same type, it means that no conversion
|
| 631 |
-
# was actually needed. As a result, we can keep the original `resolved`
|
| 632 |
-
# (and otherwise, the converted value must be wrapped into a new node).
|
| 633 |
-
if type(conv_value) != type(res_value):
|
| 634 |
-
must_wrap = True
|
| 635 |
-
resolved = conv_value
|
| 636 |
-
|
| 637 |
-
if must_wrap:
|
| 638 |
-
return InterpolationResultNode(value=resolved, key=key, parent=parent)
|
| 639 |
-
else:
|
| 640 |
-
assert isinstance(resolved, Node)
|
| 641 |
-
return resolved
|
| 642 |
-
|
| 643 |
-
def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None:
|
| 644 |
-
parent: Optional[Node] = node
|
| 645 |
-
while parent is not None:
|
| 646 |
-
if parent is target:
|
| 647 |
-
raise InterpolationResolutionError(
|
| 648 |
-
"Interpolation to parent node detected"
|
| 649 |
-
)
|
| 650 |
-
parent = parent._get_parent()
|
| 651 |
-
|
| 652 |
-
def _resolve_node_interpolation(
|
| 653 |
-
self, inter_key: str, memo: Optional[Set[int]]
|
| 654 |
-
) -> "Node":
|
| 655 |
-
"""A node interpolation is of the form `${foo.bar}`"""
|
| 656 |
-
try:
|
| 657 |
-
root_node, inter_key = self._resolve_key_and_root(inter_key)
|
| 658 |
-
except ConfigKeyError as exc:
|
| 659 |
-
raise InterpolationKeyError(
|
| 660 |
-
f"ConfigKeyError while resolving interpolation: {exc}"
|
| 661 |
-
).with_traceback(sys.exc_info()[2])
|
| 662 |
-
|
| 663 |
-
try:
|
| 664 |
-
parent, last_key, value = root_node._select_impl(
|
| 665 |
-
inter_key,
|
| 666 |
-
throw_on_missing=True,
|
| 667 |
-
throw_on_resolution_failure=True,
|
| 668 |
-
memo=memo,
|
| 669 |
-
)
|
| 670 |
-
except MissingMandatoryValue as exc:
|
| 671 |
-
raise InterpolationToMissingValueError(
|
| 672 |
-
f"MissingMandatoryValue while resolving interpolation: {exc}"
|
| 673 |
-
).with_traceback(sys.exc_info()[2])
|
| 674 |
-
|
| 675 |
-
if parent is None or value is None:
|
| 676 |
-
raise InterpolationKeyError(f"Interpolation key '{inter_key}' not found")
|
| 677 |
-
else:
|
| 678 |
-
self._validate_not_dereferencing_to_parent(node=self, target=value)
|
| 679 |
-
return value
|
| 680 |
-
|
| 681 |
-
def _evaluate_custom_resolver(
|
| 682 |
-
self,
|
| 683 |
-
key: Any,
|
| 684 |
-
node: Node,
|
| 685 |
-
inter_type: str,
|
| 686 |
-
inter_args: Tuple[Any, ...],
|
| 687 |
-
inter_args_str: Tuple[str, ...],
|
| 688 |
-
) -> Any:
|
| 689 |
-
from omegaconf import OmegaConf
|
| 690 |
-
|
| 691 |
-
resolver = OmegaConf._get_resolver(inter_type)
|
| 692 |
-
if resolver is not None:
|
| 693 |
-
root_node = self._get_root()
|
| 694 |
-
return resolver(
|
| 695 |
-
root_node,
|
| 696 |
-
self,
|
| 697 |
-
node,
|
| 698 |
-
inter_args,
|
| 699 |
-
inter_args_str,
|
| 700 |
-
)
|
| 701 |
-
else:
|
| 702 |
-
raise UnsupportedInterpolationType(
|
| 703 |
-
f"Unsupported interpolation type {inter_type}"
|
| 704 |
-
)
|
| 705 |
-
|
| 706 |
-
def _maybe_resolve_interpolation(
|
| 707 |
-
self,
|
| 708 |
-
parent: Optional["Container"],
|
| 709 |
-
key: Any,
|
| 710 |
-
value: Node,
|
| 711 |
-
throw_on_resolution_failure: bool,
|
| 712 |
-
memo: Optional[Set[int]] = None,
|
| 713 |
-
) -> Optional[Node]:
|
| 714 |
-
value_kind = get_value_kind(value)
|
| 715 |
-
if value_kind != ValueKind.INTERPOLATION:
|
| 716 |
-
return value
|
| 717 |
-
|
| 718 |
-
parse_tree = parse(_get_value(value))
|
| 719 |
-
return self._resolve_interpolation_from_parse_tree(
|
| 720 |
-
parent=parent,
|
| 721 |
-
value=value,
|
| 722 |
-
key=key,
|
| 723 |
-
parse_tree=parse_tree,
|
| 724 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 725 |
-
memo=memo if memo is not None else set(),
|
| 726 |
-
)
|
| 727 |
-
|
| 728 |
-
def resolve_parse_tree(
|
| 729 |
-
self,
|
| 730 |
-
parse_tree: ParserRuleContext,
|
| 731 |
-
node: Node,
|
| 732 |
-
memo: Optional[Set[int]] = None,
|
| 733 |
-
key: Optional[Any] = None,
|
| 734 |
-
) -> Any:
|
| 735 |
-
"""
|
| 736 |
-
Resolve a given parse tree into its value.
|
| 737 |
-
|
| 738 |
-
We make no assumption here on the type of the tree's root, so that the
|
| 739 |
-
return value may be of any type.
|
| 740 |
-
"""
|
| 741 |
-
|
| 742 |
-
def node_interpolation_callback(
|
| 743 |
-
inter_key: str, memo: Optional[Set[int]]
|
| 744 |
-
) -> Optional["Node"]:
|
| 745 |
-
return self._resolve_node_interpolation(inter_key=inter_key, memo=memo)
|
| 746 |
-
|
| 747 |
-
def resolver_interpolation_callback(
|
| 748 |
-
name: str, args: Tuple[Any, ...], args_str: Tuple[str, ...]
|
| 749 |
-
) -> Any:
|
| 750 |
-
return self._evaluate_custom_resolver(
|
| 751 |
-
key=key,
|
| 752 |
-
node=node,
|
| 753 |
-
inter_type=name,
|
| 754 |
-
inter_args=args,
|
| 755 |
-
inter_args_str=args_str,
|
| 756 |
-
)
|
| 757 |
-
|
| 758 |
-
visitor = GrammarVisitor(
|
| 759 |
-
node_interpolation_callback=node_interpolation_callback,
|
| 760 |
-
resolver_interpolation_callback=resolver_interpolation_callback,
|
| 761 |
-
memo=memo,
|
| 762 |
-
)
|
| 763 |
-
try:
|
| 764 |
-
return visitor.visit(parse_tree)
|
| 765 |
-
except InterpolationResolutionError:
|
| 766 |
-
raise
|
| 767 |
-
except Exception as exc:
|
| 768 |
-
# Other kinds of exceptions are wrapped in an `InterpolationResolutionError`.
|
| 769 |
-
raise InterpolationResolutionError(
|
| 770 |
-
f"{type(exc).__name__} raised while resolving interpolation: {exc}"
|
| 771 |
-
).with_traceback(sys.exc_info()[2])
|
| 772 |
-
|
| 773 |
-
def _invalidate_flags_cache(self) -> None:
|
| 774 |
-
from .dictconfig import DictConfig
|
| 775 |
-
from .listconfig import ListConfig
|
| 776 |
-
|
| 777 |
-
# invalidate subtree cache only if the cache is initialized in this node.
|
| 778 |
-
|
| 779 |
-
if self.__dict__["_flags_cache"] is not None:
|
| 780 |
-
self.__dict__["_flags_cache"] = None
|
| 781 |
-
if isinstance(self, DictConfig):
|
| 782 |
-
content = self.__dict__["_content"]
|
| 783 |
-
if isinstance(content, dict):
|
| 784 |
-
for value in self.__dict__["_content"].values():
|
| 785 |
-
value._invalidate_flags_cache()
|
| 786 |
-
elif isinstance(self, ListConfig):
|
| 787 |
-
content = self.__dict__["_content"]
|
| 788 |
-
if isinstance(content, list):
|
| 789 |
-
for item in self.__dict__["_content"]:
|
| 790 |
-
item._invalidate_flags_cache()
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
class SCMode(Enum):
|
| 794 |
-
DICT = 1 # Convert to plain dict
|
| 795 |
-
DICT_CONFIG = 2 # Keep as OmegaConf DictConfig
|
| 796 |
-
INSTANTIATE = 3 # Create a dataclass or attrs class instance
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
class UnionNode(Box):
|
| 800 |
-
"""
|
| 801 |
-
This class handles Union type hints. The `_content` attribute is either a
|
| 802 |
-
child node that is compatible with the given Union ref_type, or it is a
|
| 803 |
-
special value (None or MISSING or interpolation).
|
| 804 |
-
|
| 805 |
-
Much of the logic for e.g. value assignment and type validation is
|
| 806 |
-
delegated to the child node. As such, UnionNode functions as a
|
| 807 |
-
"pass-through" node. User apps and downstream libraries should not need to
|
| 808 |
-
know about UnionNode (assuming they only use OmegaConf's public API).
|
| 809 |
-
"""
|
| 810 |
-
|
| 811 |
-
_parent: Optional[Container]
|
| 812 |
-
_content: Union[Node, None, str]
|
| 813 |
-
|
| 814 |
-
def __init__(
|
| 815 |
-
self,
|
| 816 |
-
content: Any,
|
| 817 |
-
ref_type: Any,
|
| 818 |
-
is_optional: bool = True,
|
| 819 |
-
key: Any = None,
|
| 820 |
-
parent: Optional[Box] = None,
|
| 821 |
-
) -> None:
|
| 822 |
-
try:
|
| 823 |
-
if not is_union_annotation(ref_type): # pragma: no cover
|
| 824 |
-
msg = (
|
| 825 |
-
f"UnionNode got unexpected ref_type {ref_type}. Please file a bug"
|
| 826 |
-
+ " report at https://github.com/omry/omegaconf/issues"
|
| 827 |
-
)
|
| 828 |
-
raise AssertionError(msg)
|
| 829 |
-
if not isinstance(parent, (Container, NoneType)):
|
| 830 |
-
raise ConfigTypeError("Parent type is not omegaconf.Container")
|
| 831 |
-
super().__init__(
|
| 832 |
-
parent=parent,
|
| 833 |
-
metadata=Metadata(
|
| 834 |
-
ref_type=ref_type,
|
| 835 |
-
object_type=None,
|
| 836 |
-
optional=is_optional,
|
| 837 |
-
key=key,
|
| 838 |
-
flags={"convert": False},
|
| 839 |
-
),
|
| 840 |
-
)
|
| 841 |
-
self._set_value(content)
|
| 842 |
-
except Exception as ex:
|
| 843 |
-
format_and_raise(node=None, key=key, value=content, msg=str(ex), cause=ex)
|
| 844 |
-
|
| 845 |
-
def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
|
| 846 |
-
parent = self._get_parent()
|
| 847 |
-
if parent is None:
|
| 848 |
-
if self._metadata.key is None:
|
| 849 |
-
return ""
|
| 850 |
-
else:
|
| 851 |
-
return str(self._metadata.key)
|
| 852 |
-
else:
|
| 853 |
-
return parent._get_full_key(self._metadata.key)
|
| 854 |
-
|
| 855 |
-
def __eq__(self, other: Any) -> bool:
|
| 856 |
-
content = self.__dict__["_content"]
|
| 857 |
-
if isinstance(content, Node):
|
| 858 |
-
ret = content.__eq__(other)
|
| 859 |
-
elif isinstance(other, Node):
|
| 860 |
-
ret = other.__eq__(content)
|
| 861 |
-
else:
|
| 862 |
-
ret = content.__eq__(other)
|
| 863 |
-
assert isinstance(ret, (bool, type(NotImplemented)))
|
| 864 |
-
return ret
|
| 865 |
-
|
| 866 |
-
def __ne__(self, other: Any) -> bool:
|
| 867 |
-
x = self.__eq__(other)
|
| 868 |
-
if x is NotImplemented:
|
| 869 |
-
return NotImplemented
|
| 870 |
-
return not x
|
| 871 |
-
|
| 872 |
-
def __hash__(self) -> int:
|
| 873 |
-
return hash(self.__dict__["_content"])
|
| 874 |
-
|
| 875 |
-
def _value(self) -> Union[Node, None, str]:
|
| 876 |
-
content = self.__dict__["_content"]
|
| 877 |
-
assert isinstance(content, (Node, NoneType, str))
|
| 878 |
-
return content
|
| 879 |
-
|
| 880 |
-
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
|
| 881 |
-
previous_content = self.__dict__["_content"]
|
| 882 |
-
previous_metadata = self.__dict__["_metadata"]
|
| 883 |
-
try:
|
| 884 |
-
self._set_value_impl(value, flags)
|
| 885 |
-
except Exception as e:
|
| 886 |
-
self.__dict__["_content"] = previous_content
|
| 887 |
-
self.__dict__["_metadata"] = previous_metadata
|
| 888 |
-
raise e
|
| 889 |
-
|
| 890 |
-
def _set_value_impl(
|
| 891 |
-
self, value: Any, flags: Optional[Dict[str, bool]] = None
|
| 892 |
-
) -> None:
|
| 893 |
-
from omegaconf.omegaconf import _node_wrap
|
| 894 |
-
|
| 895 |
-
ref_type = self._metadata.ref_type
|
| 896 |
-
type_hint = self._metadata.type_hint
|
| 897 |
-
|
| 898 |
-
value = _get_value(value)
|
| 899 |
-
if _is_special(value):
|
| 900 |
-
assert isinstance(value, (str, NoneType))
|
| 901 |
-
if value is None:
|
| 902 |
-
if not self._is_optional():
|
| 903 |
-
raise ValidationError(
|
| 904 |
-
f"Value '$VALUE' is incompatible with type hint '{type_str(type_hint)}'"
|
| 905 |
-
)
|
| 906 |
-
self.__dict__["_content"] = value
|
| 907 |
-
elif isinstance(value, Container):
|
| 908 |
-
raise ValidationError(
|
| 909 |
-
f"Cannot assign container '$VALUE' of type '$VALUE_TYPE' to {type_str(type_hint)}"
|
| 910 |
-
)
|
| 911 |
-
else:
|
| 912 |
-
for candidate_ref_type in ref_type.__args__:
|
| 913 |
-
try:
|
| 914 |
-
self.__dict__["_content"] = _node_wrap(
|
| 915 |
-
value=value,
|
| 916 |
-
ref_type=candidate_ref_type,
|
| 917 |
-
is_optional=False,
|
| 918 |
-
key=None,
|
| 919 |
-
parent=self,
|
| 920 |
-
)
|
| 921 |
-
break
|
| 922 |
-
except ValidationError:
|
| 923 |
-
continue
|
| 924 |
-
else:
|
| 925 |
-
raise ValidationError(
|
| 926 |
-
f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_str(type_hint)}'"
|
| 927 |
-
)
|
| 928 |
-
|
| 929 |
-
def _is_optional(self) -> bool:
|
| 930 |
-
return self.__dict__["_metadata"].optional is True
|
| 931 |
-
|
| 932 |
-
def _is_interpolation(self) -> bool:
|
| 933 |
-
return _is_interpolation(self.__dict__["_content"])
|
| 934 |
-
|
| 935 |
-
def __str__(self) -> str:
|
| 936 |
-
return str(self.__dict__["_content"])
|
| 937 |
-
|
| 938 |
-
def __repr__(self) -> str:
|
| 939 |
-
return repr(self.__dict__["_content"])
|
| 940 |
-
|
| 941 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "UnionNode":
|
| 942 |
-
res = object.__new__(type(self))
|
| 943 |
-
for key, value in self.__dict__.items():
|
| 944 |
-
if key not in ("_content", "_parent"):
|
| 945 |
-
res.__dict__[key] = copy.deepcopy(value, memo=memo)
|
| 946 |
-
|
| 947 |
-
src_content = self.__dict__["_content"]
|
| 948 |
-
if isinstance(src_content, Node):
|
| 949 |
-
old_parent = src_content.__dict__["_parent"]
|
| 950 |
-
try:
|
| 951 |
-
src_content.__dict__["_parent"] = None
|
| 952 |
-
content_copy = copy.deepcopy(src_content, memo=memo)
|
| 953 |
-
content_copy.__dict__["_parent"] = res
|
| 954 |
-
finally:
|
| 955 |
-
src_content.__dict__["_parent"] = old_parent
|
| 956 |
-
else:
|
| 957 |
-
# None and strings can be assigned as is
|
| 958 |
-
content_copy = src_content
|
| 959 |
-
|
| 960 |
-
res.__dict__["_content"] = content_copy
|
| 961 |
-
res.__dict__["_parent"] = self.__dict__["_parent"]
|
| 962 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/basecontainer.py
DELETED
|
@@ -1,916 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import sys
|
| 3 |
-
from abc import ABC, abstractmethod
|
| 4 |
-
from enum import Enum
|
| 5 |
-
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union
|
| 6 |
-
|
| 7 |
-
import yaml
|
| 8 |
-
|
| 9 |
-
from ._utils import (
|
| 10 |
-
_DEFAULT_MARKER_,
|
| 11 |
-
ValueKind,
|
| 12 |
-
_ensure_container,
|
| 13 |
-
_get_value,
|
| 14 |
-
_is_interpolation,
|
| 15 |
-
_is_missing_value,
|
| 16 |
-
_is_none,
|
| 17 |
-
_is_special,
|
| 18 |
-
_resolve_optional,
|
| 19 |
-
get_structured_config_data,
|
| 20 |
-
get_type_hint,
|
| 21 |
-
get_value_kind,
|
| 22 |
-
get_yaml_loader,
|
| 23 |
-
is_container_annotation,
|
| 24 |
-
is_dict_annotation,
|
| 25 |
-
is_list_annotation,
|
| 26 |
-
is_primitive_dict,
|
| 27 |
-
is_primitive_type_annotation,
|
| 28 |
-
is_structured_config,
|
| 29 |
-
is_tuple_annotation,
|
| 30 |
-
is_union_annotation,
|
| 31 |
-
)
|
| 32 |
-
from .base import (
|
| 33 |
-
Box,
|
| 34 |
-
Container,
|
| 35 |
-
ContainerMetadata,
|
| 36 |
-
DictKeyType,
|
| 37 |
-
Node,
|
| 38 |
-
SCMode,
|
| 39 |
-
UnionNode,
|
| 40 |
-
)
|
| 41 |
-
from .errors import (
|
| 42 |
-
ConfigCycleDetectedException,
|
| 43 |
-
ConfigTypeError,
|
| 44 |
-
InterpolationResolutionError,
|
| 45 |
-
KeyValidationError,
|
| 46 |
-
MissingMandatoryValue,
|
| 47 |
-
OmegaConfBaseException,
|
| 48 |
-
ReadonlyConfigError,
|
| 49 |
-
ValidationError,
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
if TYPE_CHECKING:
|
| 53 |
-
from .dictconfig import DictConfig # pragma: no cover
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class BaseContainer(Container, ABC):
|
| 57 |
-
_resolvers: ClassVar[Dict[str, Any]] = {}
|
| 58 |
-
|
| 59 |
-
def __init__(self, parent: Optional[Box], metadata: ContainerMetadata):
|
| 60 |
-
if not (parent is None or isinstance(parent, Box)):
|
| 61 |
-
raise ConfigTypeError("Parent type is not omegaconf.Box")
|
| 62 |
-
super().__init__(parent=parent, metadata=metadata)
|
| 63 |
-
|
| 64 |
-
def _get_child(
|
| 65 |
-
self,
|
| 66 |
-
key: Any,
|
| 67 |
-
validate_access: bool = True,
|
| 68 |
-
validate_key: bool = True,
|
| 69 |
-
throw_on_missing_value: bool = False,
|
| 70 |
-
throw_on_missing_key: bool = False,
|
| 71 |
-
) -> Union[Optional[Node], List[Optional[Node]]]:
|
| 72 |
-
"""Like _get_node, passing through to the nearest concrete Node."""
|
| 73 |
-
child = self._get_node(
|
| 74 |
-
key=key,
|
| 75 |
-
validate_access=validate_access,
|
| 76 |
-
validate_key=validate_key,
|
| 77 |
-
throw_on_missing_value=throw_on_missing_value,
|
| 78 |
-
throw_on_missing_key=throw_on_missing_key,
|
| 79 |
-
)
|
| 80 |
-
if isinstance(child, UnionNode) and not _is_special(child):
|
| 81 |
-
value = child._value()
|
| 82 |
-
assert isinstance(value, Node) and not isinstance(value, UnionNode)
|
| 83 |
-
child = value
|
| 84 |
-
return child
|
| 85 |
-
|
| 86 |
-
def _resolve_with_default(
|
| 87 |
-
self,
|
| 88 |
-
key: Union[DictKeyType, int],
|
| 89 |
-
value: Node,
|
| 90 |
-
default_value: Any = _DEFAULT_MARKER_,
|
| 91 |
-
) -> Any:
|
| 92 |
-
"""returns the value with the specified key, like obj.key and obj['key']"""
|
| 93 |
-
if _is_missing_value(value):
|
| 94 |
-
if default_value is not _DEFAULT_MARKER_:
|
| 95 |
-
return default_value
|
| 96 |
-
raise MissingMandatoryValue("Missing mandatory value: $FULL_KEY")
|
| 97 |
-
|
| 98 |
-
resolved_node = self._maybe_resolve_interpolation(
|
| 99 |
-
parent=self,
|
| 100 |
-
key=key,
|
| 101 |
-
value=value,
|
| 102 |
-
throw_on_resolution_failure=True,
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
return _get_value(resolved_node)
|
| 106 |
-
|
| 107 |
-
def __str__(self) -> str:
|
| 108 |
-
return self.__repr__()
|
| 109 |
-
|
| 110 |
-
def __repr__(self) -> str:
|
| 111 |
-
if self.__dict__["_content"] is None:
|
| 112 |
-
return "None"
|
| 113 |
-
elif self._is_interpolation() or self._is_missing():
|
| 114 |
-
v = self.__dict__["_content"]
|
| 115 |
-
return f"'{v}'"
|
| 116 |
-
else:
|
| 117 |
-
return self.__dict__["_content"].__repr__() # type: ignore
|
| 118 |
-
|
| 119 |
-
# Support pickle
|
| 120 |
-
def __getstate__(self) -> Dict[str, Any]:
|
| 121 |
-
dict_copy = copy.copy(self.__dict__)
|
| 122 |
-
|
| 123 |
-
# no need to serialize the flags cache, it can be re-constructed later
|
| 124 |
-
dict_copy.pop("_flags_cache", None)
|
| 125 |
-
|
| 126 |
-
dict_copy["_metadata"] = copy.copy(dict_copy["_metadata"])
|
| 127 |
-
ref_type = self._metadata.ref_type
|
| 128 |
-
if is_container_annotation(ref_type):
|
| 129 |
-
if is_dict_annotation(ref_type):
|
| 130 |
-
dict_copy["_metadata"].ref_type = Dict
|
| 131 |
-
elif is_list_annotation(ref_type):
|
| 132 |
-
dict_copy["_metadata"].ref_type = List
|
| 133 |
-
else:
|
| 134 |
-
assert False
|
| 135 |
-
if sys.version_info < (3, 7): # pragma: no cover
|
| 136 |
-
element_type = self._metadata.element_type
|
| 137 |
-
if is_union_annotation(element_type):
|
| 138 |
-
raise OmegaConfBaseException(
|
| 139 |
-
"Serializing structured configs with `Union` element type requires python >= 3.7"
|
| 140 |
-
)
|
| 141 |
-
return dict_copy
|
| 142 |
-
|
| 143 |
-
# Support pickle
|
| 144 |
-
def __setstate__(self, d: Dict[str, Any]) -> None:
|
| 145 |
-
from omegaconf import DictConfig
|
| 146 |
-
from omegaconf._utils import is_generic_dict, is_generic_list
|
| 147 |
-
|
| 148 |
-
if isinstance(self, DictConfig):
|
| 149 |
-
key_type = d["_metadata"].key_type
|
| 150 |
-
|
| 151 |
-
# backward compatibility to load OmegaConf 2.0 configs
|
| 152 |
-
if key_type is None:
|
| 153 |
-
key_type = Any
|
| 154 |
-
d["_metadata"].key_type = key_type
|
| 155 |
-
|
| 156 |
-
element_type = d["_metadata"].element_type
|
| 157 |
-
|
| 158 |
-
# backward compatibility to load OmegaConf 2.0 configs
|
| 159 |
-
if element_type is None:
|
| 160 |
-
element_type = Any
|
| 161 |
-
d["_metadata"].element_type = element_type
|
| 162 |
-
|
| 163 |
-
ref_type = d["_metadata"].ref_type
|
| 164 |
-
if is_container_annotation(ref_type):
|
| 165 |
-
if is_generic_dict(ref_type):
|
| 166 |
-
d["_metadata"].ref_type = Dict[key_type, element_type] # type: ignore
|
| 167 |
-
elif is_generic_list(ref_type):
|
| 168 |
-
d["_metadata"].ref_type = List[element_type] # type: ignore
|
| 169 |
-
else:
|
| 170 |
-
assert False
|
| 171 |
-
|
| 172 |
-
d["_flags_cache"] = None
|
| 173 |
-
self.__dict__.update(d)
|
| 174 |
-
|
| 175 |
-
@abstractmethod
|
| 176 |
-
def __delitem__(self, key: Any) -> None:
|
| 177 |
-
...
|
| 178 |
-
|
| 179 |
-
def __len__(self) -> int:
|
| 180 |
-
if self._is_none() or self._is_missing() or self._is_interpolation():
|
| 181 |
-
return 0
|
| 182 |
-
content = self.__dict__["_content"]
|
| 183 |
-
return len(content)
|
| 184 |
-
|
| 185 |
-
def merge_with_cli(self) -> None:
|
| 186 |
-
args_list = sys.argv[1:]
|
| 187 |
-
self.merge_with_dotlist(args_list)
|
| 188 |
-
|
| 189 |
-
def merge_with_dotlist(self, dotlist: List[str]) -> None:
|
| 190 |
-
from omegaconf import OmegaConf
|
| 191 |
-
|
| 192 |
-
def fail() -> None:
|
| 193 |
-
raise ValueError("Input list must be a list or a tuple of strings")
|
| 194 |
-
|
| 195 |
-
if not isinstance(dotlist, (list, tuple)):
|
| 196 |
-
fail()
|
| 197 |
-
|
| 198 |
-
for arg in dotlist:
|
| 199 |
-
if not isinstance(arg, str):
|
| 200 |
-
fail()
|
| 201 |
-
|
| 202 |
-
idx = arg.find("=")
|
| 203 |
-
if idx == -1:
|
| 204 |
-
key = arg
|
| 205 |
-
value = None
|
| 206 |
-
else:
|
| 207 |
-
key = arg[0:idx]
|
| 208 |
-
value = arg[idx + 1 :]
|
| 209 |
-
value = yaml.load(value, Loader=get_yaml_loader())
|
| 210 |
-
|
| 211 |
-
OmegaConf.update(self, key, value)
|
| 212 |
-
|
| 213 |
-
def is_empty(self) -> bool:
|
| 214 |
-
"""return true if config is empty"""
|
| 215 |
-
return len(self.__dict__["_content"]) == 0
|
| 216 |
-
|
| 217 |
-
@staticmethod
|
| 218 |
-
def _to_content(
|
| 219 |
-
conf: Container,
|
| 220 |
-
resolve: bool,
|
| 221 |
-
throw_on_missing: bool,
|
| 222 |
-
enum_to_str: bool = False,
|
| 223 |
-
structured_config_mode: SCMode = SCMode.DICT,
|
| 224 |
-
) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]:
|
| 225 |
-
from omegaconf import MISSING, DictConfig, ListConfig
|
| 226 |
-
|
| 227 |
-
def convert(val: Node) -> Any:
|
| 228 |
-
value = val._value()
|
| 229 |
-
if enum_to_str and isinstance(value, Enum):
|
| 230 |
-
value = f"{value.name}"
|
| 231 |
-
|
| 232 |
-
return value
|
| 233 |
-
|
| 234 |
-
def get_node_value(key: Union[DictKeyType, int]) -> Any:
|
| 235 |
-
try:
|
| 236 |
-
node = conf._get_child(key, throw_on_missing_value=throw_on_missing)
|
| 237 |
-
except MissingMandatoryValue as e:
|
| 238 |
-
conf._format_and_raise(key=key, value=None, cause=e)
|
| 239 |
-
assert isinstance(node, Node)
|
| 240 |
-
if resolve:
|
| 241 |
-
try:
|
| 242 |
-
node = node._dereference_node()
|
| 243 |
-
except InterpolationResolutionError as e:
|
| 244 |
-
conf._format_and_raise(key=key, value=None, cause=e)
|
| 245 |
-
|
| 246 |
-
if isinstance(node, Container):
|
| 247 |
-
value = BaseContainer._to_content(
|
| 248 |
-
node,
|
| 249 |
-
resolve=resolve,
|
| 250 |
-
throw_on_missing=throw_on_missing,
|
| 251 |
-
enum_to_str=enum_to_str,
|
| 252 |
-
structured_config_mode=structured_config_mode,
|
| 253 |
-
)
|
| 254 |
-
else:
|
| 255 |
-
value = convert(node)
|
| 256 |
-
return value
|
| 257 |
-
|
| 258 |
-
if conf._is_none():
|
| 259 |
-
return None
|
| 260 |
-
elif conf._is_missing():
|
| 261 |
-
if throw_on_missing:
|
| 262 |
-
conf._format_and_raise(
|
| 263 |
-
key=None,
|
| 264 |
-
value=None,
|
| 265 |
-
cause=MissingMandatoryValue("Missing mandatory value"),
|
| 266 |
-
)
|
| 267 |
-
else:
|
| 268 |
-
return MISSING
|
| 269 |
-
elif not resolve and conf._is_interpolation():
|
| 270 |
-
inter = conf._value()
|
| 271 |
-
assert isinstance(inter, str)
|
| 272 |
-
return inter
|
| 273 |
-
|
| 274 |
-
if resolve:
|
| 275 |
-
_conf = conf._dereference_node()
|
| 276 |
-
assert isinstance(_conf, Container)
|
| 277 |
-
conf = _conf
|
| 278 |
-
|
| 279 |
-
if isinstance(conf, DictConfig):
|
| 280 |
-
if (
|
| 281 |
-
conf._metadata.object_type not in (dict, None)
|
| 282 |
-
and structured_config_mode == SCMode.DICT_CONFIG
|
| 283 |
-
):
|
| 284 |
-
return conf
|
| 285 |
-
if structured_config_mode == SCMode.INSTANTIATE and is_structured_config(
|
| 286 |
-
conf._metadata.object_type
|
| 287 |
-
):
|
| 288 |
-
return conf._to_object()
|
| 289 |
-
|
| 290 |
-
retdict: Dict[DictKeyType, Any] = {}
|
| 291 |
-
for key in conf.keys():
|
| 292 |
-
value = get_node_value(key)
|
| 293 |
-
if enum_to_str and isinstance(key, Enum):
|
| 294 |
-
key = f"{key.name}"
|
| 295 |
-
retdict[key] = value
|
| 296 |
-
return retdict
|
| 297 |
-
elif isinstance(conf, ListConfig):
|
| 298 |
-
retlist: List[Any] = []
|
| 299 |
-
for index in range(len(conf)):
|
| 300 |
-
item = get_node_value(index)
|
| 301 |
-
retlist.append(item)
|
| 302 |
-
|
| 303 |
-
return retlist
|
| 304 |
-
assert False
|
| 305 |
-
|
| 306 |
-
@staticmethod
|
| 307 |
-
def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
|
| 308 |
-
"""merge src into dest and return a new copy, does not modified input"""
|
| 309 |
-
from omegaconf import AnyNode, DictConfig, ValueNode
|
| 310 |
-
|
| 311 |
-
assert isinstance(dest, DictConfig)
|
| 312 |
-
assert isinstance(src, DictConfig)
|
| 313 |
-
src_type = src._metadata.object_type
|
| 314 |
-
src_ref_type = get_type_hint(src)
|
| 315 |
-
assert src_ref_type is not None
|
| 316 |
-
|
| 317 |
-
# If source DictConfig is:
|
| 318 |
-
# - None => set the destination DictConfig to None
|
| 319 |
-
# - an interpolation => set the destination DictConfig to be the same interpolation
|
| 320 |
-
if src._is_none() or src._is_interpolation():
|
| 321 |
-
dest._set_value(src._value())
|
| 322 |
-
_update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
|
| 323 |
-
return
|
| 324 |
-
|
| 325 |
-
dest._validate_merge(value=src)
|
| 326 |
-
|
| 327 |
-
def expand(node: Container) -> None:
|
| 328 |
-
rt = node._metadata.ref_type
|
| 329 |
-
val: Any
|
| 330 |
-
if rt is not Any:
|
| 331 |
-
if is_dict_annotation(rt):
|
| 332 |
-
val = {}
|
| 333 |
-
elif is_list_annotation(rt) or is_tuple_annotation(rt):
|
| 334 |
-
val = []
|
| 335 |
-
else:
|
| 336 |
-
val = rt
|
| 337 |
-
elif isinstance(node, DictConfig):
|
| 338 |
-
val = {}
|
| 339 |
-
else:
|
| 340 |
-
assert False
|
| 341 |
-
|
| 342 |
-
node._set_value(val)
|
| 343 |
-
|
| 344 |
-
if (
|
| 345 |
-
src._is_missing()
|
| 346 |
-
and not dest._is_missing()
|
| 347 |
-
and is_structured_config(src_ref_type)
|
| 348 |
-
):
|
| 349 |
-
# Replace `src` with a prototype of its corresponding structured config
|
| 350 |
-
# whose fields are all missing (to avoid overwriting fields in `dest`).
|
| 351 |
-
assert src_type is None # src missing, so src's object_type should be None
|
| 352 |
-
src_type = src_ref_type
|
| 353 |
-
src = _create_structured_with_missing_fields(
|
| 354 |
-
ref_type=src_ref_type, object_type=src_type
|
| 355 |
-
)
|
| 356 |
-
|
| 357 |
-
if (dest._is_interpolation() or dest._is_missing()) and not src._is_missing():
|
| 358 |
-
expand(dest)
|
| 359 |
-
|
| 360 |
-
src_items = list(src) if not src._is_missing() else []
|
| 361 |
-
for key in src_items:
|
| 362 |
-
src_node = src._get_node(key, validate_access=False)
|
| 363 |
-
dest_node = dest._get_node(key, validate_access=False)
|
| 364 |
-
assert isinstance(src_node, Node)
|
| 365 |
-
assert dest_node is None or isinstance(dest_node, Node)
|
| 366 |
-
src_value = _get_value(src_node)
|
| 367 |
-
|
| 368 |
-
src_vk = get_value_kind(src_node)
|
| 369 |
-
src_node_missing = src_vk is ValueKind.MANDATORY_MISSING
|
| 370 |
-
|
| 371 |
-
if isinstance(dest_node, DictConfig):
|
| 372 |
-
dest_node._validate_merge(value=src_node)
|
| 373 |
-
|
| 374 |
-
if (
|
| 375 |
-
isinstance(dest_node, Container)
|
| 376 |
-
and dest_node._is_none()
|
| 377 |
-
and not src_node_missing
|
| 378 |
-
and not _is_none(src_node, resolve=True)
|
| 379 |
-
):
|
| 380 |
-
expand(dest_node)
|
| 381 |
-
|
| 382 |
-
if dest_node is not None and dest_node._is_interpolation():
|
| 383 |
-
target_node = dest_node._maybe_dereference_node()
|
| 384 |
-
if isinstance(target_node, Container):
|
| 385 |
-
dest[key] = target_node
|
| 386 |
-
dest_node = dest._get_node(key)
|
| 387 |
-
|
| 388 |
-
is_optional, et = _resolve_optional(dest._metadata.element_type)
|
| 389 |
-
if dest_node is None and is_structured_config(et) and not src_node_missing:
|
| 390 |
-
# merging into a new node. Use element_type as a base
|
| 391 |
-
dest[key] = DictConfig(
|
| 392 |
-
et, parent=dest, ref_type=et, is_optional=is_optional
|
| 393 |
-
)
|
| 394 |
-
dest_node = dest._get_node(key)
|
| 395 |
-
|
| 396 |
-
if dest_node is not None:
|
| 397 |
-
if isinstance(dest_node, BaseContainer):
|
| 398 |
-
if isinstance(src_node, BaseContainer):
|
| 399 |
-
dest_node._merge_with(src_node)
|
| 400 |
-
elif not src_node_missing:
|
| 401 |
-
dest.__setitem__(key, src_node)
|
| 402 |
-
else:
|
| 403 |
-
if isinstance(src_node, BaseContainer):
|
| 404 |
-
dest.__setitem__(key, src_node)
|
| 405 |
-
else:
|
| 406 |
-
assert isinstance(dest_node, (ValueNode, UnionNode))
|
| 407 |
-
assert isinstance(src_node, (ValueNode, UnionNode))
|
| 408 |
-
try:
|
| 409 |
-
if isinstance(dest_node, AnyNode):
|
| 410 |
-
if src_node_missing:
|
| 411 |
-
node = copy.copy(src_node)
|
| 412 |
-
# if src node is missing, use the value from the dest_node,
|
| 413 |
-
# but validate it against the type of the src node before assigment
|
| 414 |
-
node._set_value(dest_node._value())
|
| 415 |
-
else:
|
| 416 |
-
node = src_node
|
| 417 |
-
dest.__setitem__(key, node)
|
| 418 |
-
else:
|
| 419 |
-
if not src_node_missing:
|
| 420 |
-
dest_node._set_value(src_value)
|
| 421 |
-
|
| 422 |
-
except (ValidationError, ReadonlyConfigError) as e:
|
| 423 |
-
dest._format_and_raise(key=key, value=src_value, cause=e)
|
| 424 |
-
else:
|
| 425 |
-
from omegaconf import open_dict
|
| 426 |
-
|
| 427 |
-
if is_structured_config(src_type):
|
| 428 |
-
# verified to be compatible above in _validate_merge
|
| 429 |
-
with open_dict(dest):
|
| 430 |
-
dest[key] = src._get_node(key)
|
| 431 |
-
else:
|
| 432 |
-
dest[key] = src._get_node(key)
|
| 433 |
-
|
| 434 |
-
_update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
|
| 435 |
-
|
| 436 |
-
# explicit flags on the source config are replacing the flag values in the destination
|
| 437 |
-
flags = src._metadata.flags
|
| 438 |
-
assert flags is not None
|
| 439 |
-
for flag, value in flags.items():
|
| 440 |
-
if value is not None:
|
| 441 |
-
dest._set_flag(flag, value)
|
| 442 |
-
|
| 443 |
-
@staticmethod
|
| 444 |
-
def _list_merge(dest: Any, src: Any) -> None:
|
| 445 |
-
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 446 |
-
|
| 447 |
-
assert isinstance(dest, ListConfig)
|
| 448 |
-
assert isinstance(src, ListConfig)
|
| 449 |
-
|
| 450 |
-
if src._is_none():
|
| 451 |
-
dest._set_value(None)
|
| 452 |
-
elif src._is_missing():
|
| 453 |
-
# do not change dest if src is MISSING.
|
| 454 |
-
if dest._metadata.element_type is Any:
|
| 455 |
-
dest._metadata.element_type = src._metadata.element_type
|
| 456 |
-
elif src._is_interpolation():
|
| 457 |
-
dest._set_value(src._value())
|
| 458 |
-
else:
|
| 459 |
-
temp_target = ListConfig(content=[], parent=dest._get_parent())
|
| 460 |
-
temp_target.__dict__["_metadata"] = copy.deepcopy(
|
| 461 |
-
dest.__dict__["_metadata"]
|
| 462 |
-
)
|
| 463 |
-
is_optional, et = _resolve_optional(dest._metadata.element_type)
|
| 464 |
-
if is_structured_config(et):
|
| 465 |
-
prototype = DictConfig(et, ref_type=et, is_optional=is_optional)
|
| 466 |
-
for item in src._iter_ex(resolve=False):
|
| 467 |
-
if isinstance(item, DictConfig):
|
| 468 |
-
item = OmegaConf.merge(prototype, item)
|
| 469 |
-
temp_target.append(item)
|
| 470 |
-
else:
|
| 471 |
-
for item in src._iter_ex(resolve=False):
|
| 472 |
-
temp_target.append(item)
|
| 473 |
-
|
| 474 |
-
dest.__dict__["_content"] = temp_target.__dict__["_content"]
|
| 475 |
-
|
| 476 |
-
# explicit flags on the source config are replacing the flag values in the destination
|
| 477 |
-
flags = src._metadata.flags
|
| 478 |
-
assert flags is not None
|
| 479 |
-
for flag, value in flags.items():
|
| 480 |
-
if value is not None:
|
| 481 |
-
dest._set_flag(flag, value)
|
| 482 |
-
|
| 483 |
-
def merge_with(
|
| 484 |
-
self,
|
| 485 |
-
*others: Union[
|
| 486 |
-
"BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
|
| 487 |
-
],
|
| 488 |
-
) -> None:
|
| 489 |
-
try:
|
| 490 |
-
self._merge_with(*others)
|
| 491 |
-
except Exception as e:
|
| 492 |
-
self._format_and_raise(key=None, value=None, cause=e)
|
| 493 |
-
|
| 494 |
-
def _merge_with(
|
| 495 |
-
self,
|
| 496 |
-
*others: Union[
|
| 497 |
-
"BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
|
| 498 |
-
],
|
| 499 |
-
) -> None:
|
| 500 |
-
from .dictconfig import DictConfig
|
| 501 |
-
from .listconfig import ListConfig
|
| 502 |
-
|
| 503 |
-
"""merge a list of other Config objects into this one, overriding as needed"""
|
| 504 |
-
for other in others:
|
| 505 |
-
if other is None:
|
| 506 |
-
raise ValueError("Cannot merge with a None config")
|
| 507 |
-
|
| 508 |
-
my_flags = {}
|
| 509 |
-
if self._get_flag("allow_objects") is True:
|
| 510 |
-
my_flags = {"allow_objects": True}
|
| 511 |
-
other = _ensure_container(other, flags=my_flags)
|
| 512 |
-
|
| 513 |
-
if isinstance(self, DictConfig) and isinstance(other, DictConfig):
|
| 514 |
-
BaseContainer._map_merge(self, other)
|
| 515 |
-
elif isinstance(self, ListConfig) and isinstance(other, ListConfig):
|
| 516 |
-
BaseContainer._list_merge(self, other)
|
| 517 |
-
else:
|
| 518 |
-
raise TypeError("Cannot merge DictConfig with ListConfig")
|
| 519 |
-
|
| 520 |
-
# recursively correct the parent hierarchy after the merge
|
| 521 |
-
self._re_parent()
|
| 522 |
-
|
| 523 |
-
# noinspection PyProtectedMember
|
| 524 |
-
def _set_item_impl(self, key: Any, value: Any) -> None:
|
| 525 |
-
"""
|
| 526 |
-
Changes the value of the node key with the desired value. If the node key doesn't
|
| 527 |
-
exist it creates a new one.
|
| 528 |
-
"""
|
| 529 |
-
from .nodes import AnyNode, ValueNode
|
| 530 |
-
|
| 531 |
-
if isinstance(value, Node):
|
| 532 |
-
do_deepcopy = not self._get_flag("no_deepcopy_set_nodes")
|
| 533 |
-
if not do_deepcopy and isinstance(value, Box):
|
| 534 |
-
# if value is from the same config, perform a deepcopy no matter what.
|
| 535 |
-
if self._get_root() is value._get_root():
|
| 536 |
-
do_deepcopy = True
|
| 537 |
-
|
| 538 |
-
if do_deepcopy:
|
| 539 |
-
value = copy.deepcopy(value)
|
| 540 |
-
value._set_parent(None)
|
| 541 |
-
|
| 542 |
-
try:
|
| 543 |
-
old = value._key()
|
| 544 |
-
value._set_key(key)
|
| 545 |
-
self._validate_set(key, value)
|
| 546 |
-
finally:
|
| 547 |
-
value._set_key(old)
|
| 548 |
-
else:
|
| 549 |
-
self._validate_set(key, value)
|
| 550 |
-
|
| 551 |
-
if self._get_flag("readonly"):
|
| 552 |
-
raise ReadonlyConfigError("Cannot change read-only config container")
|
| 553 |
-
|
| 554 |
-
input_is_node = isinstance(value, Node)
|
| 555 |
-
target_node_ref = self._get_node(key)
|
| 556 |
-
assert target_node_ref is None or isinstance(target_node_ref, Node)
|
| 557 |
-
|
| 558 |
-
input_is_typed_vnode = isinstance(value, ValueNode) and not isinstance(
|
| 559 |
-
value, AnyNode
|
| 560 |
-
)
|
| 561 |
-
|
| 562 |
-
def get_target_type_hint(val: Any) -> Any:
|
| 563 |
-
if not is_structured_config(val):
|
| 564 |
-
type_hint = self._metadata.element_type
|
| 565 |
-
else:
|
| 566 |
-
target = self._get_node(key)
|
| 567 |
-
if target is None:
|
| 568 |
-
type_hint = self._metadata.element_type
|
| 569 |
-
else:
|
| 570 |
-
assert isinstance(target, Node)
|
| 571 |
-
type_hint = target._metadata.type_hint
|
| 572 |
-
return type_hint
|
| 573 |
-
|
| 574 |
-
target_type_hint = get_target_type_hint(value)
|
| 575 |
-
_, target_ref_type = _resolve_optional(target_type_hint)
|
| 576 |
-
|
| 577 |
-
def assign(value_key: Any, val: Node) -> None:
|
| 578 |
-
assert val._get_parent() is None
|
| 579 |
-
v = val
|
| 580 |
-
v._set_parent(self)
|
| 581 |
-
v._set_key(value_key)
|
| 582 |
-
_deep_update_type_hint(node=v, type_hint=self._metadata.element_type)
|
| 583 |
-
self.__dict__["_content"][value_key] = v
|
| 584 |
-
|
| 585 |
-
if input_is_typed_vnode and not is_union_annotation(target_ref_type):
|
| 586 |
-
assign(key, value)
|
| 587 |
-
else:
|
| 588 |
-
# input is not a ValueNode, can be primitive or box
|
| 589 |
-
|
| 590 |
-
special_value = _is_special(value)
|
| 591 |
-
# We use the `Node._set_value` method if the target node exists and:
|
| 592 |
-
# 1. the target has an explicit ref_type, or
|
| 593 |
-
# 2. the target is an AnyNode and the input is a primitive type.
|
| 594 |
-
should_set_value = target_node_ref is not None and (
|
| 595 |
-
target_node_ref._has_ref_type()
|
| 596 |
-
or (
|
| 597 |
-
isinstance(target_node_ref, AnyNode)
|
| 598 |
-
and is_primitive_type_annotation(value)
|
| 599 |
-
)
|
| 600 |
-
)
|
| 601 |
-
if should_set_value:
|
| 602 |
-
if special_value and isinstance(value, Node):
|
| 603 |
-
value = value._value()
|
| 604 |
-
self.__dict__["_content"][key]._set_value(value)
|
| 605 |
-
elif input_is_node:
|
| 606 |
-
if (
|
| 607 |
-
special_value
|
| 608 |
-
and (
|
| 609 |
-
is_container_annotation(target_ref_type)
|
| 610 |
-
or is_structured_config(target_ref_type)
|
| 611 |
-
)
|
| 612 |
-
or is_primitive_type_annotation(target_ref_type)
|
| 613 |
-
or is_union_annotation(target_ref_type)
|
| 614 |
-
):
|
| 615 |
-
value = _get_value(value)
|
| 616 |
-
self._wrap_value_and_set(key, value, target_type_hint)
|
| 617 |
-
else:
|
| 618 |
-
assign(key, value)
|
| 619 |
-
else:
|
| 620 |
-
self._wrap_value_and_set(key, value, target_type_hint)
|
| 621 |
-
|
| 622 |
-
def _wrap_value_and_set(self, key: Any, val: Any, type_hint: Any) -> None:
|
| 623 |
-
from omegaconf.omegaconf import _maybe_wrap
|
| 624 |
-
|
| 625 |
-
is_optional, ref_type = _resolve_optional(type_hint)
|
| 626 |
-
|
| 627 |
-
try:
|
| 628 |
-
wrapped = _maybe_wrap(
|
| 629 |
-
ref_type=ref_type,
|
| 630 |
-
key=key,
|
| 631 |
-
value=val,
|
| 632 |
-
is_optional=is_optional,
|
| 633 |
-
parent=self,
|
| 634 |
-
)
|
| 635 |
-
except ValidationError as e:
|
| 636 |
-
self._format_and_raise(key=key, value=val, cause=e)
|
| 637 |
-
self.__dict__["_content"][key] = wrapped
|
| 638 |
-
|
| 639 |
-
@staticmethod
|
| 640 |
-
def _item_eq(
|
| 641 |
-
c1: Container,
|
| 642 |
-
k1: Union[DictKeyType, int],
|
| 643 |
-
c2: Container,
|
| 644 |
-
k2: Union[DictKeyType, int],
|
| 645 |
-
) -> bool:
|
| 646 |
-
v1 = c1._get_child(k1)
|
| 647 |
-
v2 = c2._get_child(k2)
|
| 648 |
-
assert v1 is not None and v2 is not None
|
| 649 |
-
|
| 650 |
-
assert isinstance(v1, Node)
|
| 651 |
-
assert isinstance(v2, Node)
|
| 652 |
-
|
| 653 |
-
if v1._is_none() and v2._is_none():
|
| 654 |
-
return True
|
| 655 |
-
|
| 656 |
-
if v1._is_missing() and v2._is_missing():
|
| 657 |
-
return True
|
| 658 |
-
|
| 659 |
-
v1_inter = v1._is_interpolation()
|
| 660 |
-
v2_inter = v2._is_interpolation()
|
| 661 |
-
dv1: Optional[Node] = v1
|
| 662 |
-
dv2: Optional[Node] = v2
|
| 663 |
-
|
| 664 |
-
if v1_inter:
|
| 665 |
-
dv1 = v1._maybe_dereference_node()
|
| 666 |
-
if v2_inter:
|
| 667 |
-
dv2 = v2._maybe_dereference_node()
|
| 668 |
-
|
| 669 |
-
if v1_inter and v2_inter:
|
| 670 |
-
if dv1 is None or dv2 is None:
|
| 671 |
-
return v1 == v2
|
| 672 |
-
else:
|
| 673 |
-
# both are not none, if both are containers compare as container
|
| 674 |
-
if isinstance(dv1, Container) and isinstance(dv2, Container):
|
| 675 |
-
if dv1 != dv2:
|
| 676 |
-
return False
|
| 677 |
-
dv1 = _get_value(dv1)
|
| 678 |
-
dv2 = _get_value(dv2)
|
| 679 |
-
return dv1 == dv2
|
| 680 |
-
elif not v1_inter and not v2_inter:
|
| 681 |
-
v1 = _get_value(v1)
|
| 682 |
-
v2 = _get_value(v2)
|
| 683 |
-
ret = v1 == v2
|
| 684 |
-
assert isinstance(ret, bool)
|
| 685 |
-
return ret
|
| 686 |
-
else:
|
| 687 |
-
dv1 = _get_value(dv1)
|
| 688 |
-
dv2 = _get_value(dv2)
|
| 689 |
-
ret = dv1 == dv2
|
| 690 |
-
assert isinstance(ret, bool)
|
| 691 |
-
return ret
|
| 692 |
-
|
| 693 |
-
def _is_optional(self) -> bool:
|
| 694 |
-
return self.__dict__["_metadata"].optional is True
|
| 695 |
-
|
| 696 |
-
def _is_interpolation(self) -> bool:
|
| 697 |
-
return _is_interpolation(self.__dict__["_content"])
|
| 698 |
-
|
| 699 |
-
@abstractmethod
|
| 700 |
-
def _validate_get(self, key: Any, value: Any = None) -> None:
|
| 701 |
-
...
|
| 702 |
-
|
| 703 |
-
@abstractmethod
|
| 704 |
-
def _validate_set(self, key: Any, value: Any) -> None:
|
| 705 |
-
...
|
| 706 |
-
|
| 707 |
-
def _value(self) -> Any:
|
| 708 |
-
return self.__dict__["_content"]
|
| 709 |
-
|
| 710 |
-
def _get_full_key(self, key: Union[DictKeyType, int, slice, None]) -> str:
|
| 711 |
-
from .listconfig import ListConfig
|
| 712 |
-
from .omegaconf import _select_one
|
| 713 |
-
|
| 714 |
-
if not isinstance(key, (int, str, Enum, float, bool, slice, bytes, type(None))):
|
| 715 |
-
return ""
|
| 716 |
-
|
| 717 |
-
def _slice_to_str(x: slice) -> str:
|
| 718 |
-
if x.step is not None:
|
| 719 |
-
return f"{x.start}:{x.stop}:{x.step}"
|
| 720 |
-
else:
|
| 721 |
-
return f"{x.start}:{x.stop}"
|
| 722 |
-
|
| 723 |
-
def prepand(
|
| 724 |
-
full_key: str,
|
| 725 |
-
parent_type: Any,
|
| 726 |
-
cur_type: Any,
|
| 727 |
-
key: Optional[Union[DictKeyType, int, slice]],
|
| 728 |
-
) -> str:
|
| 729 |
-
if key is None:
|
| 730 |
-
return full_key
|
| 731 |
-
|
| 732 |
-
if isinstance(key, slice):
|
| 733 |
-
key = _slice_to_str(key)
|
| 734 |
-
elif isinstance(key, Enum):
|
| 735 |
-
key = key.name
|
| 736 |
-
else:
|
| 737 |
-
key = str(key)
|
| 738 |
-
|
| 739 |
-
assert isinstance(key, str)
|
| 740 |
-
|
| 741 |
-
if issubclass(parent_type, ListConfig):
|
| 742 |
-
if full_key != "":
|
| 743 |
-
if issubclass(cur_type, ListConfig):
|
| 744 |
-
full_key = f"[{key}]{full_key}"
|
| 745 |
-
else:
|
| 746 |
-
full_key = f"[{key}].{full_key}"
|
| 747 |
-
else:
|
| 748 |
-
full_key = f"[{key}]"
|
| 749 |
-
else:
|
| 750 |
-
if full_key == "":
|
| 751 |
-
full_key = key
|
| 752 |
-
else:
|
| 753 |
-
if issubclass(cur_type, ListConfig):
|
| 754 |
-
full_key = f"{key}{full_key}"
|
| 755 |
-
else:
|
| 756 |
-
full_key = f"{key}.{full_key}"
|
| 757 |
-
return full_key
|
| 758 |
-
|
| 759 |
-
if key is not None and key != "":
|
| 760 |
-
assert isinstance(self, Container)
|
| 761 |
-
cur, _ = _select_one(
|
| 762 |
-
c=self, key=str(key), throw_on_missing=False, throw_on_type_error=False
|
| 763 |
-
)
|
| 764 |
-
if cur is None:
|
| 765 |
-
cur = self
|
| 766 |
-
full_key = prepand("", type(cur), None, key)
|
| 767 |
-
if cur._key() is not None:
|
| 768 |
-
full_key = prepand(
|
| 769 |
-
full_key, type(cur._get_parent()), type(cur), cur._key()
|
| 770 |
-
)
|
| 771 |
-
else:
|
| 772 |
-
full_key = prepand("", type(cur._get_parent()), type(cur), cur._key())
|
| 773 |
-
else:
|
| 774 |
-
cur = self
|
| 775 |
-
if cur._key() is None:
|
| 776 |
-
return ""
|
| 777 |
-
full_key = self._key()
|
| 778 |
-
|
| 779 |
-
assert cur is not None
|
| 780 |
-
memo = {id(cur)} # remember already visited nodes so as to detect cycles
|
| 781 |
-
while cur._get_parent() is not None:
|
| 782 |
-
cur = cur._get_parent()
|
| 783 |
-
if id(cur) in memo:
|
| 784 |
-
raise ConfigCycleDetectedException(
|
| 785 |
-
f"Cycle when iterating over parents of key `{key!s}`"
|
| 786 |
-
)
|
| 787 |
-
memo.add(id(cur))
|
| 788 |
-
assert cur is not None
|
| 789 |
-
if cur._key() is not None:
|
| 790 |
-
full_key = prepand(
|
| 791 |
-
full_key, type(cur._get_parent()), type(cur), cur._key()
|
| 792 |
-
)
|
| 793 |
-
|
| 794 |
-
return full_key
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
def _create_structured_with_missing_fields(
|
| 798 |
-
ref_type: type, object_type: Optional[type] = None
|
| 799 |
-
) -> "DictConfig":
|
| 800 |
-
from . import MISSING, DictConfig
|
| 801 |
-
|
| 802 |
-
cfg_data = get_structured_config_data(ref_type)
|
| 803 |
-
for v in cfg_data.values():
|
| 804 |
-
v._set_value(MISSING)
|
| 805 |
-
|
| 806 |
-
cfg = DictConfig(cfg_data)
|
| 807 |
-
cfg._metadata.optional, cfg._metadata.ref_type = _resolve_optional(ref_type)
|
| 808 |
-
cfg._metadata.object_type = object_type
|
| 809 |
-
|
| 810 |
-
return cfg
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
def _update_types(node: Node, ref_type: Any, object_type: Optional[type]) -> None:
|
| 814 |
-
if object_type is not None and not is_primitive_dict(object_type):
|
| 815 |
-
node._metadata.object_type = object_type
|
| 816 |
-
|
| 817 |
-
if node._metadata.ref_type is Any:
|
| 818 |
-
_deep_update_type_hint(node, ref_type)
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
def _deep_update_type_hint(node: Node, type_hint: Any) -> None:
|
| 822 |
-
"""Ensure node is compatible with type_hint, mutating if necessary."""
|
| 823 |
-
from omegaconf import DictConfig, ListConfig
|
| 824 |
-
|
| 825 |
-
from ._utils import get_dict_key_value_types, get_list_element_type
|
| 826 |
-
|
| 827 |
-
if type_hint is Any:
|
| 828 |
-
return
|
| 829 |
-
|
| 830 |
-
_shallow_validate_type_hint(node, type_hint)
|
| 831 |
-
|
| 832 |
-
new_is_optional, new_ref_type = _resolve_optional(type_hint)
|
| 833 |
-
node._metadata.ref_type = new_ref_type
|
| 834 |
-
node._metadata.optional = new_is_optional
|
| 835 |
-
|
| 836 |
-
if is_list_annotation(new_ref_type) and isinstance(node, ListConfig):
|
| 837 |
-
new_element_type = get_list_element_type(new_ref_type)
|
| 838 |
-
node._metadata.element_type = new_element_type
|
| 839 |
-
if not _is_special(node):
|
| 840 |
-
for i in range(len(node)):
|
| 841 |
-
_deep_update_subnode(node, i, new_element_type)
|
| 842 |
-
|
| 843 |
-
if is_dict_annotation(new_ref_type) and isinstance(node, DictConfig):
|
| 844 |
-
new_key_type, new_element_type = get_dict_key_value_types(new_ref_type)
|
| 845 |
-
node._metadata.key_type = new_key_type
|
| 846 |
-
node._metadata.element_type = new_element_type
|
| 847 |
-
if not _is_special(node):
|
| 848 |
-
for key in node:
|
| 849 |
-
if new_key_type is not Any and not isinstance(key, new_key_type):
|
| 850 |
-
raise KeyValidationError(
|
| 851 |
-
f"Key {key!r} ({type(key).__name__}) is incompatible"
|
| 852 |
-
+ f" with key type hint '{new_key_type.__name__}'"
|
| 853 |
-
)
|
| 854 |
-
_deep_update_subnode(node, key, new_element_type)
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
def _deep_update_subnode(node: BaseContainer, key: Any, value_type_hint: Any) -> None:
|
| 858 |
-
"""Get node[key] and ensure it is compatible with value_type_hint, mutating if necessary."""
|
| 859 |
-
subnode = node._get_node(key)
|
| 860 |
-
assert isinstance(subnode, Node)
|
| 861 |
-
if _is_special(subnode):
|
| 862 |
-
# Ensure special values are wrapped in a Node subclass that
|
| 863 |
-
# is compatible with the type hint.
|
| 864 |
-
node._wrap_value_and_set(key, subnode._value(), value_type_hint)
|
| 865 |
-
subnode = node._get_node(key)
|
| 866 |
-
assert isinstance(subnode, Node)
|
| 867 |
-
_deep_update_type_hint(subnode, value_type_hint)
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
def _shallow_validate_type_hint(node: Node, type_hint: Any) -> None:
|
| 871 |
-
"""Error if node's type, content and metadata are not compatible with type_hint."""
|
| 872 |
-
from omegaconf import DictConfig, ListConfig, ValueNode
|
| 873 |
-
|
| 874 |
-
is_optional, ref_type = _resolve_optional(type_hint)
|
| 875 |
-
|
| 876 |
-
vk = get_value_kind(node)
|
| 877 |
-
|
| 878 |
-
if node._is_none():
|
| 879 |
-
if not is_optional:
|
| 880 |
-
value = _get_value(node)
|
| 881 |
-
raise ValidationError(
|
| 882 |
-
f"Value {value!r} ({type(value).__name__})"
|
| 883 |
-
+ f" is incompatible with type hint '{ref_type.__name__}'"
|
| 884 |
-
)
|
| 885 |
-
return
|
| 886 |
-
elif vk in (ValueKind.MANDATORY_MISSING, ValueKind.INTERPOLATION):
|
| 887 |
-
return
|
| 888 |
-
elif vk == ValueKind.VALUE:
|
| 889 |
-
if is_primitive_type_annotation(ref_type) and isinstance(node, ValueNode):
|
| 890 |
-
value = node._value()
|
| 891 |
-
if not isinstance(value, ref_type):
|
| 892 |
-
raise ValidationError(
|
| 893 |
-
f"Value {value!r} ({type(value).__name__})"
|
| 894 |
-
+ f" is incompatible with type hint '{ref_type.__name__}'"
|
| 895 |
-
)
|
| 896 |
-
elif is_structured_config(ref_type) and isinstance(node, DictConfig):
|
| 897 |
-
return
|
| 898 |
-
elif is_dict_annotation(ref_type) and isinstance(node, DictConfig):
|
| 899 |
-
return
|
| 900 |
-
elif is_list_annotation(ref_type) and isinstance(node, ListConfig):
|
| 901 |
-
return
|
| 902 |
-
else:
|
| 903 |
-
if isinstance(node, ValueNode):
|
| 904 |
-
value = node._value()
|
| 905 |
-
raise ValidationError(
|
| 906 |
-
f"Value {value!r} ({type(value).__name__})"
|
| 907 |
-
+ f" is incompatible with type hint '{ref_type}'"
|
| 908 |
-
)
|
| 909 |
-
else:
|
| 910 |
-
raise ValidationError(
|
| 911 |
-
f"'{type(node).__name__}' is incompatible"
|
| 912 |
-
+ f" with type hint '{ref_type}'"
|
| 913 |
-
)
|
| 914 |
-
|
| 915 |
-
else:
|
| 916 |
-
assert False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/dictconfig.py
DELETED
|
@@ -1,776 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
from enum import Enum
|
| 3 |
-
from typing import (
|
| 4 |
-
Any,
|
| 5 |
-
Dict,
|
| 6 |
-
ItemsView,
|
| 7 |
-
Iterable,
|
| 8 |
-
Iterator,
|
| 9 |
-
KeysView,
|
| 10 |
-
List,
|
| 11 |
-
MutableMapping,
|
| 12 |
-
Optional,
|
| 13 |
-
Sequence,
|
| 14 |
-
Tuple,
|
| 15 |
-
Type,
|
| 16 |
-
Union,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
from ._utils import (
|
| 20 |
-
_DEFAULT_MARKER_,
|
| 21 |
-
ValueKind,
|
| 22 |
-
_get_value,
|
| 23 |
-
_is_interpolation,
|
| 24 |
-
_is_missing_literal,
|
| 25 |
-
_is_missing_value,
|
| 26 |
-
_is_none,
|
| 27 |
-
_resolve_optional,
|
| 28 |
-
_valid_dict_key_annotation_type,
|
| 29 |
-
format_and_raise,
|
| 30 |
-
get_structured_config_data,
|
| 31 |
-
get_structured_config_init_field_names,
|
| 32 |
-
get_type_of,
|
| 33 |
-
get_value_kind,
|
| 34 |
-
is_container_annotation,
|
| 35 |
-
is_dict,
|
| 36 |
-
is_primitive_dict,
|
| 37 |
-
is_structured_config,
|
| 38 |
-
is_structured_config_frozen,
|
| 39 |
-
type_str,
|
| 40 |
-
)
|
| 41 |
-
from .base import Box, Container, ContainerMetadata, DictKeyType, Node
|
| 42 |
-
from .basecontainer import BaseContainer
|
| 43 |
-
from .errors import (
|
| 44 |
-
ConfigAttributeError,
|
| 45 |
-
ConfigKeyError,
|
| 46 |
-
ConfigTypeError,
|
| 47 |
-
InterpolationResolutionError,
|
| 48 |
-
KeyValidationError,
|
| 49 |
-
MissingMandatoryValue,
|
| 50 |
-
OmegaConfBaseException,
|
| 51 |
-
ReadonlyConfigError,
|
| 52 |
-
ValidationError,
|
| 53 |
-
)
|
| 54 |
-
from .nodes import EnumNode, ValueNode
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class DictConfig(BaseContainer, MutableMapping[Any, Any]):
|
| 58 |
-
|
| 59 |
-
_metadata: ContainerMetadata
|
| 60 |
-
_content: Union[Dict[DictKeyType, Node], None, str]
|
| 61 |
-
|
| 62 |
-
def __init__(
|
| 63 |
-
self,
|
| 64 |
-
content: Union[Dict[DictKeyType, Any], "DictConfig", Any],
|
| 65 |
-
key: Any = None,
|
| 66 |
-
parent: Optional[Box] = None,
|
| 67 |
-
ref_type: Union[Any, Type[Any]] = Any,
|
| 68 |
-
key_type: Union[Any, Type[Any]] = Any,
|
| 69 |
-
element_type: Union[Any, Type[Any]] = Any,
|
| 70 |
-
is_optional: bool = True,
|
| 71 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 72 |
-
) -> None:
|
| 73 |
-
try:
|
| 74 |
-
if isinstance(content, DictConfig):
|
| 75 |
-
if flags is None:
|
| 76 |
-
flags = content._metadata.flags
|
| 77 |
-
super().__init__(
|
| 78 |
-
parent=parent,
|
| 79 |
-
metadata=ContainerMetadata(
|
| 80 |
-
key=key,
|
| 81 |
-
optional=is_optional,
|
| 82 |
-
ref_type=ref_type,
|
| 83 |
-
object_type=dict,
|
| 84 |
-
key_type=key_type,
|
| 85 |
-
element_type=element_type,
|
| 86 |
-
flags=flags,
|
| 87 |
-
),
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
if not _valid_dict_key_annotation_type(key_type):
|
| 91 |
-
raise KeyValidationError(f"Unsupported key type {key_type}")
|
| 92 |
-
|
| 93 |
-
if is_structured_config(content) or is_structured_config(ref_type):
|
| 94 |
-
self._set_value(content, flags=flags)
|
| 95 |
-
if is_structured_config_frozen(content) or is_structured_config_frozen(
|
| 96 |
-
ref_type
|
| 97 |
-
):
|
| 98 |
-
self._set_flag("readonly", True)
|
| 99 |
-
|
| 100 |
-
else:
|
| 101 |
-
if isinstance(content, DictConfig):
|
| 102 |
-
metadata = copy.deepcopy(content._metadata)
|
| 103 |
-
metadata.key = key
|
| 104 |
-
metadata.ref_type = ref_type
|
| 105 |
-
metadata.optional = is_optional
|
| 106 |
-
metadata.element_type = element_type
|
| 107 |
-
metadata.key_type = key_type
|
| 108 |
-
self.__dict__["_metadata"] = metadata
|
| 109 |
-
self._set_value(content, flags=flags)
|
| 110 |
-
except Exception as ex:
|
| 111 |
-
format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
|
| 112 |
-
|
| 113 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "DictConfig":
|
| 114 |
-
res = DictConfig(None)
|
| 115 |
-
res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
|
| 116 |
-
res.__dict__["_flags_cache"] = copy.deepcopy(
|
| 117 |
-
self.__dict__["_flags_cache"], memo=memo
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
src_content = self.__dict__["_content"]
|
| 121 |
-
if isinstance(src_content, dict):
|
| 122 |
-
content_copy = {}
|
| 123 |
-
for k, v in src_content.items():
|
| 124 |
-
old_parent = v.__dict__["_parent"]
|
| 125 |
-
try:
|
| 126 |
-
v.__dict__["_parent"] = None
|
| 127 |
-
vc = copy.deepcopy(v, memo=memo)
|
| 128 |
-
vc.__dict__["_parent"] = res
|
| 129 |
-
content_copy[k] = vc
|
| 130 |
-
finally:
|
| 131 |
-
v.__dict__["_parent"] = old_parent
|
| 132 |
-
else:
|
| 133 |
-
# None and strings can be assigned as is
|
| 134 |
-
content_copy = src_content
|
| 135 |
-
|
| 136 |
-
res.__dict__["_content"] = content_copy
|
| 137 |
-
# parent is retained, but not copied
|
| 138 |
-
res.__dict__["_parent"] = self.__dict__["_parent"]
|
| 139 |
-
return res
|
| 140 |
-
|
| 141 |
-
def copy(self) -> "DictConfig":
|
| 142 |
-
return copy.copy(self)
|
| 143 |
-
|
| 144 |
-
def _is_typed(self) -> bool:
|
| 145 |
-
return self._metadata.object_type not in (Any, None) and not is_dict(
|
| 146 |
-
self._metadata.object_type
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
def _validate_get(self, key: Any, value: Any = None) -> None:
|
| 150 |
-
is_typed = self._is_typed()
|
| 151 |
-
|
| 152 |
-
is_struct = self._get_flag("struct") is True
|
| 153 |
-
if key not in self.__dict__["_content"]:
|
| 154 |
-
if is_typed:
|
| 155 |
-
# do not raise an exception if struct is explicitly set to False
|
| 156 |
-
if self._get_node_flag("struct") is False:
|
| 157 |
-
return
|
| 158 |
-
if is_typed or is_struct:
|
| 159 |
-
if is_typed:
|
| 160 |
-
assert self._metadata.object_type not in (dict, None)
|
| 161 |
-
msg = f"Key '{key}' not in '{self._metadata.object_type.__name__}'"
|
| 162 |
-
else:
|
| 163 |
-
msg = f"Key '{key}' is not in struct"
|
| 164 |
-
self._format_and_raise(
|
| 165 |
-
key=key, value=value, cause=ConfigAttributeError(msg)
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
def _validate_set(self, key: Any, value: Any) -> None:
|
| 169 |
-
from omegaconf import OmegaConf
|
| 170 |
-
|
| 171 |
-
vk = get_value_kind(value)
|
| 172 |
-
if vk == ValueKind.INTERPOLATION:
|
| 173 |
-
return
|
| 174 |
-
if _is_none(value):
|
| 175 |
-
self._validate_non_optional(key, value)
|
| 176 |
-
return
|
| 177 |
-
if vk == ValueKind.MANDATORY_MISSING or value is None:
|
| 178 |
-
return
|
| 179 |
-
|
| 180 |
-
target = self._get_node(key) if key is not None else self
|
| 181 |
-
|
| 182 |
-
target_has_ref_type = isinstance(
|
| 183 |
-
target, DictConfig
|
| 184 |
-
) and target._metadata.ref_type not in (Any, dict)
|
| 185 |
-
is_valid_target = target is None or not target_has_ref_type
|
| 186 |
-
|
| 187 |
-
if is_valid_target:
|
| 188 |
-
return
|
| 189 |
-
|
| 190 |
-
assert isinstance(target, Node)
|
| 191 |
-
|
| 192 |
-
target_type = target._metadata.ref_type
|
| 193 |
-
value_type = OmegaConf.get_type(value)
|
| 194 |
-
|
| 195 |
-
if is_dict(value_type) and is_dict(target_type):
|
| 196 |
-
return
|
| 197 |
-
if is_container_annotation(target_type) and not is_container_annotation(
|
| 198 |
-
value_type
|
| 199 |
-
):
|
| 200 |
-
raise ValidationError(
|
| 201 |
-
f"Cannot assign {type_str(value_type)} to {type_str(target_type)}"
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
if target_type is not None and value_type is not None:
|
| 205 |
-
origin = getattr(target_type, "__origin__", target_type)
|
| 206 |
-
if not issubclass(value_type, origin):
|
| 207 |
-
self._raise_invalid_value(value, value_type, target_type)
|
| 208 |
-
|
| 209 |
-
def _validate_merge(self, value: Any) -> None:
|
| 210 |
-
from omegaconf import OmegaConf
|
| 211 |
-
|
| 212 |
-
dest = self
|
| 213 |
-
src = value
|
| 214 |
-
|
| 215 |
-
self._validate_non_optional(None, src)
|
| 216 |
-
|
| 217 |
-
dest_obj_type = OmegaConf.get_type(dest)
|
| 218 |
-
src_obj_type = OmegaConf.get_type(src)
|
| 219 |
-
|
| 220 |
-
if dest._is_missing() and src._metadata.object_type not in (dict, None):
|
| 221 |
-
self._validate_set(key=None, value=_get_value(src))
|
| 222 |
-
|
| 223 |
-
if src._is_missing():
|
| 224 |
-
return
|
| 225 |
-
|
| 226 |
-
validation_error = (
|
| 227 |
-
dest_obj_type is not None
|
| 228 |
-
and src_obj_type is not None
|
| 229 |
-
and is_structured_config(dest_obj_type)
|
| 230 |
-
and not src._is_none()
|
| 231 |
-
and not is_dict(src_obj_type)
|
| 232 |
-
and not issubclass(src_obj_type, dest_obj_type)
|
| 233 |
-
)
|
| 234 |
-
if validation_error:
|
| 235 |
-
msg = (
|
| 236 |
-
f"Merge error: {type_str(src_obj_type)} is not a "
|
| 237 |
-
f"subclass of {type_str(dest_obj_type)}. value: {src}"
|
| 238 |
-
)
|
| 239 |
-
raise ValidationError(msg)
|
| 240 |
-
|
| 241 |
-
def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None:
|
| 242 |
-
if _is_none(value, resolve=True, throw_on_resolution_failure=False):
|
| 243 |
-
|
| 244 |
-
if key is not None:
|
| 245 |
-
child = self._get_node(key)
|
| 246 |
-
if child is not None:
|
| 247 |
-
assert isinstance(child, Node)
|
| 248 |
-
field_is_optional = child._is_optional()
|
| 249 |
-
else:
|
| 250 |
-
field_is_optional, _ = _resolve_optional(
|
| 251 |
-
self._metadata.element_type
|
| 252 |
-
)
|
| 253 |
-
else:
|
| 254 |
-
field_is_optional = self._is_optional()
|
| 255 |
-
|
| 256 |
-
if not field_is_optional:
|
| 257 |
-
self._format_and_raise(
|
| 258 |
-
key=key,
|
| 259 |
-
value=value,
|
| 260 |
-
cause=ValidationError("field '$FULL_KEY' is not Optional"),
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
def _raise_invalid_value(
|
| 264 |
-
self, value: Any, value_type: Any, target_type: Any
|
| 265 |
-
) -> None:
|
| 266 |
-
assert value_type is not None
|
| 267 |
-
assert target_type is not None
|
| 268 |
-
msg = (
|
| 269 |
-
f"Invalid type assigned: {type_str(value_type)} is not a "
|
| 270 |
-
f"subclass of {type_str(target_type)}. value: {value}"
|
| 271 |
-
)
|
| 272 |
-
raise ValidationError(msg)
|
| 273 |
-
|
| 274 |
-
def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
|
| 275 |
-
return self._s_validate_and_normalize_key(self._metadata.key_type, key)
|
| 276 |
-
|
| 277 |
-
def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
|
| 278 |
-
if key_type is Any:
|
| 279 |
-
for t in DictKeyType.__args__: # type: ignore
|
| 280 |
-
if isinstance(key, t):
|
| 281 |
-
return key # type: ignore
|
| 282 |
-
raise KeyValidationError("Incompatible key type '$KEY_TYPE'")
|
| 283 |
-
elif key_type is bool and key in [0, 1]:
|
| 284 |
-
# Python treats True as 1 and False as 0 when used as dict keys
|
| 285 |
-
# assert hash(0) == hash(False)
|
| 286 |
-
# assert hash(1) == hash(True)
|
| 287 |
-
return bool(key)
|
| 288 |
-
elif key_type in (str, bytes, int, float, bool): # primitive type
|
| 289 |
-
if not isinstance(key, key_type):
|
| 290 |
-
raise KeyValidationError(
|
| 291 |
-
f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
return key # type: ignore
|
| 295 |
-
elif issubclass(key_type, Enum):
|
| 296 |
-
try:
|
| 297 |
-
return EnumNode.validate_and_convert_to_enum(key_type, key)
|
| 298 |
-
except ValidationError:
|
| 299 |
-
valid = ", ".join([x for x in key_type.__members__.keys()])
|
| 300 |
-
raise KeyValidationError(
|
| 301 |
-
f"Key '$KEY' is incompatible with the enum type '{key_type.__name__}', valid: [{valid}]"
|
| 302 |
-
)
|
| 303 |
-
else:
|
| 304 |
-
assert False, f"Unsupported key type {key_type}"
|
| 305 |
-
|
| 306 |
-
def __setitem__(self, key: DictKeyType, value: Any) -> None:
|
| 307 |
-
try:
|
| 308 |
-
self.__set_impl(key=key, value=value)
|
| 309 |
-
except AttributeError as e:
|
| 310 |
-
self._format_and_raise(
|
| 311 |
-
key=key, value=value, type_override=ConfigKeyError, cause=e
|
| 312 |
-
)
|
| 313 |
-
except Exception as e:
|
| 314 |
-
self._format_and_raise(key=key, value=value, cause=e)
|
| 315 |
-
|
| 316 |
-
def __set_impl(self, key: DictKeyType, value: Any) -> None:
|
| 317 |
-
key = self._validate_and_normalize_key(key)
|
| 318 |
-
self._set_item_impl(key, value)
|
| 319 |
-
|
| 320 |
-
# hide content while inspecting in debugger
|
| 321 |
-
def __dir__(self) -> Iterable[str]:
|
| 322 |
-
if self._is_missing() or self._is_none():
|
| 323 |
-
return []
|
| 324 |
-
return self.__dict__["_content"].keys() # type: ignore
|
| 325 |
-
|
| 326 |
-
def __setattr__(self, key: str, value: Any) -> None:
|
| 327 |
-
"""
|
| 328 |
-
Allow assigning attributes to DictConfig
|
| 329 |
-
:param key:
|
| 330 |
-
:param value:
|
| 331 |
-
:return:
|
| 332 |
-
"""
|
| 333 |
-
try:
|
| 334 |
-
self.__set_impl(key, value)
|
| 335 |
-
except Exception as e:
|
| 336 |
-
if isinstance(e, OmegaConfBaseException) and e._initialized:
|
| 337 |
-
raise e
|
| 338 |
-
self._format_and_raise(key=key, value=value, cause=e)
|
| 339 |
-
assert False
|
| 340 |
-
|
| 341 |
-
def __getattr__(self, key: str) -> Any:
|
| 342 |
-
"""
|
| 343 |
-
Allow accessing dictionary values as attributes
|
| 344 |
-
:param key:
|
| 345 |
-
:return:
|
| 346 |
-
"""
|
| 347 |
-
if key == "__name__":
|
| 348 |
-
raise AttributeError()
|
| 349 |
-
|
| 350 |
-
try:
|
| 351 |
-
return self._get_impl(
|
| 352 |
-
key=key, default_value=_DEFAULT_MARKER_, validate_key=False
|
| 353 |
-
)
|
| 354 |
-
except ConfigKeyError as e:
|
| 355 |
-
self._format_and_raise(
|
| 356 |
-
key=key, value=None, cause=e, type_override=ConfigAttributeError
|
| 357 |
-
)
|
| 358 |
-
except Exception as e:
|
| 359 |
-
self._format_and_raise(key=key, value=None, cause=e)
|
| 360 |
-
|
| 361 |
-
def __getitem__(self, key: DictKeyType) -> Any:
|
| 362 |
-
"""
|
| 363 |
-
Allow map style access
|
| 364 |
-
:param key:
|
| 365 |
-
:return:
|
| 366 |
-
"""
|
| 367 |
-
|
| 368 |
-
try:
|
| 369 |
-
return self._get_impl(key=key, default_value=_DEFAULT_MARKER_)
|
| 370 |
-
except AttributeError as e:
|
| 371 |
-
self._format_and_raise(
|
| 372 |
-
key=key, value=None, cause=e, type_override=ConfigKeyError
|
| 373 |
-
)
|
| 374 |
-
except Exception as e:
|
| 375 |
-
self._format_and_raise(key=key, value=None, cause=e)
|
| 376 |
-
|
| 377 |
-
def __delattr__(self, key: str) -> None:
|
| 378 |
-
"""
|
| 379 |
-
Allow deleting dictionary values as attributes
|
| 380 |
-
:param key:
|
| 381 |
-
:return:
|
| 382 |
-
"""
|
| 383 |
-
if self._get_flag("readonly"):
|
| 384 |
-
self._format_and_raise(
|
| 385 |
-
key=key,
|
| 386 |
-
value=None,
|
| 387 |
-
cause=ReadonlyConfigError(
|
| 388 |
-
"DictConfig in read-only mode does not support deletion"
|
| 389 |
-
),
|
| 390 |
-
)
|
| 391 |
-
try:
|
| 392 |
-
del self.__dict__["_content"][key]
|
| 393 |
-
except KeyError:
|
| 394 |
-
msg = "Attribute not found: '$KEY'"
|
| 395 |
-
self._format_and_raise(key=key, value=None, cause=ConfigAttributeError(msg))
|
| 396 |
-
|
| 397 |
-
def __delitem__(self, key: DictKeyType) -> None:
|
| 398 |
-
key = self._validate_and_normalize_key(key)
|
| 399 |
-
if self._get_flag("readonly"):
|
| 400 |
-
self._format_and_raise(
|
| 401 |
-
key=key,
|
| 402 |
-
value=None,
|
| 403 |
-
cause=ReadonlyConfigError(
|
| 404 |
-
"DictConfig in read-only mode does not support deletion"
|
| 405 |
-
),
|
| 406 |
-
)
|
| 407 |
-
if self._get_flag("struct"):
|
| 408 |
-
self._format_and_raise(
|
| 409 |
-
key=key,
|
| 410 |
-
value=None,
|
| 411 |
-
cause=ConfigTypeError(
|
| 412 |
-
"DictConfig in struct mode does not support deletion"
|
| 413 |
-
),
|
| 414 |
-
)
|
| 415 |
-
if self._is_typed() and self._get_node_flag("struct") is not False:
|
| 416 |
-
self._format_and_raise(
|
| 417 |
-
key=key,
|
| 418 |
-
value=None,
|
| 419 |
-
cause=ConfigTypeError(
|
| 420 |
-
f"{type_str(self._metadata.object_type)} (DictConfig) does not support deletion"
|
| 421 |
-
),
|
| 422 |
-
)
|
| 423 |
-
|
| 424 |
-
try:
|
| 425 |
-
del self.__dict__["_content"][key]
|
| 426 |
-
except KeyError:
|
| 427 |
-
msg = "Key not found: '$KEY'"
|
| 428 |
-
self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg))
|
| 429 |
-
|
| 430 |
-
def get(self, key: DictKeyType, default_value: Any = None) -> Any:
|
| 431 |
-
"""Return the value for `key` if `key` is in the dictionary, else
|
| 432 |
-
`default_value` (defaulting to `None`)."""
|
| 433 |
-
try:
|
| 434 |
-
return self._get_impl(key=key, default_value=default_value)
|
| 435 |
-
except KeyValidationError as e:
|
| 436 |
-
self._format_and_raise(key=key, value=None, cause=e)
|
| 437 |
-
|
| 438 |
-
def _get_impl(
|
| 439 |
-
self, key: DictKeyType, default_value: Any, validate_key: bool = True
|
| 440 |
-
) -> Any:
|
| 441 |
-
try:
|
| 442 |
-
node = self._get_child(
|
| 443 |
-
key=key, throw_on_missing_key=True, validate_key=validate_key
|
| 444 |
-
)
|
| 445 |
-
except (ConfigAttributeError, ConfigKeyError):
|
| 446 |
-
if default_value is not _DEFAULT_MARKER_:
|
| 447 |
-
return default_value
|
| 448 |
-
else:
|
| 449 |
-
raise
|
| 450 |
-
assert isinstance(node, Node)
|
| 451 |
-
return self._resolve_with_default(
|
| 452 |
-
key=key, value=node, default_value=default_value
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
def _get_node(
|
| 456 |
-
self,
|
| 457 |
-
key: DictKeyType,
|
| 458 |
-
validate_access: bool = True,
|
| 459 |
-
validate_key: bool = True,
|
| 460 |
-
throw_on_missing_value: bool = False,
|
| 461 |
-
throw_on_missing_key: bool = False,
|
| 462 |
-
) -> Optional[Node]:
|
| 463 |
-
try:
|
| 464 |
-
key = self._validate_and_normalize_key(key)
|
| 465 |
-
except KeyValidationError:
|
| 466 |
-
if validate_access and validate_key:
|
| 467 |
-
raise
|
| 468 |
-
else:
|
| 469 |
-
if throw_on_missing_key:
|
| 470 |
-
raise ConfigAttributeError
|
| 471 |
-
else:
|
| 472 |
-
return None
|
| 473 |
-
|
| 474 |
-
if validate_access:
|
| 475 |
-
self._validate_get(key)
|
| 476 |
-
|
| 477 |
-
value: Optional[Node] = self.__dict__["_content"].get(key)
|
| 478 |
-
if value is None:
|
| 479 |
-
if throw_on_missing_key:
|
| 480 |
-
raise ConfigKeyError(f"Missing key {key!s}")
|
| 481 |
-
elif throw_on_missing_value and value._is_missing():
|
| 482 |
-
raise MissingMandatoryValue("Missing mandatory value: $KEY")
|
| 483 |
-
return value
|
| 484 |
-
|
| 485 |
-
def pop(self, key: DictKeyType, default: Any = _DEFAULT_MARKER_) -> Any:
|
| 486 |
-
try:
|
| 487 |
-
if self._get_flag("readonly"):
|
| 488 |
-
raise ReadonlyConfigError("Cannot pop from read-only node")
|
| 489 |
-
if self._get_flag("struct"):
|
| 490 |
-
raise ConfigTypeError("DictConfig in struct mode does not support pop")
|
| 491 |
-
if self._is_typed() and self._get_node_flag("struct") is not False:
|
| 492 |
-
raise ConfigTypeError(
|
| 493 |
-
f"{type_str(self._metadata.object_type)} (DictConfig) does not support pop"
|
| 494 |
-
)
|
| 495 |
-
key = self._validate_and_normalize_key(key)
|
| 496 |
-
node = self._get_child(key=key, validate_access=False)
|
| 497 |
-
if node is not None:
|
| 498 |
-
assert isinstance(node, Node)
|
| 499 |
-
value = self._resolve_with_default(
|
| 500 |
-
key=key, value=node, default_value=default
|
| 501 |
-
)
|
| 502 |
-
|
| 503 |
-
del self[key]
|
| 504 |
-
return value
|
| 505 |
-
else:
|
| 506 |
-
if default is not _DEFAULT_MARKER_:
|
| 507 |
-
return default
|
| 508 |
-
else:
|
| 509 |
-
full = self._get_full_key(key=key)
|
| 510 |
-
if full != key:
|
| 511 |
-
raise ConfigKeyError(
|
| 512 |
-
f"Key not found: '{key!s}' (path: '{full}')"
|
| 513 |
-
)
|
| 514 |
-
else:
|
| 515 |
-
raise ConfigKeyError(f"Key not found: '{key!s}'")
|
| 516 |
-
except Exception as e:
|
| 517 |
-
self._format_and_raise(key=key, value=None, cause=e)
|
| 518 |
-
|
| 519 |
-
def keys(self) -> KeysView[DictKeyType]:
|
| 520 |
-
if self._is_missing() or self._is_interpolation() or self._is_none():
|
| 521 |
-
return {}.keys()
|
| 522 |
-
ret = self.__dict__["_content"].keys()
|
| 523 |
-
assert isinstance(ret, KeysView)
|
| 524 |
-
return ret
|
| 525 |
-
|
| 526 |
-
def __contains__(self, key: object) -> bool:
|
| 527 |
-
"""
|
| 528 |
-
A key is contained in a DictConfig if there is an associated value and
|
| 529 |
-
it is not a mandatory missing value ('???').
|
| 530 |
-
:param key:
|
| 531 |
-
:return:
|
| 532 |
-
"""
|
| 533 |
-
|
| 534 |
-
try:
|
| 535 |
-
key = self._validate_and_normalize_key(key)
|
| 536 |
-
except KeyValidationError:
|
| 537 |
-
return False
|
| 538 |
-
|
| 539 |
-
try:
|
| 540 |
-
node = self._get_child(key)
|
| 541 |
-
assert node is None or isinstance(node, Node)
|
| 542 |
-
except (KeyError, AttributeError):
|
| 543 |
-
node = None
|
| 544 |
-
|
| 545 |
-
if node is None:
|
| 546 |
-
return False
|
| 547 |
-
else:
|
| 548 |
-
try:
|
| 549 |
-
self._resolve_with_default(key=key, value=node)
|
| 550 |
-
return True
|
| 551 |
-
except InterpolationResolutionError:
|
| 552 |
-
# Interpolations that fail count as existing.
|
| 553 |
-
return True
|
| 554 |
-
except MissingMandatoryValue:
|
| 555 |
-
# Missing values count as *not* existing.
|
| 556 |
-
return False
|
| 557 |
-
|
| 558 |
-
def __iter__(self) -> Iterator[DictKeyType]:
|
| 559 |
-
return iter(self.keys())
|
| 560 |
-
|
| 561 |
-
def items(self) -> ItemsView[DictKeyType, Any]:
|
| 562 |
-
return dict(self.items_ex(resolve=True, keys=None)).items()
|
| 563 |
-
|
| 564 |
-
def setdefault(self, key: DictKeyType, default: Any = None) -> Any:
|
| 565 |
-
if key in self:
|
| 566 |
-
ret = self.__getitem__(key)
|
| 567 |
-
else:
|
| 568 |
-
ret = default
|
| 569 |
-
self.__setitem__(key, default)
|
| 570 |
-
return ret
|
| 571 |
-
|
| 572 |
-
def items_ex(
|
| 573 |
-
self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None
|
| 574 |
-
) -> List[Tuple[DictKeyType, Any]]:
|
| 575 |
-
items: List[Tuple[DictKeyType, Any]] = []
|
| 576 |
-
|
| 577 |
-
if self._is_none():
|
| 578 |
-
self._format_and_raise(
|
| 579 |
-
key=None,
|
| 580 |
-
value=None,
|
| 581 |
-
cause=TypeError("Cannot iterate a DictConfig object representing None"),
|
| 582 |
-
)
|
| 583 |
-
if self._is_missing():
|
| 584 |
-
raise MissingMandatoryValue("Cannot iterate a missing DictConfig")
|
| 585 |
-
|
| 586 |
-
for key in self.keys():
|
| 587 |
-
if resolve:
|
| 588 |
-
value = self[key]
|
| 589 |
-
else:
|
| 590 |
-
value = self.__dict__["_content"][key]
|
| 591 |
-
if isinstance(value, ValueNode):
|
| 592 |
-
value = value._value()
|
| 593 |
-
if keys is None or key in keys:
|
| 594 |
-
items.append((key, value))
|
| 595 |
-
|
| 596 |
-
return items
|
| 597 |
-
|
| 598 |
-
def __eq__(self, other: Any) -> bool:
|
| 599 |
-
if other is None:
|
| 600 |
-
return self.__dict__["_content"] is None
|
| 601 |
-
if is_primitive_dict(other) or is_structured_config(other):
|
| 602 |
-
other = DictConfig(other, flags={"allow_objects": True})
|
| 603 |
-
return DictConfig._dict_conf_eq(self, other)
|
| 604 |
-
if isinstance(other, DictConfig):
|
| 605 |
-
return DictConfig._dict_conf_eq(self, other)
|
| 606 |
-
if self._is_missing():
|
| 607 |
-
return _is_missing_literal(other)
|
| 608 |
-
return NotImplemented
|
| 609 |
-
|
| 610 |
-
def __ne__(self, other: Any) -> bool:
|
| 611 |
-
x = self.__eq__(other)
|
| 612 |
-
if x is not NotImplemented:
|
| 613 |
-
return not x
|
| 614 |
-
return NotImplemented
|
| 615 |
-
|
| 616 |
-
def __hash__(self) -> int:
|
| 617 |
-
return hash(str(self))
|
| 618 |
-
|
| 619 |
-
def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None:
|
| 620 |
-
"""
|
| 621 |
-
Retypes a node.
|
| 622 |
-
This should only be used in rare circumstances, where you want to dynamically change
|
| 623 |
-
the runtime structured-type of a DictConfig.
|
| 624 |
-
It will change the type and add the additional fields based on the input class or object
|
| 625 |
-
"""
|
| 626 |
-
if type_or_prototype is None:
|
| 627 |
-
return
|
| 628 |
-
if not is_structured_config(type_or_prototype):
|
| 629 |
-
raise ValueError(f"Expected structured config class: {type_or_prototype}")
|
| 630 |
-
|
| 631 |
-
from omegaconf import OmegaConf
|
| 632 |
-
|
| 633 |
-
proto: DictConfig = OmegaConf.structured(type_or_prototype)
|
| 634 |
-
object_type = proto._metadata.object_type
|
| 635 |
-
# remove the type to prevent assignment validation from rejecting the promotion.
|
| 636 |
-
proto._metadata.object_type = None
|
| 637 |
-
self.merge_with(proto)
|
| 638 |
-
# restore the type.
|
| 639 |
-
self._metadata.object_type = object_type
|
| 640 |
-
|
| 641 |
-
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
|
| 642 |
-
try:
|
| 643 |
-
previous_content = self.__dict__["_content"]
|
| 644 |
-
self._set_value_impl(value, flags)
|
| 645 |
-
except Exception as e:
|
| 646 |
-
self.__dict__["_content"] = previous_content
|
| 647 |
-
raise e
|
| 648 |
-
|
| 649 |
-
def _set_value_impl(
|
| 650 |
-
self, value: Any, flags: Optional[Dict[str, bool]] = None
|
| 651 |
-
) -> None:
|
| 652 |
-
from omegaconf import MISSING, flag_override
|
| 653 |
-
|
| 654 |
-
if flags is None:
|
| 655 |
-
flags = {}
|
| 656 |
-
|
| 657 |
-
assert not isinstance(value, ValueNode)
|
| 658 |
-
self._validate_set(key=None, value=value)
|
| 659 |
-
|
| 660 |
-
if _is_none(value, resolve=True):
|
| 661 |
-
self.__dict__["_content"] = None
|
| 662 |
-
self._metadata.object_type = None
|
| 663 |
-
elif _is_interpolation(value, strict_interpolation_validation=True):
|
| 664 |
-
self.__dict__["_content"] = value
|
| 665 |
-
self._metadata.object_type = None
|
| 666 |
-
elif _is_missing_value(value):
|
| 667 |
-
self.__dict__["_content"] = MISSING
|
| 668 |
-
self._metadata.object_type = None
|
| 669 |
-
else:
|
| 670 |
-
self.__dict__["_content"] = {}
|
| 671 |
-
if is_structured_config(value):
|
| 672 |
-
self._metadata.object_type = None
|
| 673 |
-
ao = self._get_flag("allow_objects")
|
| 674 |
-
data = get_structured_config_data(value, allow_objects=ao)
|
| 675 |
-
with flag_override(self, ["struct", "readonly"], False):
|
| 676 |
-
for k, v in data.items():
|
| 677 |
-
self.__setitem__(k, v)
|
| 678 |
-
self._metadata.object_type = get_type_of(value)
|
| 679 |
-
|
| 680 |
-
elif isinstance(value, DictConfig):
|
| 681 |
-
self._metadata.flags = copy.deepcopy(flags)
|
| 682 |
-
with flag_override(self, ["struct", "readonly"], False):
|
| 683 |
-
for k, v in value.__dict__["_content"].items():
|
| 684 |
-
self.__setitem__(k, v)
|
| 685 |
-
self._metadata.object_type = value._metadata.object_type
|
| 686 |
-
|
| 687 |
-
elif isinstance(value, dict):
|
| 688 |
-
with flag_override(self, ["struct", "readonly"], False):
|
| 689 |
-
for k, v in value.items():
|
| 690 |
-
self.__setitem__(k, v)
|
| 691 |
-
self._metadata.object_type = dict
|
| 692 |
-
|
| 693 |
-
else: # pragma: no cover
|
| 694 |
-
msg = f"Unsupported value type: {value}"
|
| 695 |
-
raise ValidationError(msg)
|
| 696 |
-
|
| 697 |
-
@staticmethod
|
| 698 |
-
def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool:
|
| 699 |
-
|
| 700 |
-
d1_none = d1.__dict__["_content"] is None
|
| 701 |
-
d2_none = d2.__dict__["_content"] is None
|
| 702 |
-
if d1_none and d2_none:
|
| 703 |
-
return True
|
| 704 |
-
if d1_none != d2_none:
|
| 705 |
-
return False
|
| 706 |
-
|
| 707 |
-
assert isinstance(d1, DictConfig)
|
| 708 |
-
assert isinstance(d2, DictConfig)
|
| 709 |
-
if len(d1) != len(d2):
|
| 710 |
-
return False
|
| 711 |
-
if d1._is_missing() or d2._is_missing():
|
| 712 |
-
return d1._is_missing() is d2._is_missing()
|
| 713 |
-
|
| 714 |
-
for k, v in d1.items_ex(resolve=False):
|
| 715 |
-
if k not in d2.__dict__["_content"]:
|
| 716 |
-
return False
|
| 717 |
-
if not BaseContainer._item_eq(d1, k, d2, k):
|
| 718 |
-
return False
|
| 719 |
-
|
| 720 |
-
return True
|
| 721 |
-
|
| 722 |
-
def _to_object(self) -> Any:
|
| 723 |
-
"""
|
| 724 |
-
Instantiate an instance of `self._metadata.object_type`.
|
| 725 |
-
This requires `self` to be a structured config.
|
| 726 |
-
Nested subconfigs are converted by calling `OmegaConf.to_object`.
|
| 727 |
-
"""
|
| 728 |
-
from omegaconf import OmegaConf
|
| 729 |
-
|
| 730 |
-
object_type = self._metadata.object_type
|
| 731 |
-
assert is_structured_config(object_type)
|
| 732 |
-
init_field_names = set(get_structured_config_init_field_names(object_type))
|
| 733 |
-
|
| 734 |
-
init_field_items: Dict[str, Any] = {}
|
| 735 |
-
non_init_field_items: Dict[str, Any] = {}
|
| 736 |
-
for k in self.keys():
|
| 737 |
-
assert isinstance(k, str)
|
| 738 |
-
node = self._get_child(k)
|
| 739 |
-
assert isinstance(node, Node)
|
| 740 |
-
try:
|
| 741 |
-
node = node._dereference_node()
|
| 742 |
-
except InterpolationResolutionError as e:
|
| 743 |
-
self._format_and_raise(key=k, value=None, cause=e)
|
| 744 |
-
if node._is_missing():
|
| 745 |
-
if k not in init_field_names:
|
| 746 |
-
continue # MISSING is ignored for init=False fields
|
| 747 |
-
self._format_and_raise(
|
| 748 |
-
key=k,
|
| 749 |
-
value=None,
|
| 750 |
-
cause=MissingMandatoryValue(
|
| 751 |
-
"Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY"
|
| 752 |
-
),
|
| 753 |
-
)
|
| 754 |
-
if isinstance(node, Container):
|
| 755 |
-
v = OmegaConf.to_object(node)
|
| 756 |
-
else:
|
| 757 |
-
v = node._value()
|
| 758 |
-
|
| 759 |
-
if k in init_field_names:
|
| 760 |
-
init_field_items[k] = v
|
| 761 |
-
else:
|
| 762 |
-
non_init_field_items[k] = v
|
| 763 |
-
|
| 764 |
-
try:
|
| 765 |
-
result = object_type(**init_field_items)
|
| 766 |
-
except TypeError as exc:
|
| 767 |
-
self._format_and_raise(
|
| 768 |
-
key=None,
|
| 769 |
-
value=None,
|
| 770 |
-
cause=exc,
|
| 771 |
-
msg="Could not create instance of `$OBJECT_TYPE`: " + str(exc),
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
for k, v in non_init_field_items.items():
|
| 775 |
-
setattr(result, k, v)
|
| 776 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/errors.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
from typing import Any, Optional, Type
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class OmegaConfBaseException(Exception):
|
| 5 |
-
# would ideally be typed Optional[Node]
|
| 6 |
-
parent_node: Any
|
| 7 |
-
child_node: Any
|
| 8 |
-
key: Any
|
| 9 |
-
full_key: Optional[str]
|
| 10 |
-
value: Any
|
| 11 |
-
msg: Optional[str]
|
| 12 |
-
cause: Optional[Exception]
|
| 13 |
-
object_type: Optional[Type[Any]]
|
| 14 |
-
object_type_str: Optional[str]
|
| 15 |
-
ref_type: Optional[Type[Any]]
|
| 16 |
-
ref_type_str: Optional[str]
|
| 17 |
-
|
| 18 |
-
_initialized: bool = False
|
| 19 |
-
|
| 20 |
-
def __init__(self, *_args: Any, **_kwargs: Any) -> None:
|
| 21 |
-
self.parent_node = None
|
| 22 |
-
self.child_node = None
|
| 23 |
-
self.key = None
|
| 24 |
-
self.full_key = None
|
| 25 |
-
self.value = None
|
| 26 |
-
self.msg = None
|
| 27 |
-
self.object_type = None
|
| 28 |
-
self.ref_type = None
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class MissingMandatoryValue(OmegaConfBaseException):
|
| 32 |
-
"""Thrown when a variable flagged with '???' value is accessed to
|
| 33 |
-
indicate that the value was not set"""
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class KeyValidationError(OmegaConfBaseException, ValueError):
|
| 37 |
-
"""
|
| 38 |
-
Thrown when an a key of invalid type is used
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class ValidationError(OmegaConfBaseException, ValueError):
|
| 43 |
-
"""
|
| 44 |
-
Thrown when a value fails validation
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class UnsupportedValueType(ValidationError, ValueError):
|
| 49 |
-
"""
|
| 50 |
-
Thrown when an input value is not of supported type
|
| 51 |
-
"""
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class ReadonlyConfigError(OmegaConfBaseException):
|
| 55 |
-
"""
|
| 56 |
-
Thrown when someone tries to modify a frozen config
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
class InterpolationResolutionError(OmegaConfBaseException, ValueError):
|
| 61 |
-
"""
|
| 62 |
-
Base class for exceptions raised when resolving an interpolation.
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class UnsupportedInterpolationType(InterpolationResolutionError):
|
| 67 |
-
"""
|
| 68 |
-
Thrown when an attempt to use an unregistered interpolation is made
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class InterpolationKeyError(InterpolationResolutionError):
|
| 73 |
-
"""
|
| 74 |
-
Thrown when a node does not exist when resolving an interpolation.
|
| 75 |
-
"""
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class InterpolationToMissingValueError(InterpolationResolutionError):
|
| 79 |
-
"""
|
| 80 |
-
Thrown when a node interpolation points to a node that is set to ???.
|
| 81 |
-
"""
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
class InterpolationValidationError(InterpolationResolutionError, ValidationError):
|
| 85 |
-
"""
|
| 86 |
-
Thrown when the result of an interpolation fails the validation step.
|
| 87 |
-
"""
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class ConfigKeyError(OmegaConfBaseException, KeyError):
|
| 91 |
-
"""
|
| 92 |
-
Thrown from DictConfig when a regular dict access would have caused a KeyError.
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
msg: str
|
| 96 |
-
|
| 97 |
-
def __init__(self, msg: str) -> None:
|
| 98 |
-
super().__init__(msg)
|
| 99 |
-
self.msg = msg
|
| 100 |
-
|
| 101 |
-
def __str__(self) -> str:
|
| 102 |
-
"""
|
| 103 |
-
Workaround to nasty KeyError quirk: https://bugs.python.org/issue2651
|
| 104 |
-
"""
|
| 105 |
-
return self.msg
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class ConfigAttributeError(OmegaConfBaseException, AttributeError):
|
| 109 |
-
"""
|
| 110 |
-
Thrown from a config object when a regular access would have caused an AttributeError.
|
| 111 |
-
"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
class ConfigTypeError(OmegaConfBaseException, TypeError):
|
| 115 |
-
"""
|
| 116 |
-
Thrown from a config object when a regular access would have caused a TypeError.
|
| 117 |
-
"""
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
class ConfigIndexError(OmegaConfBaseException, IndexError):
|
| 121 |
-
"""
|
| 122 |
-
Thrown from a config object when a regular access would have caused an IndexError.
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class ConfigValueError(OmegaConfBaseException, ValueError):
|
| 127 |
-
"""
|
| 128 |
-
Thrown from a config object when a regular access would have caused a ValueError.
|
| 129 |
-
"""
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class ConfigCycleDetectedException(OmegaConfBaseException):
|
| 133 |
-
"""
|
| 134 |
-
Thrown when a cycle is detected in the graph made by config nodes.
|
| 135 |
-
"""
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class GrammarParseError(OmegaConfBaseException):
|
| 139 |
-
"""
|
| 140 |
-
Thrown when failing to parse an expression according to the ANTLR grammar.
|
| 141 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/grammar/OmegaConfGrammarLexer.g4
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
// Regenerate lexer and parser by running 'python setup.py antlr' at project root.
|
| 2 |
-
// See `OmegaConfGrammarParser.g4` for some important information regarding how to
|
| 3 |
-
// properly maintain this grammar.
|
| 4 |
-
|
| 5 |
-
lexer grammar OmegaConfGrammarLexer;
|
| 6 |
-
|
| 7 |
-
// Re-usable fragments.
|
| 8 |
-
fragment CHAR: [a-zA-Z];
|
| 9 |
-
fragment DIGIT: [0-9];
|
| 10 |
-
fragment INT_UNSIGNED: '0' | [1-9] (('_')? DIGIT)*;
|
| 11 |
-
fragment ESC_BACKSLASH: '\\\\'; // escaped backslash
|
| 12 |
-
|
| 13 |
-
/////////////////////////////
|
| 14 |
-
// DEFAULT_MODE (TOPLEVEL) //
|
| 15 |
-
/////////////////////////////
|
| 16 |
-
|
| 17 |
-
TOP_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
|
| 18 |
-
|
| 19 |
-
// Regular string: anything that does not contain any $ and does not end with \
|
| 20 |
-
// (this ensures this rule will not consume characters required to recognize other tokens).
|
| 21 |
-
ANY_STR: ~[$]* ~[\\$];
|
| 22 |
-
|
| 23 |
-
// Escaped interpolation: '\${', optionally preceded by an even number of \
|
| 24 |
-
ESC_INTER: ESC_BACKSLASH* '\\${';
|
| 25 |
-
|
| 26 |
-
// Backslashes that *may* be escaped (even number).
|
| 27 |
-
TOP_ESC: ESC_BACKSLASH+;
|
| 28 |
-
|
| 29 |
-
// Other backslashes that will not need escaping (odd number due to not matching the previous rule).
|
| 30 |
-
BACKSLASHES: '\\'+ -> type(ANY_STR);
|
| 31 |
-
|
| 32 |
-
// The dollar sign must be singled out so that we can recognize interpolations.
|
| 33 |
-
DOLLAR: '$' -> type(ANY_STR);
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
////////////////
|
| 37 |
-
// VALUE_MODE //
|
| 38 |
-
////////////////
|
| 39 |
-
|
| 40 |
-
mode VALUE_MODE;
|
| 41 |
-
|
| 42 |
-
INTER_OPEN: '${' WS? -> pushMode(INTERPOLATION_MODE);
|
| 43 |
-
BRACE_OPEN: '{' WS? -> pushMode(VALUE_MODE); // must keep track of braces to detect end of interpolation
|
| 44 |
-
BRACE_CLOSE: WS? '}' -> popMode;
|
| 45 |
-
QUOTE_OPEN_SINGLE: '\'' -> pushMode(QUOTED_SINGLE_MODE);
|
| 46 |
-
QUOTE_OPEN_DOUBLE: '"' -> pushMode(QUOTED_DOUBLE_MODE);
|
| 47 |
-
|
| 48 |
-
COMMA: WS? ',' WS?;
|
| 49 |
-
BRACKET_OPEN: '[' WS?;
|
| 50 |
-
BRACKET_CLOSE: WS? ']';
|
| 51 |
-
COLON: WS? ':' WS?;
|
| 52 |
-
|
| 53 |
-
// Numbers.
|
| 54 |
-
|
| 55 |
-
fragment POINT_FLOAT: INT_UNSIGNED '.' | INT_UNSIGNED? '.' DIGIT (('_')? DIGIT)*;
|
| 56 |
-
fragment EXPONENT_FLOAT: (INT_UNSIGNED | POINT_FLOAT) [eE] [+-]? DIGIT (('_')? DIGIT)*;
|
| 57 |
-
FLOAT: [+-]? (POINT_FLOAT | EXPONENT_FLOAT | [Ii][Nn][Ff] | [Nn][Aa][Nn]);
|
| 58 |
-
INT: [+-]? INT_UNSIGNED;
|
| 59 |
-
|
| 60 |
-
// Other reserved keywords.
|
| 61 |
-
|
| 62 |
-
BOOL:
|
| 63 |
-
[Tt][Rr][Uu][Ee] // TRUE
|
| 64 |
-
| [Ff][Aa][Ll][Ss][Ee]; // FALSE
|
| 65 |
-
|
| 66 |
-
NULL: [Nn][Uu][Ll][Ll];
|
| 67 |
-
|
| 68 |
-
UNQUOTED_CHAR: [/\-\\+.$%*@?|]; // other characters allowed in unquoted strings
|
| 69 |
-
ID: (CHAR|'_') (CHAR|DIGIT|'_'|'-')*;
|
| 70 |
-
ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
|
| 71 |
-
'\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
|
| 72 |
-
WS: [ \t]+;
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
////////////////////////
|
| 76 |
-
// INTERPOLATION_MODE //
|
| 77 |
-
////////////////////////
|
| 78 |
-
|
| 79 |
-
mode INTERPOLATION_MODE;
|
| 80 |
-
|
| 81 |
-
NESTED_INTER_OPEN: INTER_OPEN WS? -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
|
| 82 |
-
INTER_COLON: WS? ':' WS? -> type(COLON), mode(VALUE_MODE);
|
| 83 |
-
INTER_CLOSE: WS? '}' -> popMode;
|
| 84 |
-
|
| 85 |
-
DOT: '.';
|
| 86 |
-
INTER_BRACKET_OPEN: '[' -> type(BRACKET_OPEN);
|
| 87 |
-
INTER_BRACKET_CLOSE: ']' -> type(BRACKET_CLOSE);
|
| 88 |
-
INTER_ID: ID -> type(ID);
|
| 89 |
-
|
| 90 |
-
// Interpolation key, may contain any non special character.
|
| 91 |
-
// Note that we can allow '$' because the parser does not support interpolations that
|
| 92 |
-
// are only part of a key name, i.e., "${foo${bar}}" is not allowed. As a result, it
|
| 93 |
-
// is ok to "consume" all '$' characters within the `INTER_KEY` token.
|
| 94 |
-
INTER_KEY: ~[\\{}()[\]:. \t'"]+;
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
////////////////////////
|
| 98 |
-
// QUOTED_SINGLE_MODE //
|
| 99 |
-
////////////////////////
|
| 100 |
-
|
| 101 |
-
mode QUOTED_SINGLE_MODE;
|
| 102 |
-
|
| 103 |
-
// This mode is very similar to `DEFAULT_MODE` except for the handling of quotes.
|
| 104 |
-
|
| 105 |
-
QSINGLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
|
| 106 |
-
MATCHING_QUOTE_CLOSE: '\'' -> popMode;
|
| 107 |
-
|
| 108 |
-
// Regular string: anything that does not contain any $ *or quote* and does not end with \
|
| 109 |
-
QSINGLE_STR: ~['$]* ~['\\$] -> type(ANY_STR);
|
| 110 |
-
|
| 111 |
-
QSINGLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);
|
| 112 |
-
|
| 113 |
-
// Escaped quote (optionally preceded by an even number of backslashes).
|
| 114 |
-
QSINGLE_ESC_QUOTE: ESC_BACKSLASH* '\\\'' -> type(ESC);
|
| 115 |
-
|
| 116 |
-
QUOTED_ESC: ESC_BACKSLASH+;
|
| 117 |
-
QSINGLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
|
| 118 |
-
QSINGLE_DOLLAR: '$' -> type(ANY_STR);
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
////////////////////////
|
| 122 |
-
// QUOTED_DOUBLE_MODE //
|
| 123 |
-
////////////////////////
|
| 124 |
-
|
| 125 |
-
mode QUOTED_DOUBLE_MODE;
|
| 126 |
-
|
| 127 |
-
// Same as `QUOTED_SINGLE_MODE` but for double quotes.
|
| 128 |
-
|
| 129 |
-
QDOUBLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
|
| 130 |
-
QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode;
|
| 131 |
-
|
| 132 |
-
QDOUBLE_STR: ~["$]* ~["\\$] -> type(ANY_STR);
|
| 133 |
-
QDOUBLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);
|
| 134 |
-
QDOUBLE_ESC_QUOTE: ESC_BACKSLASH* '\\"' -> type(ESC);
|
| 135 |
-
QDOUBLE_ESC: ESC_BACKSLASH+ -> type(QUOTED_ESC);
|
| 136 |
-
QDOUBLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
|
| 137 |
-
QDOUBLE_DOLLAR: '$' -> type(ANY_STR);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/grammar/OmegaConfGrammarParser.g4
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
// Regenerate parser by running 'python setup.py antlr' at project root.
|
| 2 |
-
|
| 3 |
-
// Maintenance guidelines when modifying this grammar:
|
| 4 |
-
//
|
| 5 |
-
// - Consider whether the regex pattern `SIMPLE_INTERPOLATION_PATTERN` found in
|
| 6 |
-
// `grammar_parser.py` should be updated as well.
|
| 7 |
-
//
|
| 8 |
-
// - Update Hydra's grammar accordingly.
|
| 9 |
-
//
|
| 10 |
-
// - Keep up-to-date the comments in the visitor (in `grammar_visitor.py`)
|
| 11 |
-
// that contain grammar excerpts (within each `visit...()` method).
|
| 12 |
-
//
|
| 13 |
-
// - Remember to update the documentation (including the tutorial notebook as
|
| 14 |
-
// well as grammar.rst)
|
| 15 |
-
|
| 16 |
-
parser grammar OmegaConfGrammarParser;
|
| 17 |
-
options {tokenVocab = OmegaConfGrammarLexer;}
|
| 18 |
-
|
| 19 |
-
// Main rules used to parse OmegaConf strings.
|
| 20 |
-
|
| 21 |
-
configValue: text EOF;
|
| 22 |
-
singleElement: element EOF;
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
// Composite text expression (may contain interpolations).
|
| 26 |
-
|
| 27 |
-
text: (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+;
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
// Elements.
|
| 31 |
-
|
| 32 |
-
element:
|
| 33 |
-
primitive
|
| 34 |
-
| quotedValue
|
| 35 |
-
| listContainer
|
| 36 |
-
| dictContainer
|
| 37 |
-
;
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
// Data structures.
|
| 41 |
-
|
| 42 |
-
listContainer: BRACKET_OPEN sequence? BRACKET_CLOSE; // [], [1,2,3], [a,b,[1,2]]
|
| 43 |
-
dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
|
| 44 |
-
dictKeyValuePair: dictKey COLON element;
|
| 45 |
-
sequence: (element (COMMA element?)*) | (COMMA element?)+;
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
// Interpolations.
|
| 49 |
-
|
| 50 |
-
interpolation: interpolationNode | interpolationResolver;
|
| 51 |
-
|
| 52 |
-
interpolationNode:
|
| 53 |
-
INTER_OPEN
|
| 54 |
-
DOT* // relative interpolation?
|
| 55 |
-
(configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
|
| 56 |
-
(DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
|
| 57 |
-
INTER_CLOSE;
|
| 58 |
-
interpolationResolver: INTER_OPEN resolverName COLON sequence? BRACE_CLOSE;
|
| 59 |
-
configKey: interpolation | ID | INTER_KEY;
|
| 60 |
-
resolverName: (interpolation | ID) (DOT (interpolation | ID))* ; // oc.env, myfunc, ns.${x}, ns1.ns2.f
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
// Primitive types.
|
| 64 |
-
|
| 65 |
-
// Ex: "hello world", 'hello ${world}'
|
| 66 |
-
quotedValue: (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE;
|
| 67 |
-
|
| 68 |
-
primitive:
|
| 69 |
-
( ID // foo_10
|
| 70 |
-
| NULL // null, NULL
|
| 71 |
-
| INT // 0, 10, -20, 1_000_000
|
| 72 |
-
| FLOAT // 3.14, -20.0, 1e-1, -10e3
|
| 73 |
-
| BOOL // true, TrUe, false, False
|
| 74 |
-
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, |
|
| 75 |
-
| COLON // :
|
| 76 |
-
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
|
| 77 |
-
| WS // whitespaces
|
| 78 |
-
| interpolation
|
| 79 |
-
)+;
|
| 80 |
-
|
| 81 |
-
// Same as `primitive` except that `COLON` and interpolations are not allowed.
|
| 82 |
-
dictKey:
|
| 83 |
-
( ID // foo_10
|
| 84 |
-
| NULL // null, NULL
|
| 85 |
-
| INT // 0, 10, -20, 1_000_000
|
| 86 |
-
| FLOAT // 3.14, -20.0, 1e-1, -10e3
|
| 87 |
-
| BOOL // true, TrUe, false, False
|
| 88 |
-
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, |
|
| 89 |
-
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
|
| 90 |
-
| WS // whitespaces
|
| 91 |
-
)+;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/grammar/__init__.py
DELETED
|
File without changes
|
omegaconf/grammar/gen/__init__.py
DELETED
|
File without changes
|
omegaconf/grammar_parser.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import threading
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
-
from antlr4 import CommonTokenStream, InputStream, ParserRuleContext
|
| 6 |
-
from antlr4.error.ErrorListener import ErrorListener
|
| 7 |
-
|
| 8 |
-
from .errors import GrammarParseError
|
| 9 |
-
|
| 10 |
-
# Import from visitor in order to check the presence of generated grammar files
|
| 11 |
-
# files in a single place.
|
| 12 |
-
from .grammar_visitor import ( # type: ignore
|
| 13 |
-
OmegaConfGrammarLexer,
|
| 14 |
-
OmegaConfGrammarParser,
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
# Used to cache grammar objects to avoid re-creating them on each call to `parse()`.
|
| 18 |
-
# We use a per-thread cache to make it thread-safe.
|
| 19 |
-
_grammar_cache = threading.local()
|
| 20 |
-
|
| 21 |
-
# Build regex pattern to efficiently identify typical interpolations.
|
| 22 |
-
# See test `test_match_simple_interpolation_pattern` for examples.
|
| 23 |
-
_config_key = r"[$\w]+" # foo, $0, $bar, $foo_$bar123$
|
| 24 |
-
_key_maybe_brackets = f"{_config_key}|\\[{_config_key}\\]" # foo, [foo], [$bar]
|
| 25 |
-
_node_access = f"\\.{_key_maybe_brackets}" # .foo, [foo], [$bar]
|
| 26 |
-
_node_path = f"(\\.)*({_key_maybe_brackets})({_node_access})*" # [foo].bar, .foo[bar]
|
| 27 |
-
_node_inter = f"\\${{\\s*{_node_path}\\s*}}" # node interpolation ${foo.bar}
|
| 28 |
-
_id = "[a-zA-Z_][\\w\\-]*" # foo, foo_bar, foo-bar, abc123
|
| 29 |
-
_resolver_name = f"({_id}(\\.{_id})*)?" # foo, ns.bar3, ns_1.ns_2.b0z
|
| 30 |
-
_arg = r"[a-zA-Z_0-9/\-\+.$%*@?|]+" # string representing a resolver argument
|
| 31 |
-
_args = f"{_arg}(\\s*,\\s*{_arg})*" # list of resolver arguments
|
| 32 |
-
_resolver_inter = f"\\${{\\s*{_resolver_name}\\s*:\\s*{_args}?\\s*}}" # ${foo:bar}
|
| 33 |
-
_inter = f"({_node_inter}|{_resolver_inter})" # any kind of interpolation
|
| 34 |
-
_outer = "([^$]|\\$(?!{))+" # any character except $ (unless not followed by {)
|
| 35 |
-
SIMPLE_INTERPOLATION_PATTERN = re.compile(
|
| 36 |
-
f"({_outer})?({_inter}({_outer})?)+$", flags=re.ASCII
|
| 37 |
-
)
|
| 38 |
-
# NOTE: SIMPLE_INTERPOLATION_PATTERN must not generate false positive matches:
|
| 39 |
-
# it must not accept anything that isn't a valid interpolation (per the
|
| 40 |
-
# interpolation grammar defined in `omegaconf/grammar/*.g4`).
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class OmegaConfErrorListener(ErrorListener): # type: ignore
|
| 44 |
-
def syntaxError(
|
| 45 |
-
self,
|
| 46 |
-
recognizer: Any,
|
| 47 |
-
offending_symbol: Any,
|
| 48 |
-
line: Any,
|
| 49 |
-
column: Any,
|
| 50 |
-
msg: Any,
|
| 51 |
-
e: Any,
|
| 52 |
-
) -> None:
|
| 53 |
-
raise GrammarParseError(str(e) if msg is None else msg) from e
|
| 54 |
-
|
| 55 |
-
def reportAmbiguity(
|
| 56 |
-
self,
|
| 57 |
-
recognizer: Any,
|
| 58 |
-
dfa: Any,
|
| 59 |
-
startIndex: Any,
|
| 60 |
-
stopIndex: Any,
|
| 61 |
-
exact: Any,
|
| 62 |
-
ambigAlts: Any,
|
| 63 |
-
configs: Any,
|
| 64 |
-
) -> None:
|
| 65 |
-
raise GrammarParseError("ANTLR error: Ambiguity") # pragma: no cover
|
| 66 |
-
|
| 67 |
-
def reportAttemptingFullContext(
|
| 68 |
-
self,
|
| 69 |
-
recognizer: Any,
|
| 70 |
-
dfa: Any,
|
| 71 |
-
startIndex: Any,
|
| 72 |
-
stopIndex: Any,
|
| 73 |
-
conflictingAlts: Any,
|
| 74 |
-
configs: Any,
|
| 75 |
-
) -> None:
|
| 76 |
-
# Note: for now we raise an error to be safe. However this is mostly a
|
| 77 |
-
# performance warning, so in the future this may be relaxed if we need
|
| 78 |
-
# to change the grammar in such a way that this warning cannot be
|
| 79 |
-
# avoided (another option would be to switch to SLL parsing mode).
|
| 80 |
-
raise GrammarParseError(
|
| 81 |
-
"ANTLR error: Attempting Full Context"
|
| 82 |
-
) # pragma: no cover
|
| 83 |
-
|
| 84 |
-
def reportContextSensitivity(
|
| 85 |
-
self,
|
| 86 |
-
recognizer: Any,
|
| 87 |
-
dfa: Any,
|
| 88 |
-
startIndex: Any,
|
| 89 |
-
stopIndex: Any,
|
| 90 |
-
prediction: Any,
|
| 91 |
-
configs: Any,
|
| 92 |
-
) -> None:
|
| 93 |
-
raise GrammarParseError("ANTLR error: ContextSensitivity") # pragma: no cover
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def parse(
|
| 97 |
-
value: str, parser_rule: str = "configValue", lexer_mode: str = "DEFAULT_MODE"
|
| 98 |
-
) -> ParserRuleContext:
|
| 99 |
-
"""
|
| 100 |
-
Parse interpolated string `value` (and return the parse tree).
|
| 101 |
-
"""
|
| 102 |
-
l_mode = getattr(OmegaConfGrammarLexer, lexer_mode)
|
| 103 |
-
istream = InputStream(value)
|
| 104 |
-
|
| 105 |
-
cached = getattr(_grammar_cache, "data", None)
|
| 106 |
-
if cached is None:
|
| 107 |
-
error_listener = OmegaConfErrorListener()
|
| 108 |
-
lexer = OmegaConfGrammarLexer(istream)
|
| 109 |
-
lexer.removeErrorListeners()
|
| 110 |
-
lexer.addErrorListener(error_listener)
|
| 111 |
-
lexer.mode(l_mode)
|
| 112 |
-
token_stream = CommonTokenStream(lexer)
|
| 113 |
-
parser = OmegaConfGrammarParser(token_stream)
|
| 114 |
-
parser.removeErrorListeners()
|
| 115 |
-
parser.addErrorListener(error_listener)
|
| 116 |
-
|
| 117 |
-
# The two lines below could be enabled in the future if we decide to switch
|
| 118 |
-
# to SLL prediction mode. Warning though, it has not been fully tested yet!
|
| 119 |
-
# from antlr4 import PredictionMode
|
| 120 |
-
# parser._interp.predictionMode = PredictionMode.SLL
|
| 121 |
-
|
| 122 |
-
# Note that although the input stream `istream` is implicitly cached within
|
| 123 |
-
# the lexer, it will be replaced by a new input next time the lexer is re-used.
|
| 124 |
-
_grammar_cache.data = lexer, token_stream, parser
|
| 125 |
-
|
| 126 |
-
else:
|
| 127 |
-
lexer, token_stream, parser = cached
|
| 128 |
-
# Replace the old input stream with the new one.
|
| 129 |
-
lexer.inputStream = istream
|
| 130 |
-
# Initialize the lexer / token stream / parser to process the new input.
|
| 131 |
-
lexer.mode(l_mode)
|
| 132 |
-
token_stream.setTokenSource(lexer)
|
| 133 |
-
parser.reset()
|
| 134 |
-
|
| 135 |
-
try:
|
| 136 |
-
return getattr(parser, parser_rule)()
|
| 137 |
-
except Exception as exc:
|
| 138 |
-
if type(exc) is Exception and str(exc) == "Empty Stack":
|
| 139 |
-
# This exception is raised by antlr when trying to pop a mode while
|
| 140 |
-
# no mode has been pushed. We convert it into an `GrammarParseError`
|
| 141 |
-
# to facilitate exception handling from the caller.
|
| 142 |
-
raise GrammarParseError("Empty Stack")
|
| 143 |
-
else:
|
| 144 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/grammar_visitor.py
DELETED
|
@@ -1,392 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import warnings
|
| 3 |
-
from itertools import zip_longest
|
| 4 |
-
from typing import (
|
| 5 |
-
TYPE_CHECKING,
|
| 6 |
-
Any,
|
| 7 |
-
Callable,
|
| 8 |
-
Dict,
|
| 9 |
-
Generator,
|
| 10 |
-
List,
|
| 11 |
-
Optional,
|
| 12 |
-
Set,
|
| 13 |
-
Tuple,
|
| 14 |
-
Union,
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
from antlr4 import TerminalNode
|
| 18 |
-
|
| 19 |
-
from .errors import InterpolationResolutionError
|
| 20 |
-
|
| 21 |
-
if TYPE_CHECKING:
|
| 22 |
-
from .base import Node # noqa F401
|
| 23 |
-
|
| 24 |
-
try:
|
| 25 |
-
from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer
|
| 26 |
-
from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
|
| 27 |
-
from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import (
|
| 28 |
-
OmegaConfGrammarParserVisitor,
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
except ModuleNotFoundError: # pragma: no cover
|
| 32 |
-
print(
|
| 33 |
-
"Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.",
|
| 34 |
-
file=sys.stderr,
|
| 35 |
-
)
|
| 36 |
-
sys.exit(1)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class GrammarVisitor(OmegaConfGrammarParserVisitor):
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
node_interpolation_callback: Callable[
|
| 43 |
-
[str, Optional[Set[int]]],
|
| 44 |
-
Optional["Node"],
|
| 45 |
-
],
|
| 46 |
-
resolver_interpolation_callback: Callable[..., Any],
|
| 47 |
-
memo: Optional[Set[int]],
|
| 48 |
-
**kw: Dict[Any, Any],
|
| 49 |
-
):
|
| 50 |
-
"""
|
| 51 |
-
Constructor.
|
| 52 |
-
|
| 53 |
-
:param node_interpolation_callback: Callback function that is called when
|
| 54 |
-
needing to resolve a node interpolation. This function should take a single
|
| 55 |
-
string input which is the key's dot path (ex: `"foo.bar"`).
|
| 56 |
-
|
| 57 |
-
:param resolver_interpolation_callback: Callback function that is called when
|
| 58 |
-
needing to resolve a resolver interpolation. This function should accept
|
| 59 |
-
three keyword arguments: `name` (str, the name of the resolver),
|
| 60 |
-
`args` (tuple, the inputs to the resolver), and `args_str` (tuple,
|
| 61 |
-
the string representation of the inputs to the resolver).
|
| 62 |
-
|
| 63 |
-
:param kw: Additional keyword arguments to be forwarded to parent class.
|
| 64 |
-
"""
|
| 65 |
-
super().__init__(**kw)
|
| 66 |
-
self.node_interpolation_callback = node_interpolation_callback
|
| 67 |
-
self.resolver_interpolation_callback = resolver_interpolation_callback
|
| 68 |
-
self.memo = memo
|
| 69 |
-
|
| 70 |
-
def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]:
|
| 71 |
-
raise NotImplementedError
|
| 72 |
-
|
| 73 |
-
def defaultResult(self) -> List[Any]:
|
| 74 |
-
# Raising an exception because not currently used (like `aggregateResult()`).
|
| 75 |
-
raise NotImplementedError
|
| 76 |
-
|
| 77 |
-
def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
|
| 78 |
-
from ._utils import _get_value
|
| 79 |
-
|
| 80 |
-
# interpolation | ID | INTER_KEY
|
| 81 |
-
assert ctx.getChildCount() == 1
|
| 82 |
-
child = ctx.getChild(0)
|
| 83 |
-
if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
|
| 84 |
-
res = _get_value(self.visitInterpolation(child))
|
| 85 |
-
if not isinstance(res, str):
|
| 86 |
-
raise InterpolationResolutionError(
|
| 87 |
-
f"The following interpolation is used to denote a config key and "
|
| 88 |
-
f"thus should return a string, but instead returned `{res}` of "
|
| 89 |
-
f"type `{type(res)}`: {ctx.getChild(0).getText()}"
|
| 90 |
-
)
|
| 91 |
-
return res
|
| 92 |
-
else:
|
| 93 |
-
assert isinstance(child, TerminalNode) and isinstance(
|
| 94 |
-
child.symbol.text, str
|
| 95 |
-
)
|
| 96 |
-
return child.symbol.text
|
| 97 |
-
|
| 98 |
-
def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
|
| 99 |
-
# text EOF
|
| 100 |
-
assert ctx.getChildCount() == 2
|
| 101 |
-
return self.visit(ctx.getChild(0))
|
| 102 |
-
|
| 103 |
-
def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any:
|
| 104 |
-
return self._createPrimitive(ctx)
|
| 105 |
-
|
| 106 |
-
def visitDictContainer(
|
| 107 |
-
self, ctx: OmegaConfGrammarParser.DictContainerContext
|
| 108 |
-
) -> Dict[Any, Any]:
|
| 109 |
-
# BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE
|
| 110 |
-
assert ctx.getChildCount() >= 2
|
| 111 |
-
return dict(
|
| 112 |
-
self.visitDictKeyValuePair(ctx.getChild(i))
|
| 113 |
-
for i in range(1, ctx.getChildCount() - 1, 2)
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:
|
| 117 |
-
# primitive | quotedValue | listContainer | dictContainer
|
| 118 |
-
assert ctx.getChildCount() == 1
|
| 119 |
-
return self.visit(ctx.getChild(0))
|
| 120 |
-
|
| 121 |
-
def visitInterpolation(
|
| 122 |
-
self, ctx: OmegaConfGrammarParser.InterpolationContext
|
| 123 |
-
) -> Any:
|
| 124 |
-
assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
|
| 125 |
-
return self.visit(ctx.getChild(0))
|
| 126 |
-
|
| 127 |
-
def visitInterpolationNode(
|
| 128 |
-
self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
|
| 129 |
-
) -> Optional["Node"]:
|
| 130 |
-
# INTER_OPEN
|
| 131 |
-
# DOT* // relative interpolation?
|
| 132 |
-
# (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
|
| 133 |
-
# (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
|
| 134 |
-
# INTER_CLOSE;
|
| 135 |
-
|
| 136 |
-
assert ctx.getChildCount() >= 3
|
| 137 |
-
|
| 138 |
-
inter_key_tokens = [] # parsed elements of the dot path
|
| 139 |
-
for child in ctx.getChildren():
|
| 140 |
-
if isinstance(child, TerminalNode):
|
| 141 |
-
s = child.symbol
|
| 142 |
-
if s.type in [
|
| 143 |
-
OmegaConfGrammarLexer.DOT,
|
| 144 |
-
OmegaConfGrammarLexer.BRACKET_OPEN,
|
| 145 |
-
OmegaConfGrammarLexer.BRACKET_CLOSE,
|
| 146 |
-
]:
|
| 147 |
-
inter_key_tokens.append(s.text)
|
| 148 |
-
else:
|
| 149 |
-
assert s.type in (
|
| 150 |
-
OmegaConfGrammarLexer.INTER_OPEN,
|
| 151 |
-
OmegaConfGrammarLexer.INTER_CLOSE,
|
| 152 |
-
)
|
| 153 |
-
else:
|
| 154 |
-
assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext)
|
| 155 |
-
inter_key_tokens.append(self.visitConfigKey(child))
|
| 156 |
-
|
| 157 |
-
inter_key = "".join(inter_key_tokens)
|
| 158 |
-
return self.node_interpolation_callback(inter_key, self.memo)
|
| 159 |
-
|
| 160 |
-
def visitInterpolationResolver(
|
| 161 |
-
self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
|
| 162 |
-
) -> Any:
|
| 163 |
-
|
| 164 |
-
# INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
|
| 165 |
-
assert 4 <= ctx.getChildCount() <= 5
|
| 166 |
-
|
| 167 |
-
resolver_name = self.visit(ctx.getChild(1))
|
| 168 |
-
maybe_seq = ctx.getChild(3)
|
| 169 |
-
args = []
|
| 170 |
-
args_str = []
|
| 171 |
-
if isinstance(maybe_seq, TerminalNode): # means there are no args
|
| 172 |
-
assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE
|
| 173 |
-
else:
|
| 174 |
-
assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext)
|
| 175 |
-
for val, txt in self.visitSequence(maybe_seq):
|
| 176 |
-
args.append(val)
|
| 177 |
-
args_str.append(txt)
|
| 178 |
-
|
| 179 |
-
return self.resolver_interpolation_callback(
|
| 180 |
-
name=resolver_name,
|
| 181 |
-
args=tuple(args),
|
| 182 |
-
args_str=tuple(args_str),
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
def visitDictKeyValuePair(
|
| 186 |
-
self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext
|
| 187 |
-
) -> Tuple[Any, Any]:
|
| 188 |
-
from ._utils import _get_value
|
| 189 |
-
|
| 190 |
-
assert ctx.getChildCount() == 3 # dictKey COLON element
|
| 191 |
-
key = self.visit(ctx.getChild(0))
|
| 192 |
-
colon = ctx.getChild(1)
|
| 193 |
-
assert (
|
| 194 |
-
isinstance(colon, TerminalNode)
|
| 195 |
-
and colon.symbol.type == OmegaConfGrammarLexer.COLON
|
| 196 |
-
)
|
| 197 |
-
value = _get_value(self.visitElement(ctx.getChild(2)))
|
| 198 |
-
return key, value
|
| 199 |
-
|
| 200 |
-
def visitListContainer(
|
| 201 |
-
self, ctx: OmegaConfGrammarParser.ListContainerContext
|
| 202 |
-
) -> List[Any]:
|
| 203 |
-
# BRACKET_OPEN sequence? BRACKET_CLOSE;
|
| 204 |
-
assert ctx.getChildCount() in (2, 3)
|
| 205 |
-
if ctx.getChildCount() == 2:
|
| 206 |
-
return []
|
| 207 |
-
sequence = ctx.getChild(1)
|
| 208 |
-
assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext)
|
| 209 |
-
return list(val for val, _ in self.visitSequence(sequence)) # ignore raw text
|
| 210 |
-
|
| 211 |
-
def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any:
|
| 212 |
-
return self._createPrimitive(ctx)
|
| 213 |
-
|
| 214 |
-
def visitQuotedValue(self, ctx: OmegaConfGrammarParser.QuotedValueContext) -> str:
|
| 215 |
-
# (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE
|
| 216 |
-
n = ctx.getChildCount()
|
| 217 |
-
assert n in [2, 3]
|
| 218 |
-
return str(self.visit(ctx.getChild(1))) if n == 3 else ""
|
| 219 |
-
|
| 220 |
-
def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> str:
|
| 221 |
-
from ._utils import _get_value
|
| 222 |
-
|
| 223 |
-
# (interpolation | ID) (DOT (interpolation | ID))*
|
| 224 |
-
assert ctx.getChildCount() >= 1
|
| 225 |
-
items = []
|
| 226 |
-
for child in list(ctx.getChildren())[::2]:
|
| 227 |
-
if isinstance(child, TerminalNode):
|
| 228 |
-
assert child.symbol.type == OmegaConfGrammarLexer.ID
|
| 229 |
-
items.append(child.symbol.text)
|
| 230 |
-
else:
|
| 231 |
-
assert isinstance(child, OmegaConfGrammarParser.InterpolationContext)
|
| 232 |
-
item = _get_value(self.visitInterpolation(child))
|
| 233 |
-
if not isinstance(item, str):
|
| 234 |
-
raise InterpolationResolutionError(
|
| 235 |
-
f"The name of a resolver must be a string, but the interpolation "
|
| 236 |
-
f"{child.getText()} resolved to `{item}` which is of type "
|
| 237 |
-
f"{type(item)}"
|
| 238 |
-
)
|
| 239 |
-
items.append(item)
|
| 240 |
-
return ".".join(items)
|
| 241 |
-
|
| 242 |
-
def visitSequence(
|
| 243 |
-
self, ctx: OmegaConfGrammarParser.SequenceContext
|
| 244 |
-
) -> Generator[Any, None, None]:
|
| 245 |
-
from ._utils import _get_value
|
| 246 |
-
|
| 247 |
-
# (element (COMMA element?)*) | (COMMA element?)+
|
| 248 |
-
assert ctx.getChildCount() >= 1
|
| 249 |
-
|
| 250 |
-
# DEPRECATED: remove in 2.2 (revert #571)
|
| 251 |
-
def empty_str_warning() -> None:
|
| 252 |
-
txt = ctx.getText()
|
| 253 |
-
warnings.warn(
|
| 254 |
-
f"In the sequence `{txt}` some elements are missing: please replace "
|
| 255 |
-
f"them with empty quoted strings. "
|
| 256 |
-
f"See https://github.com/omry/omegaconf/issues/572 for details.",
|
| 257 |
-
category=UserWarning,
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
is_previous_comma = True # whether previous child was a comma (init to True)
|
| 261 |
-
for child in ctx.getChildren():
|
| 262 |
-
if isinstance(child, OmegaConfGrammarParser.ElementContext):
|
| 263 |
-
# Also preserve the original text representation of `child` so
|
| 264 |
-
# as to allow backward compatibility with old resolvers (registered
|
| 265 |
-
# with `legacy_register_resolver()`). Note that we cannot just cast
|
| 266 |
-
# the value to string later as for instance `null` would become "None".
|
| 267 |
-
yield _get_value(self.visitElement(child)), child.getText()
|
| 268 |
-
is_previous_comma = False
|
| 269 |
-
else:
|
| 270 |
-
assert (
|
| 271 |
-
isinstance(child, TerminalNode)
|
| 272 |
-
and child.symbol.type == OmegaConfGrammarLexer.COMMA
|
| 273 |
-
)
|
| 274 |
-
if is_previous_comma:
|
| 275 |
-
empty_str_warning()
|
| 276 |
-
yield "", ""
|
| 277 |
-
else:
|
| 278 |
-
is_previous_comma = True
|
| 279 |
-
if is_previous_comma:
|
| 280 |
-
# Trailing comma.
|
| 281 |
-
empty_str_warning()
|
| 282 |
-
yield "", ""
|
| 283 |
-
|
| 284 |
-
def visitSingleElement(
|
| 285 |
-
self, ctx: OmegaConfGrammarParser.SingleElementContext
|
| 286 |
-
) -> Any:
|
| 287 |
-
# element EOF
|
| 288 |
-
assert ctx.getChildCount() == 2
|
| 289 |
-
return self.visit(ctx.getChild(0))
|
| 290 |
-
|
| 291 |
-
def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
|
| 292 |
-
# (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+
|
| 293 |
-
|
| 294 |
-
# Single interpolation? If yes, return its resolved value "as is".
|
| 295 |
-
if ctx.getChildCount() == 1:
|
| 296 |
-
c = ctx.getChild(0)
|
| 297 |
-
if isinstance(c, OmegaConfGrammarParser.InterpolationContext):
|
| 298 |
-
return self.visitInterpolation(c)
|
| 299 |
-
|
| 300 |
-
# Otherwise, concatenate string representations together.
|
| 301 |
-
return self._unescape(list(ctx.getChildren()))
|
| 302 |
-
|
| 303 |
-
def _createPrimitive(
|
| 304 |
-
self,
|
| 305 |
-
ctx: Union[
|
| 306 |
-
OmegaConfGrammarParser.PrimitiveContext,
|
| 307 |
-
OmegaConfGrammarParser.DictKeyContext,
|
| 308 |
-
],
|
| 309 |
-
) -> Any:
|
| 310 |
-
# (ID | NULL | INT | FLOAT | BOOL | UNQUOTED_CHAR | COLON | ESC | WS | interpolation)+
|
| 311 |
-
if ctx.getChildCount() == 1:
|
| 312 |
-
child = ctx.getChild(0)
|
| 313 |
-
if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
|
| 314 |
-
return self.visitInterpolation(child)
|
| 315 |
-
assert isinstance(child, TerminalNode)
|
| 316 |
-
symbol = child.symbol
|
| 317 |
-
# Parse primitive types.
|
| 318 |
-
if symbol.type in (
|
| 319 |
-
OmegaConfGrammarLexer.ID,
|
| 320 |
-
OmegaConfGrammarLexer.UNQUOTED_CHAR,
|
| 321 |
-
OmegaConfGrammarLexer.COLON,
|
| 322 |
-
):
|
| 323 |
-
return symbol.text
|
| 324 |
-
elif symbol.type == OmegaConfGrammarLexer.NULL:
|
| 325 |
-
return None
|
| 326 |
-
elif symbol.type == OmegaConfGrammarLexer.INT:
|
| 327 |
-
return int(symbol.text)
|
| 328 |
-
elif symbol.type == OmegaConfGrammarLexer.FLOAT:
|
| 329 |
-
return float(symbol.text)
|
| 330 |
-
elif symbol.type == OmegaConfGrammarLexer.BOOL:
|
| 331 |
-
return symbol.text.lower() == "true"
|
| 332 |
-
elif symbol.type == OmegaConfGrammarLexer.ESC:
|
| 333 |
-
return self._unescape([child])
|
| 334 |
-
elif symbol.type == OmegaConfGrammarLexer.WS: # pragma: no cover
|
| 335 |
-
# A single WS should have been "consumed" by another token.
|
| 336 |
-
raise AssertionError("WS should never be reached")
|
| 337 |
-
assert False, symbol.type
|
| 338 |
-
# Concatenation of multiple items ==> un-escape the concatenation.
|
| 339 |
-
return self._unescape(list(ctx.getChildren()))
|
| 340 |
-
|
| 341 |
-
def _unescape(
|
| 342 |
-
self,
|
| 343 |
-
seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
|
| 344 |
-
) -> str:
|
| 345 |
-
"""
|
| 346 |
-
Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed.
|
| 347 |
-
|
| 348 |
-
Interpolations are resolved and cast to string *WITHOUT* escaping their result
|
| 349 |
-
(it is assumed that whatever escaping is required was already handled during the
|
| 350 |
-
resolving of the interpolation).
|
| 351 |
-
"""
|
| 352 |
-
chrs = []
|
| 353 |
-
for node, next_node in zip_longest(seq, seq[1:]):
|
| 354 |
-
if isinstance(node, TerminalNode):
|
| 355 |
-
s = node.symbol
|
| 356 |
-
if s.type == OmegaConfGrammarLexer.ESC_INTER:
|
| 357 |
-
# `ESC_INTER` is of the form `\\...\${`: the formula below computes
|
| 358 |
-
# the number of characters to keep at the end of the string to remove
|
| 359 |
-
# the correct number of backslashes.
|
| 360 |
-
text = s.text[-(len(s.text) // 2 + 1) :]
|
| 361 |
-
elif (
|
| 362 |
-
# Character sequence identified as requiring un-escaping.
|
| 363 |
-
s.type == OmegaConfGrammarLexer.ESC
|
| 364 |
-
or (
|
| 365 |
-
# At top level, we need to un-escape backslashes that precede
|
| 366 |
-
# an interpolation.
|
| 367 |
-
s.type == OmegaConfGrammarLexer.TOP_ESC
|
| 368 |
-
and isinstance(
|
| 369 |
-
next_node, OmegaConfGrammarParser.InterpolationContext
|
| 370 |
-
)
|
| 371 |
-
)
|
| 372 |
-
or (
|
| 373 |
-
# In a quoted sring, we need to un-escape backslashes that
|
| 374 |
-
# either end the string, or are followed by an interpolation.
|
| 375 |
-
s.type == OmegaConfGrammarLexer.QUOTED_ESC
|
| 376 |
-
and (
|
| 377 |
-
next_node is None
|
| 378 |
-
or isinstance(
|
| 379 |
-
next_node, OmegaConfGrammarParser.InterpolationContext
|
| 380 |
-
)
|
| 381 |
-
)
|
| 382 |
-
)
|
| 383 |
-
):
|
| 384 |
-
text = s.text[1::2] # un-escape the sequence
|
| 385 |
-
else:
|
| 386 |
-
text = s.text # keep the original text
|
| 387 |
-
else:
|
| 388 |
-
assert isinstance(node, OmegaConfGrammarParser.InterpolationContext)
|
| 389 |
-
text = str(self.visitInterpolation(node))
|
| 390 |
-
chrs.append(text)
|
| 391 |
-
|
| 392 |
-
return "".join(chrs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/listconfig.py
DELETED
|
@@ -1,679 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import itertools
|
| 3 |
-
from typing import (
|
| 4 |
-
Any,
|
| 5 |
-
Callable,
|
| 6 |
-
Dict,
|
| 7 |
-
Iterable,
|
| 8 |
-
Iterator,
|
| 9 |
-
List,
|
| 10 |
-
MutableSequence,
|
| 11 |
-
Optional,
|
| 12 |
-
Tuple,
|
| 13 |
-
Type,
|
| 14 |
-
Union,
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
from ._utils import (
|
| 18 |
-
ValueKind,
|
| 19 |
-
_is_missing_literal,
|
| 20 |
-
_is_none,
|
| 21 |
-
_resolve_optional,
|
| 22 |
-
format_and_raise,
|
| 23 |
-
get_value_kind,
|
| 24 |
-
is_int,
|
| 25 |
-
is_primitive_list,
|
| 26 |
-
is_structured_config,
|
| 27 |
-
type_str,
|
| 28 |
-
)
|
| 29 |
-
from .base import Box, ContainerMetadata, Node
|
| 30 |
-
from .basecontainer import BaseContainer
|
| 31 |
-
from .errors import (
|
| 32 |
-
ConfigAttributeError,
|
| 33 |
-
ConfigTypeError,
|
| 34 |
-
ConfigValueError,
|
| 35 |
-
KeyValidationError,
|
| 36 |
-
MissingMandatoryValue,
|
| 37 |
-
ReadonlyConfigError,
|
| 38 |
-
ValidationError,
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class ListConfig(BaseContainer, MutableSequence[Any]):
|
| 43 |
-
|
| 44 |
-
_content: Union[List[Node], None, str]
|
| 45 |
-
|
| 46 |
-
def __init__(
|
| 47 |
-
self,
|
| 48 |
-
content: Union[List[Any], Tuple[Any, ...], "ListConfig", str, None],
|
| 49 |
-
key: Any = None,
|
| 50 |
-
parent: Optional[Box] = None,
|
| 51 |
-
element_type: Union[Type[Any], Any] = Any,
|
| 52 |
-
is_optional: bool = True,
|
| 53 |
-
ref_type: Union[Type[Any], Any] = Any,
|
| 54 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 55 |
-
) -> None:
|
| 56 |
-
try:
|
| 57 |
-
if isinstance(content, ListConfig):
|
| 58 |
-
if flags is None:
|
| 59 |
-
flags = content._metadata.flags
|
| 60 |
-
super().__init__(
|
| 61 |
-
parent=parent,
|
| 62 |
-
metadata=ContainerMetadata(
|
| 63 |
-
ref_type=ref_type,
|
| 64 |
-
object_type=list,
|
| 65 |
-
key=key,
|
| 66 |
-
optional=is_optional,
|
| 67 |
-
element_type=element_type,
|
| 68 |
-
key_type=int,
|
| 69 |
-
flags=flags,
|
| 70 |
-
),
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
if isinstance(content, ListConfig):
|
| 74 |
-
metadata = copy.deepcopy(content._metadata)
|
| 75 |
-
metadata.key = key
|
| 76 |
-
metadata.ref_type = ref_type
|
| 77 |
-
metadata.optional = is_optional
|
| 78 |
-
metadata.element_type = element_type
|
| 79 |
-
self.__dict__["_metadata"] = metadata
|
| 80 |
-
self._set_value(value=content, flags=flags)
|
| 81 |
-
except Exception as ex:
|
| 82 |
-
format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
|
| 83 |
-
|
| 84 |
-
def _validate_get(self, key: Any, value: Any = None) -> None:
|
| 85 |
-
if not isinstance(key, (int, slice)):
|
| 86 |
-
raise KeyValidationError(
|
| 87 |
-
"ListConfig indices must be integers or slices, not $KEY_TYPE"
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
def _validate_set(self, key: Any, value: Any) -> None:
|
| 91 |
-
from omegaconf import OmegaConf
|
| 92 |
-
|
| 93 |
-
self._validate_get(key, value)
|
| 94 |
-
|
| 95 |
-
if self._get_flag("readonly"):
|
| 96 |
-
raise ReadonlyConfigError("ListConfig is read-only")
|
| 97 |
-
|
| 98 |
-
if 0 <= key < self.__len__():
|
| 99 |
-
target = self._get_node(key)
|
| 100 |
-
if target is not None:
|
| 101 |
-
assert isinstance(target, Node)
|
| 102 |
-
if value is None and not target._is_optional():
|
| 103 |
-
raise ValidationError(
|
| 104 |
-
"$FULL_KEY is not optional and cannot be assigned None"
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
vk = get_value_kind(value)
|
| 108 |
-
if vk == ValueKind.MANDATORY_MISSING:
|
| 109 |
-
return
|
| 110 |
-
else:
|
| 111 |
-
is_optional, target_type = _resolve_optional(self._metadata.element_type)
|
| 112 |
-
value_type = OmegaConf.get_type(value)
|
| 113 |
-
|
| 114 |
-
if (value_type is None and not is_optional) or (
|
| 115 |
-
is_structured_config(target_type)
|
| 116 |
-
and value_type is not None
|
| 117 |
-
and not issubclass(value_type, target_type)
|
| 118 |
-
):
|
| 119 |
-
msg = (
|
| 120 |
-
f"Invalid type assigned: {type_str(value_type)} is not a "
|
| 121 |
-
f"subclass of {type_str(target_type)}. value: {value}"
|
| 122 |
-
)
|
| 123 |
-
raise ValidationError(msg)
|
| 124 |
-
|
| 125 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
|
| 126 |
-
res = ListConfig(None)
|
| 127 |
-
res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
|
| 128 |
-
res.__dict__["_flags_cache"] = copy.deepcopy(
|
| 129 |
-
self.__dict__["_flags_cache"], memo=memo
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
src_content = self.__dict__["_content"]
|
| 133 |
-
if isinstance(src_content, list):
|
| 134 |
-
content_copy: List[Optional[Node]] = []
|
| 135 |
-
for v in src_content:
|
| 136 |
-
old_parent = v.__dict__["_parent"]
|
| 137 |
-
try:
|
| 138 |
-
v.__dict__["_parent"] = None
|
| 139 |
-
vc = copy.deepcopy(v, memo=memo)
|
| 140 |
-
vc.__dict__["_parent"] = res
|
| 141 |
-
content_copy.append(vc)
|
| 142 |
-
finally:
|
| 143 |
-
v.__dict__["_parent"] = old_parent
|
| 144 |
-
else:
|
| 145 |
-
# None and strings can be assigned as is
|
| 146 |
-
content_copy = src_content
|
| 147 |
-
|
| 148 |
-
res.__dict__["_content"] = content_copy
|
| 149 |
-
res.__dict__["_parent"] = self.__dict__["_parent"]
|
| 150 |
-
|
| 151 |
-
return res
|
| 152 |
-
|
| 153 |
-
def copy(self) -> "ListConfig":
|
| 154 |
-
return copy.copy(self)
|
| 155 |
-
|
| 156 |
-
# hide content while inspecting in debugger
|
| 157 |
-
def __dir__(self) -> Iterable[str]:
|
| 158 |
-
if self._is_missing() or self._is_none():
|
| 159 |
-
return []
|
| 160 |
-
return [str(x) for x in range(0, len(self))]
|
| 161 |
-
|
| 162 |
-
def __setattr__(self, key: str, value: Any) -> None:
|
| 163 |
-
self._format_and_raise(
|
| 164 |
-
key=key,
|
| 165 |
-
value=value,
|
| 166 |
-
cause=ConfigAttributeError("ListConfig does not support attribute access"),
|
| 167 |
-
)
|
| 168 |
-
assert False
|
| 169 |
-
|
| 170 |
-
def __getattr__(self, key: str) -> Any:
|
| 171 |
-
# PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that.
|
| 172 |
-
if key == "__members__":
|
| 173 |
-
raise AttributeError()
|
| 174 |
-
|
| 175 |
-
if key == "__name__":
|
| 176 |
-
raise AttributeError()
|
| 177 |
-
|
| 178 |
-
if is_int(key):
|
| 179 |
-
return self.__getitem__(int(key))
|
| 180 |
-
else:
|
| 181 |
-
self._format_and_raise(
|
| 182 |
-
key=key,
|
| 183 |
-
value=None,
|
| 184 |
-
cause=ConfigAttributeError(
|
| 185 |
-
"ListConfig does not support attribute access"
|
| 186 |
-
),
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
def __getitem__(self, index: Union[int, slice]) -> Any:
|
| 190 |
-
try:
|
| 191 |
-
if self._is_missing():
|
| 192 |
-
raise MissingMandatoryValue("ListConfig is missing")
|
| 193 |
-
self._validate_get(index, None)
|
| 194 |
-
if self._is_none():
|
| 195 |
-
raise TypeError(
|
| 196 |
-
"ListConfig object representing None is not subscriptable"
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
assert isinstance(self.__dict__["_content"], list)
|
| 200 |
-
if isinstance(index, slice):
|
| 201 |
-
result = []
|
| 202 |
-
start, stop, step = self._correct_index_params(index)
|
| 203 |
-
for slice_idx in itertools.islice(
|
| 204 |
-
range(0, len(self)), start, stop, step
|
| 205 |
-
):
|
| 206 |
-
val = self._resolve_with_default(
|
| 207 |
-
key=slice_idx, value=self.__dict__["_content"][slice_idx]
|
| 208 |
-
)
|
| 209 |
-
result.append(val)
|
| 210 |
-
if index.step and index.step < 0:
|
| 211 |
-
result.reverse()
|
| 212 |
-
return result
|
| 213 |
-
else:
|
| 214 |
-
return self._resolve_with_default(
|
| 215 |
-
key=index, value=self.__dict__["_content"][index]
|
| 216 |
-
)
|
| 217 |
-
except Exception as e:
|
| 218 |
-
self._format_and_raise(key=index, value=None, cause=e)
|
| 219 |
-
|
| 220 |
-
def _correct_index_params(self, index: slice) -> Tuple[int, int, int]:
|
| 221 |
-
start = index.start
|
| 222 |
-
stop = index.stop
|
| 223 |
-
step = index.step
|
| 224 |
-
if index.start and index.start < 0:
|
| 225 |
-
start = self.__len__() + index.start
|
| 226 |
-
if index.stop and index.stop < 0:
|
| 227 |
-
stop = self.__len__() + index.stop
|
| 228 |
-
if index.step and index.step < 0:
|
| 229 |
-
step = abs(step)
|
| 230 |
-
if start and stop:
|
| 231 |
-
if start > stop:
|
| 232 |
-
start, stop = stop + 1, start + 1
|
| 233 |
-
else:
|
| 234 |
-
start = stop = 0
|
| 235 |
-
elif not start and stop:
|
| 236 |
-
start = list(range(self.__len__() - 1, stop, -step))[0]
|
| 237 |
-
stop = None
|
| 238 |
-
elif start and not stop:
|
| 239 |
-
stop = start + 1
|
| 240 |
-
start = (stop - 1) % step
|
| 241 |
-
else:
|
| 242 |
-
start = (self.__len__() - 1) % step
|
| 243 |
-
return start, stop, step
|
| 244 |
-
|
| 245 |
-
def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
|
| 246 |
-
self._set_item_impl(index, value)
|
| 247 |
-
|
| 248 |
-
def __setitem__(self, index: Union[int, slice], value: Any) -> None:
|
| 249 |
-
try:
|
| 250 |
-
if isinstance(index, slice):
|
| 251 |
-
_ = iter(value) # check iterable
|
| 252 |
-
self_indices = index.indices(len(self))
|
| 253 |
-
indexes = range(*self_indices)
|
| 254 |
-
|
| 255 |
-
# Ensure lengths match for extended slice assignment
|
| 256 |
-
if index.step not in (None, 1):
|
| 257 |
-
if len(indexes) != len(value):
|
| 258 |
-
raise ValueError(
|
| 259 |
-
f"attempt to assign sequence of size {len(value)}"
|
| 260 |
-
f" to extended slice of size {len(indexes)}"
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
# Initialize insertion offsets for empty slices
|
| 264 |
-
if len(indexes) == 0:
|
| 265 |
-
curr_index = self_indices[0] - 1
|
| 266 |
-
val_i = -1
|
| 267 |
-
|
| 268 |
-
work_copy = self.copy() # For atomicity manipulate a copy
|
| 269 |
-
|
| 270 |
-
# Delete and optionally replace non empty slices
|
| 271 |
-
only_removed = 0
|
| 272 |
-
for val_i, i in enumerate(indexes):
|
| 273 |
-
curr_index = i - only_removed
|
| 274 |
-
del work_copy[curr_index]
|
| 275 |
-
if val_i < len(value):
|
| 276 |
-
work_copy.insert(curr_index, value[val_i])
|
| 277 |
-
else:
|
| 278 |
-
only_removed += 1
|
| 279 |
-
|
| 280 |
-
# Insert any remaining input items
|
| 281 |
-
for val_i in range(val_i + 1, len(value)):
|
| 282 |
-
curr_index += 1
|
| 283 |
-
work_copy.insert(curr_index, value[val_i])
|
| 284 |
-
|
| 285 |
-
# Reinitialize self with work_copy
|
| 286 |
-
self.clear()
|
| 287 |
-
self.extend(work_copy)
|
| 288 |
-
else:
|
| 289 |
-
self._set_at_index(index, value)
|
| 290 |
-
except Exception as e:
|
| 291 |
-
self._format_and_raise(key=index, value=value, cause=e)
|
| 292 |
-
|
| 293 |
-
def append(self, item: Any) -> None:
|
| 294 |
-
content = self.__dict__["_content"]
|
| 295 |
-
index = len(content)
|
| 296 |
-
content.append(None)
|
| 297 |
-
try:
|
| 298 |
-
self._set_item_impl(index, item)
|
| 299 |
-
except Exception as e:
|
| 300 |
-
del content[index]
|
| 301 |
-
self._format_and_raise(key=index, value=item, cause=e)
|
| 302 |
-
assert False
|
| 303 |
-
|
| 304 |
-
def _update_keys(self) -> None:
|
| 305 |
-
for i in range(len(self)):
|
| 306 |
-
node = self._get_node(i)
|
| 307 |
-
if node is not None:
|
| 308 |
-
assert isinstance(node, Node)
|
| 309 |
-
node._metadata.key = i
|
| 310 |
-
|
| 311 |
-
def insert(self, index: int, item: Any) -> None:
|
| 312 |
-
from omegaconf.omegaconf import _maybe_wrap
|
| 313 |
-
|
| 314 |
-
try:
|
| 315 |
-
if self._get_flag("readonly"):
|
| 316 |
-
raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
|
| 317 |
-
if self._is_none():
|
| 318 |
-
raise TypeError(
|
| 319 |
-
"Cannot insert into ListConfig object representing None"
|
| 320 |
-
)
|
| 321 |
-
if self._is_missing():
|
| 322 |
-
raise MissingMandatoryValue("Cannot insert into missing ListConfig")
|
| 323 |
-
|
| 324 |
-
try:
|
| 325 |
-
assert isinstance(self.__dict__["_content"], list)
|
| 326 |
-
# insert place holder
|
| 327 |
-
self.__dict__["_content"].insert(index, None)
|
| 328 |
-
is_optional, ref_type = _resolve_optional(self._metadata.element_type)
|
| 329 |
-
node = _maybe_wrap(
|
| 330 |
-
ref_type=ref_type,
|
| 331 |
-
key=index,
|
| 332 |
-
value=item,
|
| 333 |
-
is_optional=is_optional,
|
| 334 |
-
parent=self,
|
| 335 |
-
)
|
| 336 |
-
self._validate_set(key=index, value=node)
|
| 337 |
-
self._set_at_index(index, node)
|
| 338 |
-
self._update_keys()
|
| 339 |
-
except Exception:
|
| 340 |
-
del self.__dict__["_content"][index]
|
| 341 |
-
self._update_keys()
|
| 342 |
-
raise
|
| 343 |
-
except Exception as e:
|
| 344 |
-
self._format_and_raise(key=index, value=item, cause=e)
|
| 345 |
-
assert False
|
| 346 |
-
|
| 347 |
-
def extend(self, lst: Iterable[Any]) -> None:
|
| 348 |
-
assert isinstance(lst, (tuple, list, ListConfig))
|
| 349 |
-
for x in lst:
|
| 350 |
-
self.append(x)
|
| 351 |
-
|
| 352 |
-
def remove(self, x: Any) -> None:
|
| 353 |
-
del self[self.index(x)]
|
| 354 |
-
|
| 355 |
-
def __delitem__(self, key: Union[int, slice]) -> None:
|
| 356 |
-
if self._get_flag("readonly"):
|
| 357 |
-
self._format_and_raise(
|
| 358 |
-
key=key,
|
| 359 |
-
value=None,
|
| 360 |
-
cause=ReadonlyConfigError(
|
| 361 |
-
"Cannot delete item from read-only ListConfig"
|
| 362 |
-
),
|
| 363 |
-
)
|
| 364 |
-
del self.__dict__["_content"][key]
|
| 365 |
-
self._update_keys()
|
| 366 |
-
|
| 367 |
-
def clear(self) -> None:
|
| 368 |
-
del self[:]
|
| 369 |
-
|
| 370 |
-
def index(
|
| 371 |
-
self, x: Any, start: Optional[int] = None, end: Optional[int] = None
|
| 372 |
-
) -> int:
|
| 373 |
-
if start is None:
|
| 374 |
-
start = 0
|
| 375 |
-
if end is None:
|
| 376 |
-
end = len(self)
|
| 377 |
-
assert start >= 0
|
| 378 |
-
assert end <= len(self)
|
| 379 |
-
found_idx = -1
|
| 380 |
-
for idx in range(start, end):
|
| 381 |
-
item = self[idx]
|
| 382 |
-
if x == item:
|
| 383 |
-
found_idx = idx
|
| 384 |
-
break
|
| 385 |
-
if found_idx != -1:
|
| 386 |
-
return found_idx
|
| 387 |
-
else:
|
| 388 |
-
self._format_and_raise(
|
| 389 |
-
key=None,
|
| 390 |
-
value=None,
|
| 391 |
-
cause=ConfigValueError("Item not found in ListConfig"),
|
| 392 |
-
)
|
| 393 |
-
assert False
|
| 394 |
-
|
| 395 |
-
def count(self, x: Any) -> int:
|
| 396 |
-
c = 0
|
| 397 |
-
for item in self:
|
| 398 |
-
if item == x:
|
| 399 |
-
c = c + 1
|
| 400 |
-
return c
|
| 401 |
-
|
| 402 |
-
def _get_node(
|
| 403 |
-
self,
|
| 404 |
-
key: Union[int, slice],
|
| 405 |
-
validate_access: bool = True,
|
| 406 |
-
validate_key: bool = True,
|
| 407 |
-
throw_on_missing_value: bool = False,
|
| 408 |
-
throw_on_missing_key: bool = False,
|
| 409 |
-
) -> Union[Optional[Node], List[Optional[Node]]]:
|
| 410 |
-
try:
|
| 411 |
-
if self._is_none():
|
| 412 |
-
raise TypeError(
|
| 413 |
-
"Cannot get_node from a ListConfig object representing None"
|
| 414 |
-
)
|
| 415 |
-
if self._is_missing():
|
| 416 |
-
raise MissingMandatoryValue("Cannot get_node from a missing ListConfig")
|
| 417 |
-
assert isinstance(self.__dict__["_content"], list)
|
| 418 |
-
if validate_access:
|
| 419 |
-
self._validate_get(key)
|
| 420 |
-
|
| 421 |
-
value = self.__dict__["_content"][key]
|
| 422 |
-
if value is not None:
|
| 423 |
-
if isinstance(key, slice):
|
| 424 |
-
assert isinstance(value, list)
|
| 425 |
-
for v in value:
|
| 426 |
-
if throw_on_missing_value and v._is_missing():
|
| 427 |
-
raise MissingMandatoryValue("Missing mandatory value")
|
| 428 |
-
else:
|
| 429 |
-
assert isinstance(value, Node)
|
| 430 |
-
if throw_on_missing_value and value._is_missing():
|
| 431 |
-
raise MissingMandatoryValue("Missing mandatory value: $KEY")
|
| 432 |
-
return value
|
| 433 |
-
except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
|
| 434 |
-
if isinstance(e, MissingMandatoryValue) and throw_on_missing_value:
|
| 435 |
-
raise
|
| 436 |
-
if validate_access:
|
| 437 |
-
self._format_and_raise(key=key, value=None, cause=e)
|
| 438 |
-
assert False
|
| 439 |
-
else:
|
| 440 |
-
return None
|
| 441 |
-
|
| 442 |
-
def get(self, index: int, default_value: Any = None) -> Any:
|
| 443 |
-
try:
|
| 444 |
-
if self._is_none():
|
| 445 |
-
raise TypeError("Cannot get from a ListConfig object representing None")
|
| 446 |
-
if self._is_missing():
|
| 447 |
-
raise MissingMandatoryValue("Cannot get from a missing ListConfig")
|
| 448 |
-
self._validate_get(index, None)
|
| 449 |
-
assert isinstance(self.__dict__["_content"], list)
|
| 450 |
-
return self._resolve_with_default(
|
| 451 |
-
key=index,
|
| 452 |
-
value=self.__dict__["_content"][index],
|
| 453 |
-
default_value=default_value,
|
| 454 |
-
)
|
| 455 |
-
except Exception as e:
|
| 456 |
-
self._format_and_raise(key=index, value=None, cause=e)
|
| 457 |
-
assert False
|
| 458 |
-
|
| 459 |
-
def pop(self, index: int = -1) -> Any:
|
| 460 |
-
try:
|
| 461 |
-
if self._get_flag("readonly"):
|
| 462 |
-
raise ReadonlyConfigError("Cannot pop from read-only ListConfig")
|
| 463 |
-
if self._is_none():
|
| 464 |
-
raise TypeError("Cannot pop from a ListConfig object representing None")
|
| 465 |
-
if self._is_missing():
|
| 466 |
-
raise MissingMandatoryValue("Cannot pop from a missing ListConfig")
|
| 467 |
-
|
| 468 |
-
assert isinstance(self.__dict__["_content"], list)
|
| 469 |
-
node = self._get_child(index)
|
| 470 |
-
assert isinstance(node, Node)
|
| 471 |
-
ret = self._resolve_with_default(key=index, value=node, default_value=None)
|
| 472 |
-
del self.__dict__["_content"][index]
|
| 473 |
-
self._update_keys()
|
| 474 |
-
return ret
|
| 475 |
-
except KeyValidationError as e:
|
| 476 |
-
self._format_and_raise(
|
| 477 |
-
key=index, value=None, cause=e, type_override=ConfigTypeError
|
| 478 |
-
)
|
| 479 |
-
assert False
|
| 480 |
-
except Exception as e:
|
| 481 |
-
self._format_and_raise(key=index, value=None, cause=e)
|
| 482 |
-
assert False
|
| 483 |
-
|
| 484 |
-
def sort(
|
| 485 |
-
self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False
|
| 486 |
-
) -> None:
|
| 487 |
-
try:
|
| 488 |
-
if self._get_flag("readonly"):
|
| 489 |
-
raise ReadonlyConfigError("Cannot sort a read-only ListConfig")
|
| 490 |
-
if self._is_none():
|
| 491 |
-
raise TypeError("Cannot sort a ListConfig object representing None")
|
| 492 |
-
if self._is_missing():
|
| 493 |
-
raise MissingMandatoryValue("Cannot sort a missing ListConfig")
|
| 494 |
-
|
| 495 |
-
if key is None:
|
| 496 |
-
|
| 497 |
-
def key1(x: Any) -> Any:
|
| 498 |
-
return x._value()
|
| 499 |
-
|
| 500 |
-
else:
|
| 501 |
-
|
| 502 |
-
def key1(x: Any) -> Any:
|
| 503 |
-
return key(x._value()) # type: ignore
|
| 504 |
-
|
| 505 |
-
assert isinstance(self.__dict__["_content"], list)
|
| 506 |
-
self.__dict__["_content"].sort(key=key1, reverse=reverse)
|
| 507 |
-
|
| 508 |
-
except Exception as e:
|
| 509 |
-
self._format_and_raise(key=None, value=None, cause=e)
|
| 510 |
-
assert False
|
| 511 |
-
|
| 512 |
-
def __eq__(self, other: Any) -> bool:
|
| 513 |
-
if isinstance(other, (list, tuple)) or other is None:
|
| 514 |
-
other = ListConfig(other, flags={"allow_objects": True})
|
| 515 |
-
return ListConfig._list_eq(self, other)
|
| 516 |
-
if other is None or isinstance(other, ListConfig):
|
| 517 |
-
return ListConfig._list_eq(self, other)
|
| 518 |
-
if self._is_missing():
|
| 519 |
-
return _is_missing_literal(other)
|
| 520 |
-
return NotImplemented
|
| 521 |
-
|
| 522 |
-
def __ne__(self, other: Any) -> bool:
|
| 523 |
-
x = self.__eq__(other)
|
| 524 |
-
if x is not NotImplemented:
|
| 525 |
-
return not x
|
| 526 |
-
return NotImplemented
|
| 527 |
-
|
| 528 |
-
def __hash__(self) -> int:
|
| 529 |
-
return hash(str(self))
|
| 530 |
-
|
| 531 |
-
def __iter__(self) -> Iterator[Any]:
|
| 532 |
-
return self._iter_ex(resolve=True)
|
| 533 |
-
|
| 534 |
-
class ListIterator(Iterator[Any]):
|
| 535 |
-
def __init__(self, lst: Any, resolve: bool) -> None:
|
| 536 |
-
self.resolve = resolve
|
| 537 |
-
self.iterator = iter(lst.__dict__["_content"])
|
| 538 |
-
self.index = 0
|
| 539 |
-
from .nodes import ValueNode
|
| 540 |
-
|
| 541 |
-
self.ValueNode = ValueNode
|
| 542 |
-
|
| 543 |
-
def __next__(self) -> Any:
|
| 544 |
-
|
| 545 |
-
x = next(self.iterator)
|
| 546 |
-
if self.resolve:
|
| 547 |
-
x = x._dereference_node()
|
| 548 |
-
if x._is_missing():
|
| 549 |
-
raise MissingMandatoryValue(f"Missing value at index {self.index}")
|
| 550 |
-
|
| 551 |
-
self.index = self.index + 1
|
| 552 |
-
if isinstance(x, self.ValueNode):
|
| 553 |
-
return x._value()
|
| 554 |
-
else:
|
| 555 |
-
# Must be omegaconf.Container. not checking for perf reasons.
|
| 556 |
-
if x._is_none():
|
| 557 |
-
return None
|
| 558 |
-
return x
|
| 559 |
-
|
| 560 |
-
def __repr__(self) -> str: # pragma: no cover
|
| 561 |
-
return f"ListConfig.ListIterator(resolve={self.resolve})"
|
| 562 |
-
|
| 563 |
-
def _iter_ex(self, resolve: bool) -> Iterator[Any]:
|
| 564 |
-
try:
|
| 565 |
-
if self._is_none():
|
| 566 |
-
raise TypeError("Cannot iterate a ListConfig object representing None")
|
| 567 |
-
if self._is_missing():
|
| 568 |
-
raise MissingMandatoryValue("Cannot iterate a missing ListConfig")
|
| 569 |
-
|
| 570 |
-
return ListConfig.ListIterator(self, resolve)
|
| 571 |
-
except (TypeError, MissingMandatoryValue) as e:
|
| 572 |
-
self._format_and_raise(key=None, value=None, cause=e)
|
| 573 |
-
assert False
|
| 574 |
-
|
| 575 |
-
def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
|
| 576 |
-
# res is sharing this list's parent to allow interpolation to work as expected
|
| 577 |
-
res = ListConfig(parent=self._get_parent(), content=[])
|
| 578 |
-
res.extend(self)
|
| 579 |
-
res.extend(other)
|
| 580 |
-
return res
|
| 581 |
-
|
| 582 |
-
def __radd__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
|
| 583 |
-
# res is sharing this list's parent to allow interpolation to work as expected
|
| 584 |
-
res = ListConfig(parent=self._get_parent(), content=[])
|
| 585 |
-
res.extend(other)
|
| 586 |
-
res.extend(self)
|
| 587 |
-
return res
|
| 588 |
-
|
| 589 |
-
def __iadd__(self, other: Iterable[Any]) -> "ListConfig":
|
| 590 |
-
self.extend(other)
|
| 591 |
-
return self
|
| 592 |
-
|
| 593 |
-
def __contains__(self, item: Any) -> bool:
|
| 594 |
-
if self._is_none():
|
| 595 |
-
raise TypeError(
|
| 596 |
-
"Cannot check if an item is in a ListConfig object representing None"
|
| 597 |
-
)
|
| 598 |
-
if self._is_missing():
|
| 599 |
-
raise MissingMandatoryValue(
|
| 600 |
-
"Cannot check if an item is in missing ListConfig"
|
| 601 |
-
)
|
| 602 |
-
|
| 603 |
-
lst = self.__dict__["_content"]
|
| 604 |
-
for x in lst:
|
| 605 |
-
x = x._dereference_node()
|
| 606 |
-
if x == item:
|
| 607 |
-
return True
|
| 608 |
-
return False
|
| 609 |
-
|
| 610 |
-
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
|
| 611 |
-
try:
|
| 612 |
-
previous_content = self.__dict__["_content"]
|
| 613 |
-
previous_metadata = self.__dict__["_metadata"]
|
| 614 |
-
self._set_value_impl(value, flags)
|
| 615 |
-
except Exception as e:
|
| 616 |
-
self.__dict__["_content"] = previous_content
|
| 617 |
-
self.__dict__["_metadata"] = previous_metadata
|
| 618 |
-
raise e
|
| 619 |
-
|
| 620 |
-
def _set_value_impl(
|
| 621 |
-
self, value: Any, flags: Optional[Dict[str, bool]] = None
|
| 622 |
-
) -> None:
|
| 623 |
-
from omegaconf import MISSING, flag_override
|
| 624 |
-
|
| 625 |
-
if flags is None:
|
| 626 |
-
flags = {}
|
| 627 |
-
|
| 628 |
-
vk = get_value_kind(value, strict_interpolation_validation=True)
|
| 629 |
-
if _is_none(value):
|
| 630 |
-
if not self._is_optional():
|
| 631 |
-
raise ValidationError(
|
| 632 |
-
"Non optional ListConfig cannot be constructed from None"
|
| 633 |
-
)
|
| 634 |
-
self.__dict__["_content"] = None
|
| 635 |
-
self._metadata.object_type = None
|
| 636 |
-
elif vk is ValueKind.MANDATORY_MISSING:
|
| 637 |
-
self.__dict__["_content"] = MISSING
|
| 638 |
-
self._metadata.object_type = None
|
| 639 |
-
elif vk == ValueKind.INTERPOLATION:
|
| 640 |
-
self.__dict__["_content"] = value
|
| 641 |
-
self._metadata.object_type = None
|
| 642 |
-
else:
|
| 643 |
-
if not (is_primitive_list(value) or isinstance(value, ListConfig)):
|
| 644 |
-
type_ = type(value)
|
| 645 |
-
msg = f"Invalid value assigned: {type_.__name__} is not a ListConfig, list or tuple."
|
| 646 |
-
raise ValidationError(msg)
|
| 647 |
-
|
| 648 |
-
self.__dict__["_content"] = []
|
| 649 |
-
if isinstance(value, ListConfig):
|
| 650 |
-
self._metadata.flags = copy.deepcopy(flags)
|
| 651 |
-
# disable struct and readonly for the construction phase
|
| 652 |
-
# retaining other flags like allow_objects. The real flags are restored at the end of this function
|
| 653 |
-
with flag_override(self, ["struct", "readonly"], False):
|
| 654 |
-
for item in value._iter_ex(resolve=False):
|
| 655 |
-
self.append(item)
|
| 656 |
-
elif is_primitive_list(value):
|
| 657 |
-
with flag_override(self, ["struct", "readonly"], False):
|
| 658 |
-
for item in value:
|
| 659 |
-
self.append(item)
|
| 660 |
-
self._metadata.object_type = list
|
| 661 |
-
|
| 662 |
-
@staticmethod
|
| 663 |
-
def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
|
| 664 |
-
l1_none = l1.__dict__["_content"] is None
|
| 665 |
-
l2_none = l2.__dict__["_content"] is None
|
| 666 |
-
if l1_none and l2_none:
|
| 667 |
-
return True
|
| 668 |
-
if l1_none != l2_none:
|
| 669 |
-
return False
|
| 670 |
-
|
| 671 |
-
assert isinstance(l1, ListConfig)
|
| 672 |
-
assert isinstance(l2, ListConfig)
|
| 673 |
-
if len(l1) != len(l2):
|
| 674 |
-
return False
|
| 675 |
-
for i in range(len(l1)):
|
| 676 |
-
if not BaseContainer._item_eq(l1, i, l2, i):
|
| 677 |
-
return False
|
| 678 |
-
|
| 679 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/nodes.py
DELETED
|
@@ -1,545 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import math
|
| 3 |
-
import sys
|
| 4 |
-
from abc import abstractmethod
|
| 5 |
-
from enum import Enum
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any, Dict, Optional, Type, Union
|
| 8 |
-
|
| 9 |
-
from omegaconf._utils import (
|
| 10 |
-
ValueKind,
|
| 11 |
-
_is_interpolation,
|
| 12 |
-
get_type_of,
|
| 13 |
-
get_value_kind,
|
| 14 |
-
is_primitive_container,
|
| 15 |
-
type_str,
|
| 16 |
-
)
|
| 17 |
-
from omegaconf.base import Box, DictKeyType, Metadata, Node
|
| 18 |
-
from omegaconf.errors import ReadonlyConfigError, UnsupportedValueType, ValidationError
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class ValueNode(Node):
|
| 22 |
-
_val: Any
|
| 23 |
-
|
| 24 |
-
def __init__(self, parent: Optional[Box], value: Any, metadata: Metadata):
|
| 25 |
-
from omegaconf import read_write
|
| 26 |
-
|
| 27 |
-
super().__init__(parent=parent, metadata=metadata)
|
| 28 |
-
with read_write(self):
|
| 29 |
-
self._set_value(value) # lgtm [py/init-calls-subclass]
|
| 30 |
-
|
| 31 |
-
def _value(self) -> Any:
|
| 32 |
-
return self._val
|
| 33 |
-
|
| 34 |
-
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
|
| 35 |
-
if self._get_flag("readonly"):
|
| 36 |
-
raise ReadonlyConfigError("Cannot set value of read-only config node")
|
| 37 |
-
|
| 38 |
-
if isinstance(value, str) and get_value_kind(
|
| 39 |
-
value, strict_interpolation_validation=True
|
| 40 |
-
) in (
|
| 41 |
-
ValueKind.INTERPOLATION,
|
| 42 |
-
ValueKind.MANDATORY_MISSING,
|
| 43 |
-
):
|
| 44 |
-
self._val = value
|
| 45 |
-
else:
|
| 46 |
-
self._val = self.validate_and_convert(value)
|
| 47 |
-
|
| 48 |
-
def _strict_validate_type(self, value: Any) -> None:
|
| 49 |
-
ref_type = self._metadata.ref_type
|
| 50 |
-
if isinstance(ref_type, type) and type(value) is not ref_type:
|
| 51 |
-
type_hint = type_str(self._metadata.type_hint)
|
| 52 |
-
raise ValidationError(
|
| 53 |
-
f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_hint}'"
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
def validate_and_convert(self, value: Any) -> Any:
|
| 57 |
-
"""
|
| 58 |
-
Validates input and converts to canonical form
|
| 59 |
-
:param value: input value
|
| 60 |
-
:return: converted value ("100" may be converted to 100 for example)
|
| 61 |
-
"""
|
| 62 |
-
if value is None:
|
| 63 |
-
if self._is_optional():
|
| 64 |
-
return None
|
| 65 |
-
ref_type_str = type_str(self._metadata.ref_type)
|
| 66 |
-
raise ValidationError(
|
| 67 |
-
f"Incompatible value '{value}' for field of type '{ref_type_str}'"
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
# Subclasses can assume that `value` is not None in
|
| 71 |
-
# `_validate_and_convert_impl()` and in `_strict_validate_type()`.
|
| 72 |
-
if self._get_flag("convert") is False:
|
| 73 |
-
self._strict_validate_type(value)
|
| 74 |
-
return value
|
| 75 |
-
else:
|
| 76 |
-
return self._validate_and_convert_impl(value)
|
| 77 |
-
|
| 78 |
-
@abstractmethod
|
| 79 |
-
def _validate_and_convert_impl(self, value: Any) -> Any:
|
| 80 |
-
...
|
| 81 |
-
|
| 82 |
-
def __str__(self) -> str:
|
| 83 |
-
return str(self._val)
|
| 84 |
-
|
| 85 |
-
def __repr__(self) -> str:
|
| 86 |
-
return repr(self._val) if hasattr(self, "_val") else "__INVALID__"
|
| 87 |
-
|
| 88 |
-
def __eq__(self, other: Any) -> bool:
|
| 89 |
-
if isinstance(other, AnyNode):
|
| 90 |
-
return self._val == other._val # type: ignore
|
| 91 |
-
else:
|
| 92 |
-
return self._val == other # type: ignore
|
| 93 |
-
|
| 94 |
-
def __ne__(self, other: Any) -> bool:
|
| 95 |
-
x = self.__eq__(other)
|
| 96 |
-
assert x is not NotImplemented
|
| 97 |
-
return not x
|
| 98 |
-
|
| 99 |
-
def __hash__(self) -> int:
|
| 100 |
-
return hash(self._val)
|
| 101 |
-
|
| 102 |
-
def _deepcopy_impl(self, res: Any, memo: Dict[int, Any]) -> None:
|
| 103 |
-
res.__dict__["_metadata"] = copy.deepcopy(self._metadata, memo=memo)
|
| 104 |
-
# shallow copy for value to support non-copyable value
|
| 105 |
-
res.__dict__["_val"] = self._val
|
| 106 |
-
|
| 107 |
-
# parent is retained, but not copied
|
| 108 |
-
res.__dict__["_parent"] = self._parent
|
| 109 |
-
|
| 110 |
-
def _is_optional(self) -> bool:
|
| 111 |
-
return self._metadata.optional
|
| 112 |
-
|
| 113 |
-
def _is_interpolation(self) -> bool:
|
| 114 |
-
return _is_interpolation(self._value())
|
| 115 |
-
|
| 116 |
-
def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
|
| 117 |
-
parent = self._get_parent()
|
| 118 |
-
if parent is None:
|
| 119 |
-
if self._metadata.key is None:
|
| 120 |
-
return ""
|
| 121 |
-
else:
|
| 122 |
-
return str(self._metadata.key)
|
| 123 |
-
else:
|
| 124 |
-
return parent._get_full_key(self._metadata.key)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
class AnyNode(ValueNode):
|
| 128 |
-
def __init__(
|
| 129 |
-
self,
|
| 130 |
-
value: Any = None,
|
| 131 |
-
key: Any = None,
|
| 132 |
-
parent: Optional[Box] = None,
|
| 133 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 134 |
-
):
|
| 135 |
-
super().__init__(
|
| 136 |
-
parent=parent,
|
| 137 |
-
value=value,
|
| 138 |
-
metadata=Metadata(
|
| 139 |
-
ref_type=Any, object_type=None, key=key, optional=True, flags=flags
|
| 140 |
-
),
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
def _validate_and_convert_impl(self, value: Any) -> Any:
|
| 144 |
-
from ._utils import is_primitive_type_annotation
|
| 145 |
-
|
| 146 |
-
# allow_objects is internal and not an official API. use at your own risk.
|
| 147 |
-
# Please be aware that this support is subject to change without notice.
|
| 148 |
-
# If this is deemed useful and supportable it may become an official API.
|
| 149 |
-
|
| 150 |
-
if self._get_flag(
|
| 151 |
-
"allow_objects"
|
| 152 |
-
) is not True and not is_primitive_type_annotation(value):
|
| 153 |
-
t = get_type_of(value)
|
| 154 |
-
raise UnsupportedValueType(
|
| 155 |
-
f"Value '{t.__name__}' is not a supported primitive type"
|
| 156 |
-
)
|
| 157 |
-
return value
|
| 158 |
-
|
| 159 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "AnyNode":
|
| 160 |
-
res = AnyNode()
|
| 161 |
-
self._deepcopy_impl(res, memo)
|
| 162 |
-
return res
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class StringNode(ValueNode):
|
| 166 |
-
def __init__(
|
| 167 |
-
self,
|
| 168 |
-
value: Any = None,
|
| 169 |
-
key: Any = None,
|
| 170 |
-
parent: Optional[Box] = None,
|
| 171 |
-
is_optional: bool = True,
|
| 172 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 173 |
-
):
|
| 174 |
-
super().__init__(
|
| 175 |
-
parent=parent,
|
| 176 |
-
value=value,
|
| 177 |
-
metadata=Metadata(
|
| 178 |
-
key=key,
|
| 179 |
-
optional=is_optional,
|
| 180 |
-
ref_type=str,
|
| 181 |
-
object_type=str,
|
| 182 |
-
flags=flags,
|
| 183 |
-
),
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
def _validate_and_convert_impl(self, value: Any) -> str:
|
| 187 |
-
from omegaconf import OmegaConf
|
| 188 |
-
|
| 189 |
-
if (
|
| 190 |
-
OmegaConf.is_config(value)
|
| 191 |
-
or is_primitive_container(value)
|
| 192 |
-
or isinstance(value, bytes)
|
| 193 |
-
):
|
| 194 |
-
raise ValidationError("Cannot convert '$VALUE_TYPE' to string: '$VALUE'")
|
| 195 |
-
return str(value)
|
| 196 |
-
|
| 197 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode":
|
| 198 |
-
res = StringNode()
|
| 199 |
-
self._deepcopy_impl(res, memo)
|
| 200 |
-
return res
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
class PathNode(ValueNode):
|
| 204 |
-
def __init__(
|
| 205 |
-
self,
|
| 206 |
-
value: Any = None,
|
| 207 |
-
key: Any = None,
|
| 208 |
-
parent: Optional[Box] = None,
|
| 209 |
-
is_optional: bool = True,
|
| 210 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 211 |
-
):
|
| 212 |
-
super().__init__(
|
| 213 |
-
parent=parent,
|
| 214 |
-
value=value,
|
| 215 |
-
metadata=Metadata(
|
| 216 |
-
key=key,
|
| 217 |
-
optional=is_optional,
|
| 218 |
-
ref_type=Path,
|
| 219 |
-
object_type=Path,
|
| 220 |
-
flags=flags,
|
| 221 |
-
),
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
def _strict_validate_type(self, value: Any) -> None:
|
| 225 |
-
if not isinstance(value, Path):
|
| 226 |
-
raise ValidationError(
|
| 227 |
-
"Value '$VALUE' of type '$VALUE_TYPE' is not an instance of 'pathlib.Path'"
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
def _validate_and_convert_impl(self, value: Any) -> Path:
|
| 231 |
-
if not isinstance(value, (str, Path)):
|
| 232 |
-
raise ValidationError(
|
| 233 |
-
"Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Path"
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
return Path(value)
|
| 237 |
-
|
| 238 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "PathNode":
|
| 239 |
-
res = PathNode()
|
| 240 |
-
self._deepcopy_impl(res, memo)
|
| 241 |
-
return res
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
class IntegerNode(ValueNode):
|
| 245 |
-
def __init__(
|
| 246 |
-
self,
|
| 247 |
-
value: Any = None,
|
| 248 |
-
key: Any = None,
|
| 249 |
-
parent: Optional[Box] = None,
|
| 250 |
-
is_optional: bool = True,
|
| 251 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 252 |
-
):
|
| 253 |
-
super().__init__(
|
| 254 |
-
parent=parent,
|
| 255 |
-
value=value,
|
| 256 |
-
metadata=Metadata(
|
| 257 |
-
key=key,
|
| 258 |
-
optional=is_optional,
|
| 259 |
-
ref_type=int,
|
| 260 |
-
object_type=int,
|
| 261 |
-
flags=flags,
|
| 262 |
-
),
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
def _validate_and_convert_impl(self, value: Any) -> int:
|
| 266 |
-
try:
|
| 267 |
-
if type(value) in (str, int):
|
| 268 |
-
val = int(value)
|
| 269 |
-
else:
|
| 270 |
-
raise ValueError()
|
| 271 |
-
except ValueError:
|
| 272 |
-
raise ValidationError(
|
| 273 |
-
"Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Integer"
|
| 274 |
-
)
|
| 275 |
-
return val
|
| 276 |
-
|
| 277 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "IntegerNode":
|
| 278 |
-
res = IntegerNode()
|
| 279 |
-
self._deepcopy_impl(res, memo)
|
| 280 |
-
return res
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
class BytesNode(ValueNode):
|
| 284 |
-
def __init__(
|
| 285 |
-
self,
|
| 286 |
-
value: Any = None,
|
| 287 |
-
key: Any = None,
|
| 288 |
-
parent: Optional[Box] = None,
|
| 289 |
-
is_optional: bool = True,
|
| 290 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 291 |
-
):
|
| 292 |
-
super().__init__(
|
| 293 |
-
parent=parent,
|
| 294 |
-
value=value,
|
| 295 |
-
metadata=Metadata(
|
| 296 |
-
key=key,
|
| 297 |
-
optional=is_optional,
|
| 298 |
-
ref_type=bytes,
|
| 299 |
-
object_type=bytes,
|
| 300 |
-
flags=flags,
|
| 301 |
-
),
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
def _validate_and_convert_impl(self, value: Any) -> bytes:
|
| 305 |
-
if not isinstance(value, bytes):
|
| 306 |
-
raise ValidationError(
|
| 307 |
-
"Value '$VALUE' of type '$VALUE_TYPE' is not of type 'bytes'"
|
| 308 |
-
)
|
| 309 |
-
return value
|
| 310 |
-
|
| 311 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "BytesNode":
|
| 312 |
-
res = BytesNode()
|
| 313 |
-
self._deepcopy_impl(res, memo)
|
| 314 |
-
return res
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
class FloatNode(ValueNode):
|
| 318 |
-
def __init__(
|
| 319 |
-
self,
|
| 320 |
-
value: Any = None,
|
| 321 |
-
key: Any = None,
|
| 322 |
-
parent: Optional[Box] = None,
|
| 323 |
-
is_optional: bool = True,
|
| 324 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 325 |
-
):
|
| 326 |
-
super().__init__(
|
| 327 |
-
parent=parent,
|
| 328 |
-
value=value,
|
| 329 |
-
metadata=Metadata(
|
| 330 |
-
key=key,
|
| 331 |
-
optional=is_optional,
|
| 332 |
-
ref_type=float,
|
| 333 |
-
object_type=float,
|
| 334 |
-
flags=flags,
|
| 335 |
-
),
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
def _validate_and_convert_impl(self, value: Any) -> float:
|
| 339 |
-
try:
|
| 340 |
-
if type(value) in (float, str, int):
|
| 341 |
-
return float(value)
|
| 342 |
-
else:
|
| 343 |
-
raise ValueError()
|
| 344 |
-
except ValueError:
|
| 345 |
-
raise ValidationError(
|
| 346 |
-
"Value '$VALUE' of type '$VALUE_TYPE' could not be converted to Float"
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
def __eq__(self, other: Any) -> bool:
|
| 350 |
-
if isinstance(other, ValueNode):
|
| 351 |
-
other_val = other._val
|
| 352 |
-
else:
|
| 353 |
-
other_val = other
|
| 354 |
-
if self._val is None and other is None:
|
| 355 |
-
return True
|
| 356 |
-
if self._val is None and other is not None:
|
| 357 |
-
return False
|
| 358 |
-
if self._val is not None and other is None:
|
| 359 |
-
return False
|
| 360 |
-
nan1 = math.isnan(self._val) if isinstance(self._val, float) else False
|
| 361 |
-
nan2 = math.isnan(other_val) if isinstance(other_val, float) else False
|
| 362 |
-
return self._val == other_val or (nan1 and nan2)
|
| 363 |
-
|
| 364 |
-
def __hash__(self) -> int:
|
| 365 |
-
return hash(self._val)
|
| 366 |
-
|
| 367 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "FloatNode":
|
| 368 |
-
res = FloatNode()
|
| 369 |
-
self._deepcopy_impl(res, memo)
|
| 370 |
-
return res
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
class BooleanNode(ValueNode):
|
| 374 |
-
def __init__(
|
| 375 |
-
self,
|
| 376 |
-
value: Any = None,
|
| 377 |
-
key: Any = None,
|
| 378 |
-
parent: Optional[Box] = None,
|
| 379 |
-
is_optional: bool = True,
|
| 380 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 381 |
-
):
|
| 382 |
-
super().__init__(
|
| 383 |
-
parent=parent,
|
| 384 |
-
value=value,
|
| 385 |
-
metadata=Metadata(
|
| 386 |
-
key=key,
|
| 387 |
-
optional=is_optional,
|
| 388 |
-
ref_type=bool,
|
| 389 |
-
object_type=bool,
|
| 390 |
-
flags=flags,
|
| 391 |
-
),
|
| 392 |
-
)
|
| 393 |
-
|
| 394 |
-
def _validate_and_convert_impl(self, value: Any) -> bool:
|
| 395 |
-
if isinstance(value, bool):
|
| 396 |
-
return value
|
| 397 |
-
if isinstance(value, int):
|
| 398 |
-
return value != 0
|
| 399 |
-
elif isinstance(value, str):
|
| 400 |
-
try:
|
| 401 |
-
return self._validate_and_convert_impl(int(value))
|
| 402 |
-
except ValueError as e:
|
| 403 |
-
if value.lower() in ("yes", "y", "on", "true"):
|
| 404 |
-
return True
|
| 405 |
-
elif value.lower() in ("no", "n", "off", "false"):
|
| 406 |
-
return False
|
| 407 |
-
else:
|
| 408 |
-
raise ValidationError(
|
| 409 |
-
"Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
|
| 410 |
-
).with_traceback(sys.exc_info()[2]) from e
|
| 411 |
-
else:
|
| 412 |
-
raise ValidationError(
|
| 413 |
-
"Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "BooleanNode":
|
| 417 |
-
res = BooleanNode()
|
| 418 |
-
self._deepcopy_impl(res, memo)
|
| 419 |
-
return res
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
class EnumNode(ValueNode): # lgtm [py/missing-equals] : Intentional.
|
| 423 |
-
"""
|
| 424 |
-
NOTE: EnumNode is serialized to yaml as a string ("Color.BLUE"), not as a fully qualified yaml type.
|
| 425 |
-
this means serialization to YAML of a typed config (with EnumNode) will not retain the type of the Enum
|
| 426 |
-
when loaded.
|
| 427 |
-
This is intentional, Please open an issue against OmegaConf if you wish to discuss this decision.
|
| 428 |
-
"""
|
| 429 |
-
|
| 430 |
-
def __init__(
|
| 431 |
-
self,
|
| 432 |
-
enum_type: Type[Enum],
|
| 433 |
-
value: Optional[Union[Enum, str]] = None,
|
| 434 |
-
key: Any = None,
|
| 435 |
-
parent: Optional[Box] = None,
|
| 436 |
-
is_optional: bool = True,
|
| 437 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 438 |
-
):
|
| 439 |
-
if not isinstance(enum_type, type) or not issubclass(enum_type, Enum):
|
| 440 |
-
raise ValidationError(
|
| 441 |
-
f"EnumNode can only operate on Enum subclasses ({enum_type})"
|
| 442 |
-
)
|
| 443 |
-
self.fields: Dict[str, str] = {}
|
| 444 |
-
self.enum_type: Type[Enum] = enum_type
|
| 445 |
-
for name, constant in enum_type.__members__.items():
|
| 446 |
-
self.fields[name] = constant.value
|
| 447 |
-
super().__init__(
|
| 448 |
-
parent=parent,
|
| 449 |
-
value=value,
|
| 450 |
-
metadata=Metadata(
|
| 451 |
-
key=key,
|
| 452 |
-
optional=is_optional,
|
| 453 |
-
ref_type=enum_type,
|
| 454 |
-
object_type=enum_type,
|
| 455 |
-
flags=flags,
|
| 456 |
-
),
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
def _strict_validate_type(self, value: Any) -> None:
|
| 460 |
-
ref_type = self._metadata.ref_type
|
| 461 |
-
if not isinstance(value, ref_type):
|
| 462 |
-
type_hint = type_str(self._metadata.type_hint)
|
| 463 |
-
raise ValidationError(
|
| 464 |
-
f"Value '$VALUE' of type '$VALUE_TYPE' is incompatible with type hint '{type_hint}'"
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
def _validate_and_convert_impl(self, value: Any) -> Enum:
|
| 468 |
-
return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value)
|
| 469 |
-
|
| 470 |
-
@staticmethod
|
| 471 |
-
def validate_and_convert_to_enum(enum_type: Type[Enum], value: Any) -> Enum:
|
| 472 |
-
if not isinstance(value, (str, int)) and not isinstance(value, enum_type):
|
| 473 |
-
raise ValidationError(
|
| 474 |
-
f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}"
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
if isinstance(value, enum_type):
|
| 478 |
-
return value
|
| 479 |
-
|
| 480 |
-
try:
|
| 481 |
-
if isinstance(value, (float, bool)):
|
| 482 |
-
raise ValueError
|
| 483 |
-
|
| 484 |
-
if isinstance(value, int):
|
| 485 |
-
return enum_type(value)
|
| 486 |
-
|
| 487 |
-
if isinstance(value, str):
|
| 488 |
-
prefix = f"{enum_type.__name__}."
|
| 489 |
-
if value.startswith(prefix):
|
| 490 |
-
value = value[len(prefix) :]
|
| 491 |
-
return enum_type[value]
|
| 492 |
-
|
| 493 |
-
assert False
|
| 494 |
-
|
| 495 |
-
except (ValueError, KeyError) as e:
|
| 496 |
-
valid = ", ".join([x for x in enum_type.__members__.keys()])
|
| 497 |
-
raise ValidationError(
|
| 498 |
-
f"Invalid value '$VALUE', expected one of [{valid}]"
|
| 499 |
-
).with_traceback(sys.exc_info()[2]) from e
|
| 500 |
-
|
| 501 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode":
|
| 502 |
-
res = EnumNode(enum_type=self.enum_type)
|
| 503 |
-
self._deepcopy_impl(res, memo)
|
| 504 |
-
return res
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
class InterpolationResultNode(ValueNode):
|
| 508 |
-
"""
|
| 509 |
-
Special node type, used to wrap interpolation results.
|
| 510 |
-
"""
|
| 511 |
-
|
| 512 |
-
def __init__(
|
| 513 |
-
self,
|
| 514 |
-
value: Any,
|
| 515 |
-
key: Any = None,
|
| 516 |
-
parent: Optional[Box] = None,
|
| 517 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 518 |
-
):
|
| 519 |
-
super().__init__(
|
| 520 |
-
parent=parent,
|
| 521 |
-
value=value,
|
| 522 |
-
metadata=Metadata(
|
| 523 |
-
ref_type=Any, object_type=None, key=key, optional=True, flags=flags
|
| 524 |
-
),
|
| 525 |
-
)
|
| 526 |
-
# In general we should not try to write into interpolation results.
|
| 527 |
-
if flags is None or "readonly" not in flags:
|
| 528 |
-
self._set_flag("readonly", True)
|
| 529 |
-
|
| 530 |
-
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
|
| 531 |
-
if self._get_flag("readonly"):
|
| 532 |
-
raise ReadonlyConfigError("Cannot set value of read-only config node")
|
| 533 |
-
self._val = self.validate_and_convert(value)
|
| 534 |
-
|
| 535 |
-
def _validate_and_convert_impl(self, value: Any) -> Any:
|
| 536 |
-
# Interpolation results may be anything.
|
| 537 |
-
return value
|
| 538 |
-
|
| 539 |
-
def __deepcopy__(self, memo: Dict[int, Any]) -> "InterpolationResultNode":
|
| 540 |
-
# Currently there should be no need to deep-copy such nodes.
|
| 541 |
-
raise NotImplementedError
|
| 542 |
-
|
| 543 |
-
def _is_interpolation(self) -> bool:
|
| 544 |
-
# The result of an interpolation cannot be itself an interpolation.
|
| 545 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/omegaconf.py
DELETED
|
@@ -1,1157 +0,0 @@
|
|
| 1 |
-
"""OmegaConf module"""
|
| 2 |
-
import copy
|
| 3 |
-
import inspect
|
| 4 |
-
import io
|
| 5 |
-
import os
|
| 6 |
-
import pathlib
|
| 7 |
-
import sys
|
| 8 |
-
import warnings
|
| 9 |
-
from collections import defaultdict
|
| 10 |
-
from contextlib import contextmanager
|
| 11 |
-
from enum import Enum
|
| 12 |
-
from textwrap import dedent
|
| 13 |
-
from typing import (
|
| 14 |
-
IO,
|
| 15 |
-
Any,
|
| 16 |
-
Callable,
|
| 17 |
-
Dict,
|
| 18 |
-
Generator,
|
| 19 |
-
Iterable,
|
| 20 |
-
List,
|
| 21 |
-
Optional,
|
| 22 |
-
Set,
|
| 23 |
-
Tuple,
|
| 24 |
-
Type,
|
| 25 |
-
Union,
|
| 26 |
-
overload,
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
import yaml
|
| 30 |
-
|
| 31 |
-
from . import DictConfig, DictKeyType, ListConfig
|
| 32 |
-
from ._utils import (
|
| 33 |
-
_DEFAULT_MARKER_,
|
| 34 |
-
_ensure_container,
|
| 35 |
-
_get_value,
|
| 36 |
-
format_and_raise,
|
| 37 |
-
get_dict_key_value_types,
|
| 38 |
-
get_list_element_type,
|
| 39 |
-
get_omega_conf_dumper,
|
| 40 |
-
get_type_of,
|
| 41 |
-
is_attr_class,
|
| 42 |
-
is_dataclass,
|
| 43 |
-
is_dict_annotation,
|
| 44 |
-
is_int,
|
| 45 |
-
is_list_annotation,
|
| 46 |
-
is_primitive_container,
|
| 47 |
-
is_primitive_dict,
|
| 48 |
-
is_primitive_list,
|
| 49 |
-
is_structured_config,
|
| 50 |
-
is_tuple_annotation,
|
| 51 |
-
is_union_annotation,
|
| 52 |
-
nullcontext,
|
| 53 |
-
split_key,
|
| 54 |
-
type_str,
|
| 55 |
-
)
|
| 56 |
-
from .base import Box, Container, Node, SCMode, UnionNode
|
| 57 |
-
from .basecontainer import BaseContainer
|
| 58 |
-
from .errors import (
|
| 59 |
-
MissingMandatoryValue,
|
| 60 |
-
OmegaConfBaseException,
|
| 61 |
-
UnsupportedInterpolationType,
|
| 62 |
-
ValidationError,
|
| 63 |
-
)
|
| 64 |
-
from .nodes import (
|
| 65 |
-
AnyNode,
|
| 66 |
-
BooleanNode,
|
| 67 |
-
BytesNode,
|
| 68 |
-
EnumNode,
|
| 69 |
-
FloatNode,
|
| 70 |
-
IntegerNode,
|
| 71 |
-
PathNode,
|
| 72 |
-
StringNode,
|
| 73 |
-
ValueNode,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
MISSING: Any = "???"
|
| 77 |
-
|
| 78 |
-
Resolver = Callable[..., Any]
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def II(interpolation: str) -> Any:
|
| 82 |
-
"""
|
| 83 |
-
Equivalent to ``${interpolation}``
|
| 84 |
-
|
| 85 |
-
:param interpolation:
|
| 86 |
-
:return: input ``${node}`` with type Any
|
| 87 |
-
"""
|
| 88 |
-
return "${" + interpolation + "}"
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def SI(interpolation: str) -> Any:
|
| 92 |
-
"""
|
| 93 |
-
Use this for String interpolation, for example ``"http://${host}:${port}"``
|
| 94 |
-
|
| 95 |
-
:param interpolation: interpolation string
|
| 96 |
-
:return: input interpolation with type ``Any``
|
| 97 |
-
"""
|
| 98 |
-
return interpolation
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def register_default_resolvers() -> None:
|
| 102 |
-
from omegaconf.resolvers import oc
|
| 103 |
-
|
| 104 |
-
OmegaConf.register_new_resolver("oc.create", oc.create)
|
| 105 |
-
OmegaConf.register_new_resolver("oc.decode", oc.decode)
|
| 106 |
-
OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
|
| 107 |
-
OmegaConf.register_new_resolver("oc.env", oc.env)
|
| 108 |
-
OmegaConf.register_new_resolver("oc.select", oc.select)
|
| 109 |
-
OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys)
|
| 110 |
-
OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
class OmegaConf:
|
| 114 |
-
"""OmegaConf primary class"""
|
| 115 |
-
|
| 116 |
-
def __init__(self) -> None:
|
| 117 |
-
raise NotImplementedError("Use one of the static construction functions")
|
| 118 |
-
|
| 119 |
-
@staticmethod
|
| 120 |
-
def structured(
|
| 121 |
-
obj: Any,
|
| 122 |
-
parent: Optional[BaseContainer] = None,
|
| 123 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 124 |
-
) -> Any:
|
| 125 |
-
return OmegaConf.create(obj, parent, flags)
|
| 126 |
-
|
| 127 |
-
@staticmethod
|
| 128 |
-
@overload
|
| 129 |
-
def create(
|
| 130 |
-
obj: str,
|
| 131 |
-
parent: Optional[BaseContainer] = None,
|
| 132 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 133 |
-
) -> Union[DictConfig, ListConfig]:
|
| 134 |
-
...
|
| 135 |
-
|
| 136 |
-
@staticmethod
|
| 137 |
-
@overload
|
| 138 |
-
def create(
|
| 139 |
-
obj: Union[List[Any], Tuple[Any, ...]],
|
| 140 |
-
parent: Optional[BaseContainer] = None,
|
| 141 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 142 |
-
) -> ListConfig:
|
| 143 |
-
...
|
| 144 |
-
|
| 145 |
-
@staticmethod
|
| 146 |
-
@overload
|
| 147 |
-
def create(
|
| 148 |
-
obj: DictConfig,
|
| 149 |
-
parent: Optional[BaseContainer] = None,
|
| 150 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 151 |
-
) -> DictConfig:
|
| 152 |
-
...
|
| 153 |
-
|
| 154 |
-
@staticmethod
|
| 155 |
-
@overload
|
| 156 |
-
def create(
|
| 157 |
-
obj: ListConfig,
|
| 158 |
-
parent: Optional[BaseContainer] = None,
|
| 159 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 160 |
-
) -> ListConfig:
|
| 161 |
-
...
|
| 162 |
-
|
| 163 |
-
@staticmethod
|
| 164 |
-
@overload
|
| 165 |
-
def create(
|
| 166 |
-
obj: Optional[Dict[Any, Any]] = None,
|
| 167 |
-
parent: Optional[BaseContainer] = None,
|
| 168 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 169 |
-
) -> DictConfig:
|
| 170 |
-
...
|
| 171 |
-
|
| 172 |
-
@staticmethod
|
| 173 |
-
def create( # noqa F811
|
| 174 |
-
obj: Any = _DEFAULT_MARKER_,
|
| 175 |
-
parent: Optional[BaseContainer] = None,
|
| 176 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 177 |
-
) -> Union[DictConfig, ListConfig]:
|
| 178 |
-
return OmegaConf._create_impl(
|
| 179 |
-
obj=obj,
|
| 180 |
-
parent=parent,
|
| 181 |
-
flags=flags,
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
@staticmethod
|
| 185 |
-
def load(file_: Union[str, pathlib.Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
|
| 186 |
-
from ._utils import get_yaml_loader
|
| 187 |
-
|
| 188 |
-
if isinstance(file_, (str, pathlib.Path)):
|
| 189 |
-
with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f:
|
| 190 |
-
obj = yaml.load(f, Loader=get_yaml_loader())
|
| 191 |
-
elif getattr(file_, "read", None):
|
| 192 |
-
obj = yaml.load(file_, Loader=get_yaml_loader())
|
| 193 |
-
else:
|
| 194 |
-
raise TypeError("Unexpected file type")
|
| 195 |
-
|
| 196 |
-
if obj is not None and not isinstance(obj, (list, dict, str)):
|
| 197 |
-
raise IOError( # pragma: no cover
|
| 198 |
-
f"Invalid loaded object type: {type(obj).__name__}"
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
ret: Union[DictConfig, ListConfig]
|
| 202 |
-
if obj is None:
|
| 203 |
-
ret = OmegaConf.create()
|
| 204 |
-
else:
|
| 205 |
-
ret = OmegaConf.create(obj)
|
| 206 |
-
return ret
|
| 207 |
-
|
| 208 |
-
@staticmethod
|
| 209 |
-
def save(
|
| 210 |
-
config: Any, f: Union[str, pathlib.Path, IO[Any]], resolve: bool = False
|
| 211 |
-
) -> None:
|
| 212 |
-
"""
|
| 213 |
-
Save as configuration object to a file
|
| 214 |
-
|
| 215 |
-
:param config: omegaconf.Config object (DictConfig or ListConfig).
|
| 216 |
-
:param f: filename or file object
|
| 217 |
-
:param resolve: True to save a resolved config (defaults to False)
|
| 218 |
-
"""
|
| 219 |
-
if is_dataclass(config) or is_attr_class(config):
|
| 220 |
-
config = OmegaConf.create(config)
|
| 221 |
-
data = OmegaConf.to_yaml(config, resolve=resolve)
|
| 222 |
-
if isinstance(f, (str, pathlib.Path)):
|
| 223 |
-
with io.open(os.path.abspath(f), "w", encoding="utf-8") as file:
|
| 224 |
-
file.write(data)
|
| 225 |
-
elif hasattr(f, "write"):
|
| 226 |
-
f.write(data)
|
| 227 |
-
f.flush()
|
| 228 |
-
else:
|
| 229 |
-
raise TypeError("Unexpected file type")
|
| 230 |
-
|
| 231 |
-
@staticmethod
|
| 232 |
-
def from_cli(args_list: Optional[List[str]] = None) -> DictConfig:
|
| 233 |
-
if args_list is None:
|
| 234 |
-
# Skip program name
|
| 235 |
-
args_list = sys.argv[1:]
|
| 236 |
-
return OmegaConf.from_dotlist(args_list)
|
| 237 |
-
|
| 238 |
-
@staticmethod
|
| 239 |
-
def from_dotlist(dotlist: List[str]) -> DictConfig:
|
| 240 |
-
"""
|
| 241 |
-
Creates config from the content sys.argv or from the specified args list of not None
|
| 242 |
-
|
| 243 |
-
:param dotlist: A list of dotlist-style strings, e.g. ``["foo.bar=1", "baz=qux"]``.
|
| 244 |
-
:return: A ``DictConfig`` object created from the dotlist.
|
| 245 |
-
"""
|
| 246 |
-
conf = OmegaConf.create()
|
| 247 |
-
conf.merge_with_dotlist(dotlist)
|
| 248 |
-
return conf
|
| 249 |
-
|
| 250 |
-
@staticmethod
|
| 251 |
-
def merge(
|
| 252 |
-
*configs: Union[
|
| 253 |
-
DictConfig,
|
| 254 |
-
ListConfig,
|
| 255 |
-
Dict[DictKeyType, Any],
|
| 256 |
-
List[Any],
|
| 257 |
-
Tuple[Any, ...],
|
| 258 |
-
Any,
|
| 259 |
-
],
|
| 260 |
-
) -> Union[ListConfig, DictConfig]:
|
| 261 |
-
"""
|
| 262 |
-
Merge a list of previously created configs into a single one
|
| 263 |
-
|
| 264 |
-
:param configs: Input configs
|
| 265 |
-
:return: the merged config object.
|
| 266 |
-
"""
|
| 267 |
-
assert len(configs) > 0
|
| 268 |
-
target = copy.deepcopy(configs[0])
|
| 269 |
-
target = _ensure_container(target)
|
| 270 |
-
assert isinstance(target, (DictConfig, ListConfig))
|
| 271 |
-
|
| 272 |
-
with flag_override(target, "readonly", False):
|
| 273 |
-
target.merge_with(*configs[1:])
|
| 274 |
-
turned_readonly = target._get_flag("readonly") is True
|
| 275 |
-
|
| 276 |
-
if turned_readonly:
|
| 277 |
-
OmegaConf.set_readonly(target, True)
|
| 278 |
-
|
| 279 |
-
return target
|
| 280 |
-
|
| 281 |
-
@staticmethod
|
| 282 |
-
def unsafe_merge(
|
| 283 |
-
*configs: Union[
|
| 284 |
-
DictConfig,
|
| 285 |
-
ListConfig,
|
| 286 |
-
Dict[DictKeyType, Any],
|
| 287 |
-
List[Any],
|
| 288 |
-
Tuple[Any, ...],
|
| 289 |
-
Any,
|
| 290 |
-
],
|
| 291 |
-
) -> Union[ListConfig, DictConfig]:
|
| 292 |
-
"""
|
| 293 |
-
Merge a list of previously created configs into a single one
|
| 294 |
-
This is much faster than OmegaConf.merge() as the input configs are not copied.
|
| 295 |
-
However, the input configs must not be used after this operation as will become inconsistent.
|
| 296 |
-
|
| 297 |
-
:param configs: Input configs
|
| 298 |
-
:return: the merged config object.
|
| 299 |
-
"""
|
| 300 |
-
assert len(configs) > 0
|
| 301 |
-
target = configs[0]
|
| 302 |
-
target = _ensure_container(target)
|
| 303 |
-
assert isinstance(target, (DictConfig, ListConfig))
|
| 304 |
-
|
| 305 |
-
with flag_override(
|
| 306 |
-
target, ["readonly", "no_deepcopy_set_nodes"], [False, True]
|
| 307 |
-
):
|
| 308 |
-
target.merge_with(*configs[1:])
|
| 309 |
-
turned_readonly = target._get_flag("readonly") is True
|
| 310 |
-
|
| 311 |
-
if turned_readonly:
|
| 312 |
-
OmegaConf.set_readonly(target, True)
|
| 313 |
-
|
| 314 |
-
return target
|
| 315 |
-
|
| 316 |
-
@staticmethod
|
| 317 |
-
def register_resolver(name: str, resolver: Resolver) -> None:
|
| 318 |
-
warnings.warn(
|
| 319 |
-
dedent(
|
| 320 |
-
"""\
|
| 321 |
-
register_resolver() is deprecated.
|
| 322 |
-
See https://github.com/omry/omegaconf/issues/426 for migration instructions.
|
| 323 |
-
"""
|
| 324 |
-
),
|
| 325 |
-
stacklevel=2,
|
| 326 |
-
)
|
| 327 |
-
return OmegaConf.legacy_register_resolver(name, resolver)
|
| 328 |
-
|
| 329 |
-
# This function will eventually be deprecated and removed.
|
| 330 |
-
@staticmethod
|
| 331 |
-
def legacy_register_resolver(name: str, resolver: Resolver) -> None:
|
| 332 |
-
assert callable(resolver), "resolver must be callable"
|
| 333 |
-
# noinspection PyProtectedMember
|
| 334 |
-
assert (
|
| 335 |
-
name not in BaseContainer._resolvers
|
| 336 |
-
), f"resolver '{name}' is already registered"
|
| 337 |
-
|
| 338 |
-
def resolver_wrapper(
|
| 339 |
-
config: BaseContainer,
|
| 340 |
-
parent: BaseContainer,
|
| 341 |
-
node: Node,
|
| 342 |
-
args: Tuple[Any, ...],
|
| 343 |
-
args_str: Tuple[str, ...],
|
| 344 |
-
) -> Any:
|
| 345 |
-
cache = OmegaConf.get_cache(config)[name]
|
| 346 |
-
# "Un-escape " spaces and commas.
|
| 347 |
-
args_unesc = [x.replace(r"\ ", " ").replace(r"\,", ",") for x in args_str]
|
| 348 |
-
|
| 349 |
-
# Nested interpolations behave in a potentially surprising way with
|
| 350 |
-
# legacy resolvers (they remain as strings, e.g., "${foo}"). If any
|
| 351 |
-
# input looks like an interpolation we thus raise an exception.
|
| 352 |
-
try:
|
| 353 |
-
bad_arg = next(i for i in args_unesc if "${" in i)
|
| 354 |
-
except StopIteration:
|
| 355 |
-
pass
|
| 356 |
-
else:
|
| 357 |
-
raise ValueError(
|
| 358 |
-
f"Resolver '{name}' was called with argument '{bad_arg}' that appears "
|
| 359 |
-
f"to be an interpolation. Nested interpolations are not supported for "
|
| 360 |
-
f"resolvers registered with `[legacy_]register_resolver()`, please use "
|
| 361 |
-
f"`register_new_resolver()` instead (see "
|
| 362 |
-
f"https://github.com/omry/omegaconf/issues/426 for migration instructions)."
|
| 363 |
-
)
|
| 364 |
-
key = args_str
|
| 365 |
-
val = cache[key] if key in cache else resolver(*args_unesc)
|
| 366 |
-
cache[key] = val
|
| 367 |
-
return val
|
| 368 |
-
|
| 369 |
-
# noinspection PyProtectedMember
|
| 370 |
-
BaseContainer._resolvers[name] = resolver_wrapper
|
| 371 |
-
|
| 372 |
-
@staticmethod
|
| 373 |
-
def register_new_resolver(
|
| 374 |
-
name: str,
|
| 375 |
-
resolver: Resolver,
|
| 376 |
-
*,
|
| 377 |
-
replace: bool = False,
|
| 378 |
-
use_cache: bool = False,
|
| 379 |
-
) -> None:
|
| 380 |
-
"""
|
| 381 |
-
Register a resolver.
|
| 382 |
-
|
| 383 |
-
:param name: Name of the resolver.
|
| 384 |
-
:param resolver: Callable whose arguments are provided in the interpolation,
|
| 385 |
-
e.g., with ${foo:x,0,${y.z}} these arguments are respectively "x" (str),
|
| 386 |
-
0 (int) and the value of ``y.z``.
|
| 387 |
-
:param replace: If set to ``False`` (default), then a ``ValueError`` is raised if
|
| 388 |
-
an existing resolver has already been registered with the same name.
|
| 389 |
-
If set to ``True``, then the new resolver replaces the previous one.
|
| 390 |
-
NOTE: The cache on existing config objects is not affected, use
|
| 391 |
-
``OmegaConf.clear_cache(cfg)`` to clear it.
|
| 392 |
-
:param use_cache: Whether the resolver's outputs should be cached. The cache is
|
| 393 |
-
based only on the string literals representing the resolver arguments, e.g.,
|
| 394 |
-
${foo:${bar}} will always return the same value regardless of the value of
|
| 395 |
-
``bar`` if the cache is enabled for ``foo``.
|
| 396 |
-
"""
|
| 397 |
-
if not callable(resolver):
|
| 398 |
-
raise TypeError("resolver must be callable")
|
| 399 |
-
if not name:
|
| 400 |
-
raise ValueError("cannot use an empty resolver name")
|
| 401 |
-
|
| 402 |
-
if not replace and OmegaConf.has_resolver(name):
|
| 403 |
-
raise ValueError(f"resolver '{name}' is already registered")
|
| 404 |
-
|
| 405 |
-
try:
|
| 406 |
-
sig: Optional[inspect.Signature] = inspect.signature(resolver)
|
| 407 |
-
except ValueError:
|
| 408 |
-
sig = None
|
| 409 |
-
|
| 410 |
-
def _should_pass(special: str) -> bool:
|
| 411 |
-
ret = sig is not None and special in sig.parameters
|
| 412 |
-
if ret and use_cache:
|
| 413 |
-
raise ValueError(
|
| 414 |
-
f"use_cache=True is incompatible with functions that receive the {special}"
|
| 415 |
-
)
|
| 416 |
-
return ret
|
| 417 |
-
|
| 418 |
-
pass_parent = _should_pass("_parent_")
|
| 419 |
-
pass_node = _should_pass("_node_")
|
| 420 |
-
pass_root = _should_pass("_root_")
|
| 421 |
-
|
| 422 |
-
def resolver_wrapper(
|
| 423 |
-
config: BaseContainer,
|
| 424 |
-
parent: Container,
|
| 425 |
-
node: Node,
|
| 426 |
-
args: Tuple[Any, ...],
|
| 427 |
-
args_str: Tuple[str, ...],
|
| 428 |
-
) -> Any:
|
| 429 |
-
if use_cache:
|
| 430 |
-
cache = OmegaConf.get_cache(config)[name]
|
| 431 |
-
try:
|
| 432 |
-
return cache[args_str]
|
| 433 |
-
except KeyError:
|
| 434 |
-
pass
|
| 435 |
-
|
| 436 |
-
# Call resolver.
|
| 437 |
-
kwargs: Dict[str, Node] = {}
|
| 438 |
-
if pass_parent:
|
| 439 |
-
kwargs["_parent_"] = parent
|
| 440 |
-
if pass_node:
|
| 441 |
-
kwargs["_node_"] = node
|
| 442 |
-
if pass_root:
|
| 443 |
-
kwargs["_root_"] = config
|
| 444 |
-
|
| 445 |
-
ret = resolver(*args, **kwargs)
|
| 446 |
-
|
| 447 |
-
if use_cache:
|
| 448 |
-
cache[args_str] = ret
|
| 449 |
-
return ret
|
| 450 |
-
|
| 451 |
-
# noinspection PyProtectedMember
|
| 452 |
-
BaseContainer._resolvers[name] = resolver_wrapper
|
| 453 |
-
|
| 454 |
-
@classmethod
|
| 455 |
-
def has_resolver(cls, name: str) -> bool:
|
| 456 |
-
return cls._get_resolver(name) is not None
|
| 457 |
-
|
| 458 |
-
# noinspection PyProtectedMember
|
| 459 |
-
@staticmethod
|
| 460 |
-
def clear_resolvers() -> None:
|
| 461 |
-
"""
|
| 462 |
-
Clear(remove) all OmegaConf resolvers, then re-register OmegaConf's default resolvers.
|
| 463 |
-
"""
|
| 464 |
-
BaseContainer._resolvers = {}
|
| 465 |
-
register_default_resolvers()
|
| 466 |
-
|
| 467 |
-
@classmethod
|
| 468 |
-
def clear_resolver(cls, name: str) -> bool:
|
| 469 |
-
"""
|
| 470 |
-
Clear(remove) any resolver only if it exists.
|
| 471 |
-
|
| 472 |
-
Returns a bool: True if resolver is removed and False if not removed.
|
| 473 |
-
|
| 474 |
-
.. warning:
|
| 475 |
-
This method can remove deafult resolvers as well.
|
| 476 |
-
|
| 477 |
-
:param name: Name of the resolver.
|
| 478 |
-
:return: A bool (``True`` if resolver is removed, ``False`` if not found before removing).
|
| 479 |
-
"""
|
| 480 |
-
if cls.has_resolver(name):
|
| 481 |
-
BaseContainer._resolvers.pop(name)
|
| 482 |
-
return True
|
| 483 |
-
else:
|
| 484 |
-
# return False if resolver does not exist
|
| 485 |
-
return False
|
| 486 |
-
|
| 487 |
-
@staticmethod
|
| 488 |
-
def get_cache(conf: BaseContainer) -> Dict[str, Any]:
|
| 489 |
-
return conf._metadata.resolver_cache
|
| 490 |
-
|
| 491 |
-
@staticmethod
|
| 492 |
-
def set_cache(conf: BaseContainer, cache: Dict[str, Any]) -> None:
|
| 493 |
-
conf._metadata.resolver_cache = copy.deepcopy(cache)
|
| 494 |
-
|
| 495 |
-
@staticmethod
|
| 496 |
-
def clear_cache(conf: BaseContainer) -> None:
|
| 497 |
-
OmegaConf.set_cache(conf, defaultdict(dict, {}))
|
| 498 |
-
|
| 499 |
-
@staticmethod
|
| 500 |
-
def copy_cache(from_config: BaseContainer, to_config: BaseContainer) -> None:
|
| 501 |
-
OmegaConf.set_cache(to_config, OmegaConf.get_cache(from_config))
|
| 502 |
-
|
| 503 |
-
@staticmethod
|
| 504 |
-
def set_readonly(conf: Node, value: Optional[bool]) -> None:
|
| 505 |
-
# noinspection PyProtectedMember
|
| 506 |
-
conf._set_flag("readonly", value)
|
| 507 |
-
|
| 508 |
-
@staticmethod
|
| 509 |
-
def is_readonly(conf: Node) -> Optional[bool]:
|
| 510 |
-
# noinspection PyProtectedMember
|
| 511 |
-
return conf._get_flag("readonly")
|
| 512 |
-
|
| 513 |
-
@staticmethod
|
| 514 |
-
def set_struct(conf: Container, value: Optional[bool]) -> None:
|
| 515 |
-
# noinspection PyProtectedMember
|
| 516 |
-
conf._set_flag("struct", value)
|
| 517 |
-
|
| 518 |
-
@staticmethod
|
| 519 |
-
def is_struct(conf: Container) -> Optional[bool]:
|
| 520 |
-
# noinspection PyProtectedMember
|
| 521 |
-
return conf._get_flag("struct")
|
| 522 |
-
|
| 523 |
-
@staticmethod
|
| 524 |
-
def masked_copy(conf: DictConfig, keys: Union[str, List[str]]) -> DictConfig:
|
| 525 |
-
"""
|
| 526 |
-
Create a masked copy of of this config that contains a subset of the keys
|
| 527 |
-
|
| 528 |
-
:param conf: DictConfig object
|
| 529 |
-
:param keys: keys to preserve in the copy
|
| 530 |
-
:return: The masked ``DictConfig`` object.
|
| 531 |
-
"""
|
| 532 |
-
from .dictconfig import DictConfig
|
| 533 |
-
|
| 534 |
-
if not isinstance(conf, DictConfig):
|
| 535 |
-
raise ValueError("masked_copy is only supported for DictConfig")
|
| 536 |
-
|
| 537 |
-
if isinstance(keys, str):
|
| 538 |
-
keys = [keys]
|
| 539 |
-
content = {key: value for key, value in conf.items_ex(resolve=False, keys=keys)}
|
| 540 |
-
return DictConfig(content=content)
|
| 541 |
-
|
| 542 |
-
@staticmethod
|
| 543 |
-
def to_container(
|
| 544 |
-
cfg: Any,
|
| 545 |
-
*,
|
| 546 |
-
resolve: bool = False,
|
| 547 |
-
throw_on_missing: bool = False,
|
| 548 |
-
enum_to_str: bool = False,
|
| 549 |
-
structured_config_mode: SCMode = SCMode.DICT,
|
| 550 |
-
) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
|
| 551 |
-
"""
|
| 552 |
-
Resursively converts an OmegaConf config to a primitive container (dict or list).
|
| 553 |
-
|
| 554 |
-
:param cfg: the config to convert
|
| 555 |
-
:param resolve: True to resolve all values
|
| 556 |
-
:param throw_on_missing: When True, raise MissingMandatoryValue if any missing values are present.
|
| 557 |
-
When False (the default), replace missing values with the string "???" in the output container.
|
| 558 |
-
:param enum_to_str: True to convert Enum keys and values to strings
|
| 559 |
-
:param structured_config_mode: Specify how Structured Configs (DictConfigs backed by a dataclass) are handled.
|
| 560 |
-
- By default (``structured_config_mode=SCMode.DICT``) structured configs are converted to plain dicts.
|
| 561 |
-
- If ``structured_config_mode=SCMode.DICT_CONFIG``, structured config nodes will remain as DictConfig.
|
| 562 |
-
- If ``structured_config_mode=SCMode.INSTANTIATE``, this function will instantiate structured configs
|
| 563 |
-
(DictConfigs backed by a dataclass), by creating an instance of the underlying dataclass.
|
| 564 |
-
|
| 565 |
-
See also OmegaConf.to_object.
|
| 566 |
-
:return: A dict or a list representing this config as a primitive container.
|
| 567 |
-
"""
|
| 568 |
-
if not OmegaConf.is_config(cfg):
|
| 569 |
-
raise ValueError(
|
| 570 |
-
f"Input cfg is not an OmegaConf config object ({type_str(type(cfg))})"
|
| 571 |
-
)
|
| 572 |
-
|
| 573 |
-
return BaseContainer._to_content(
|
| 574 |
-
cfg,
|
| 575 |
-
resolve=resolve,
|
| 576 |
-
throw_on_missing=throw_on_missing,
|
| 577 |
-
enum_to_str=enum_to_str,
|
| 578 |
-
structured_config_mode=structured_config_mode,
|
| 579 |
-
)
|
| 580 |
-
|
| 581 |
-
@staticmethod
|
| 582 |
-
def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
|
| 583 |
-
"""
|
| 584 |
-
Resursively converts an OmegaConf config to a primitive container (dict or list).
|
| 585 |
-
Any DictConfig objects backed by dataclasses or attrs classes are instantiated
|
| 586 |
-
as instances of those backing classes.
|
| 587 |
-
|
| 588 |
-
This is an alias for OmegaConf.to_container(..., resolve=True, throw_on_missing=True,
|
| 589 |
-
structured_config_mode=SCMode.INSTANTIATE)
|
| 590 |
-
|
| 591 |
-
:param cfg: the config to convert
|
| 592 |
-
:return: A dict or a list or dataclass representing this config.
|
| 593 |
-
"""
|
| 594 |
-
return OmegaConf.to_container(
|
| 595 |
-
cfg=cfg,
|
| 596 |
-
resolve=True,
|
| 597 |
-
throw_on_missing=True,
|
| 598 |
-
enum_to_str=False,
|
| 599 |
-
structured_config_mode=SCMode.INSTANTIATE,
|
| 600 |
-
)
|
| 601 |
-
|
| 602 |
-
@staticmethod
|
| 603 |
-
def is_missing(cfg: Any, key: DictKeyType) -> bool:
|
| 604 |
-
assert isinstance(cfg, Container)
|
| 605 |
-
try:
|
| 606 |
-
node = cfg._get_child(key)
|
| 607 |
-
if node is None:
|
| 608 |
-
return False
|
| 609 |
-
assert isinstance(node, Node)
|
| 610 |
-
return node._is_missing()
|
| 611 |
-
except (UnsupportedInterpolationType, KeyError, AttributeError):
|
| 612 |
-
return False
|
| 613 |
-
|
| 614 |
-
@staticmethod
|
| 615 |
-
def is_interpolation(node: Any, key: Optional[Union[int, str]] = None) -> bool:
|
| 616 |
-
if key is not None:
|
| 617 |
-
assert isinstance(node, Container)
|
| 618 |
-
target = node._get_child(key)
|
| 619 |
-
else:
|
| 620 |
-
target = node
|
| 621 |
-
if target is not None:
|
| 622 |
-
assert isinstance(target, Node)
|
| 623 |
-
return target._is_interpolation()
|
| 624 |
-
return False
|
| 625 |
-
|
| 626 |
-
@staticmethod
|
| 627 |
-
def is_list(obj: Any) -> bool:
|
| 628 |
-
from . import ListConfig
|
| 629 |
-
|
| 630 |
-
return isinstance(obj, ListConfig)
|
| 631 |
-
|
| 632 |
-
@staticmethod
|
| 633 |
-
def is_dict(obj: Any) -> bool:
|
| 634 |
-
from . import DictConfig
|
| 635 |
-
|
| 636 |
-
return isinstance(obj, DictConfig)
|
| 637 |
-
|
| 638 |
-
@staticmethod
|
| 639 |
-
def is_config(obj: Any) -> bool:
|
| 640 |
-
from . import Container
|
| 641 |
-
|
| 642 |
-
return isinstance(obj, Container)
|
| 643 |
-
|
| 644 |
-
@staticmethod
|
| 645 |
-
def get_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]:
|
| 646 |
-
if key is not None:
|
| 647 |
-
c = obj._get_child(key)
|
| 648 |
-
else:
|
| 649 |
-
c = obj
|
| 650 |
-
return OmegaConf._get_obj_type(c)
|
| 651 |
-
|
| 652 |
-
@staticmethod
|
| 653 |
-
def select(
|
| 654 |
-
cfg: Container,
|
| 655 |
-
key: str,
|
| 656 |
-
*,
|
| 657 |
-
default: Any = _DEFAULT_MARKER_,
|
| 658 |
-
throw_on_resolution_failure: bool = True,
|
| 659 |
-
throw_on_missing: bool = False,
|
| 660 |
-
) -> Any:
|
| 661 |
-
"""
|
| 662 |
-
:param cfg: Config node to select from
|
| 663 |
-
:param key: Key to select
|
| 664 |
-
:param default: Default value to return if key is not found
|
| 665 |
-
:param throw_on_resolution_failure: Raise an exception if an interpolation
|
| 666 |
-
resolution error occurs, otherwise return None
|
| 667 |
-
:param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???')
|
| 668 |
-
is made, otherwise return None
|
| 669 |
-
:return: selected value or None if not found.
|
| 670 |
-
"""
|
| 671 |
-
from ._impl import select_value
|
| 672 |
-
|
| 673 |
-
try:
|
| 674 |
-
return select_value(
|
| 675 |
-
cfg=cfg,
|
| 676 |
-
key=key,
|
| 677 |
-
default=default,
|
| 678 |
-
throw_on_resolution_failure=throw_on_resolution_failure,
|
| 679 |
-
throw_on_missing=throw_on_missing,
|
| 680 |
-
)
|
| 681 |
-
except Exception as e:
|
| 682 |
-
format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e))
|
| 683 |
-
|
| 684 |
-
@staticmethod
|
| 685 |
-
def update(
|
| 686 |
-
cfg: Container,
|
| 687 |
-
key: str,
|
| 688 |
-
value: Any = None,
|
| 689 |
-
*,
|
| 690 |
-
merge: bool = True,
|
| 691 |
-
force_add: bool = False,
|
| 692 |
-
) -> None:
|
| 693 |
-
"""
|
| 694 |
-
Updates a dot separated key sequence to a value
|
| 695 |
-
|
| 696 |
-
:param cfg: input config to update
|
| 697 |
-
:param key: key to update (can be a dot separated path)
|
| 698 |
-
:param value: value to set, if value if a list or a dict it will be merged or set
|
| 699 |
-
depending on merge_config_values
|
| 700 |
-
:param merge: If value is a dict or a list, True (default) to merge
|
| 701 |
-
into the destination, False to replace the destination.
|
| 702 |
-
:param force_add: insert the entire path regardless of Struct flag or Structured Config nodes.
|
| 703 |
-
"""
|
| 704 |
-
|
| 705 |
-
split = split_key(key)
|
| 706 |
-
root = cfg
|
| 707 |
-
for i in range(len(split) - 1):
|
| 708 |
-
k = split[i]
|
| 709 |
-
# if next_root is a primitive (string, int etc) replace it with an empty map
|
| 710 |
-
next_root, key_ = _select_one(root, k, throw_on_missing=False)
|
| 711 |
-
if not isinstance(next_root, Container):
|
| 712 |
-
if force_add:
|
| 713 |
-
with flag_override(root, "struct", False):
|
| 714 |
-
root[key_] = {}
|
| 715 |
-
else:
|
| 716 |
-
root[key_] = {}
|
| 717 |
-
root = root[key_]
|
| 718 |
-
|
| 719 |
-
last = split[-1]
|
| 720 |
-
|
| 721 |
-
assert isinstance(
|
| 722 |
-
root, Container
|
| 723 |
-
), f"Unexpected type for root: {type(root).__name__}"
|
| 724 |
-
|
| 725 |
-
last_key: Union[str, int] = last
|
| 726 |
-
if isinstance(root, ListConfig):
|
| 727 |
-
last_key = int(last)
|
| 728 |
-
|
| 729 |
-
ctx = flag_override(root, "struct", False) if force_add else nullcontext()
|
| 730 |
-
with ctx:
|
| 731 |
-
if merge and (OmegaConf.is_config(value) or is_primitive_container(value)):
|
| 732 |
-
assert isinstance(root, BaseContainer)
|
| 733 |
-
node = root._get_child(last_key)
|
| 734 |
-
if OmegaConf.is_config(node):
|
| 735 |
-
assert isinstance(node, BaseContainer)
|
| 736 |
-
node.merge_with(value)
|
| 737 |
-
return
|
| 738 |
-
|
| 739 |
-
if OmegaConf.is_dict(root):
|
| 740 |
-
assert isinstance(last_key, str)
|
| 741 |
-
root.__setattr__(last_key, value)
|
| 742 |
-
elif OmegaConf.is_list(root):
|
| 743 |
-
assert isinstance(last_key, int)
|
| 744 |
-
root.__setitem__(last_key, value)
|
| 745 |
-
else:
|
| 746 |
-
assert False
|
| 747 |
-
|
| 748 |
-
@staticmethod
|
| 749 |
-
def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:
|
| 750 |
-
"""
|
| 751 |
-
returns a yaml dump of this config object.
|
| 752 |
-
|
| 753 |
-
:param cfg: Config object, Structured Config type or instance
|
| 754 |
-
:param resolve: if True, will return a string with the interpolations resolved, otherwise
|
| 755 |
-
interpolations are preserved
|
| 756 |
-
:param sort_keys: If True, will print dict keys in sorted order. default False.
|
| 757 |
-
:return: A string containing the yaml representation.
|
| 758 |
-
"""
|
| 759 |
-
cfg = _ensure_container(cfg)
|
| 760 |
-
container = OmegaConf.to_container(cfg, resolve=resolve, enum_to_str=True)
|
| 761 |
-
return yaml.dump( # type: ignore
|
| 762 |
-
container,
|
| 763 |
-
default_flow_style=False,
|
| 764 |
-
allow_unicode=True,
|
| 765 |
-
sort_keys=sort_keys,
|
| 766 |
-
Dumper=get_omega_conf_dumper(),
|
| 767 |
-
)
|
| 768 |
-
|
| 769 |
-
@staticmethod
|
| 770 |
-
def resolve(cfg: Container) -> None:
|
| 771 |
-
"""
|
| 772 |
-
Resolves all interpolations in the given config object in-place.
|
| 773 |
-
|
| 774 |
-
:param cfg: An OmegaConf container (DictConfig, ListConfig)
|
| 775 |
-
Raises a ValueError if the input object is not an OmegaConf container.
|
| 776 |
-
"""
|
| 777 |
-
import omegaconf._impl
|
| 778 |
-
|
| 779 |
-
if not OmegaConf.is_config(cfg):
|
| 780 |
-
# Since this function is mutating the input object in-place, it doesn't make sense to
|
| 781 |
-
# auto-convert the input object to an OmegaConf container
|
| 782 |
-
raise ValueError(
|
| 783 |
-
f"Invalid config type ({type(cfg).__name__}), expected an OmegaConf Container"
|
| 784 |
-
)
|
| 785 |
-
omegaconf._impl._resolve(cfg)
|
| 786 |
-
|
| 787 |
-
@staticmethod
|
| 788 |
-
def missing_keys(cfg: Any) -> Set[str]:
|
| 789 |
-
"""
|
| 790 |
-
Returns a set of missing keys in a dotlist style.
|
| 791 |
-
|
| 792 |
-
:param cfg: An ``OmegaConf.Container``,
|
| 793 |
-
or a convertible object via ``OmegaConf.create`` (dict, list, ...).
|
| 794 |
-
:return: set of strings of the missing keys.
|
| 795 |
-
:raises ValueError: On input not representing a config.
|
| 796 |
-
"""
|
| 797 |
-
cfg = _ensure_container(cfg)
|
| 798 |
-
missings: Set[str] = set()
|
| 799 |
-
|
| 800 |
-
def gather(_cfg: Container) -> None:
|
| 801 |
-
itr: Iterable[Any]
|
| 802 |
-
if isinstance(_cfg, ListConfig):
|
| 803 |
-
itr = range(len(_cfg))
|
| 804 |
-
else:
|
| 805 |
-
itr = _cfg
|
| 806 |
-
|
| 807 |
-
for key in itr:
|
| 808 |
-
if OmegaConf.is_missing(_cfg, key):
|
| 809 |
-
missings.add(_cfg._get_full_key(key))
|
| 810 |
-
elif OmegaConf.is_config(_cfg[key]):
|
| 811 |
-
gather(_cfg[key])
|
| 812 |
-
|
| 813 |
-
gather(cfg)
|
| 814 |
-
return missings
|
| 815 |
-
|
| 816 |
-
# === private === #
|
| 817 |
-
|
| 818 |
-
@staticmethod
|
| 819 |
-
def _create_impl( # noqa F811
|
| 820 |
-
obj: Any = _DEFAULT_MARKER_,
|
| 821 |
-
parent: Optional[BaseContainer] = None,
|
| 822 |
-
flags: Optional[Dict[str, bool]] = None,
|
| 823 |
-
) -> Union[DictConfig, ListConfig]:
|
| 824 |
-
try:
|
| 825 |
-
from ._utils import get_yaml_loader
|
| 826 |
-
from .dictconfig import DictConfig
|
| 827 |
-
from .listconfig import ListConfig
|
| 828 |
-
|
| 829 |
-
if obj is _DEFAULT_MARKER_:
|
| 830 |
-
obj = {}
|
| 831 |
-
if isinstance(obj, str):
|
| 832 |
-
obj = yaml.load(obj, Loader=get_yaml_loader())
|
| 833 |
-
if obj is None:
|
| 834 |
-
return OmegaConf.create({}, parent=parent, flags=flags)
|
| 835 |
-
elif isinstance(obj, str):
|
| 836 |
-
return OmegaConf.create({obj: None}, parent=parent, flags=flags)
|
| 837 |
-
else:
|
| 838 |
-
assert isinstance(obj, (list, dict))
|
| 839 |
-
return OmegaConf.create(obj, parent=parent, flags=flags)
|
| 840 |
-
|
| 841 |
-
else:
|
| 842 |
-
if (
|
| 843 |
-
is_primitive_dict(obj)
|
| 844 |
-
or OmegaConf.is_dict(obj)
|
| 845 |
-
or is_structured_config(obj)
|
| 846 |
-
or obj is None
|
| 847 |
-
):
|
| 848 |
-
if isinstance(obj, DictConfig):
|
| 849 |
-
return DictConfig(
|
| 850 |
-
content=obj,
|
| 851 |
-
parent=parent,
|
| 852 |
-
ref_type=obj._metadata.ref_type,
|
| 853 |
-
is_optional=obj._metadata.optional,
|
| 854 |
-
key_type=obj._metadata.key_type,
|
| 855 |
-
element_type=obj._metadata.element_type,
|
| 856 |
-
flags=flags,
|
| 857 |
-
)
|
| 858 |
-
else:
|
| 859 |
-
obj_type = OmegaConf.get_type(obj)
|
| 860 |
-
key_type, element_type = get_dict_key_value_types(obj_type)
|
| 861 |
-
return DictConfig(
|
| 862 |
-
content=obj,
|
| 863 |
-
parent=parent,
|
| 864 |
-
key_type=key_type,
|
| 865 |
-
element_type=element_type,
|
| 866 |
-
flags=flags,
|
| 867 |
-
)
|
| 868 |
-
elif is_primitive_list(obj) or OmegaConf.is_list(obj):
|
| 869 |
-
if isinstance(obj, ListConfig):
|
| 870 |
-
return ListConfig(
|
| 871 |
-
content=obj,
|
| 872 |
-
parent=parent,
|
| 873 |
-
element_type=obj._metadata.element_type,
|
| 874 |
-
ref_type=obj._metadata.ref_type,
|
| 875 |
-
is_optional=obj._metadata.optional,
|
| 876 |
-
flags=flags,
|
| 877 |
-
)
|
| 878 |
-
else:
|
| 879 |
-
obj_type = OmegaConf.get_type(obj)
|
| 880 |
-
element_type = get_list_element_type(obj_type)
|
| 881 |
-
return ListConfig(
|
| 882 |
-
content=obj,
|
| 883 |
-
parent=parent,
|
| 884 |
-
element_type=element_type,
|
| 885 |
-
ref_type=Any,
|
| 886 |
-
is_optional=True,
|
| 887 |
-
flags=flags,
|
| 888 |
-
)
|
| 889 |
-
else:
|
| 890 |
-
if isinstance(obj, type):
|
| 891 |
-
raise ValidationError(
|
| 892 |
-
f"Input class '{obj.__name__}' is not a structured config. "
|
| 893 |
-
"did you forget to decorate it as a dataclass?"
|
| 894 |
-
)
|
| 895 |
-
else:
|
| 896 |
-
raise ValidationError(
|
| 897 |
-
f"Object of unsupported type: '{type(obj).__name__}'"
|
| 898 |
-
)
|
| 899 |
-
except OmegaConfBaseException as e:
|
| 900 |
-
format_and_raise(node=None, key=None, value=None, msg=str(e), cause=e)
|
| 901 |
-
assert False
|
| 902 |
-
|
| 903 |
-
@staticmethod
|
| 904 |
-
def _get_obj_type(c: Any) -> Optional[Type[Any]]:
|
| 905 |
-
if is_structured_config(c):
|
| 906 |
-
return get_type_of(c)
|
| 907 |
-
elif c is None:
|
| 908 |
-
return None
|
| 909 |
-
elif isinstance(c, DictConfig):
|
| 910 |
-
if c._is_none():
|
| 911 |
-
return None
|
| 912 |
-
elif c._is_missing():
|
| 913 |
-
return None
|
| 914 |
-
else:
|
| 915 |
-
if is_structured_config(c._metadata.object_type):
|
| 916 |
-
return c._metadata.object_type
|
| 917 |
-
else:
|
| 918 |
-
return dict
|
| 919 |
-
elif isinstance(c, ListConfig):
|
| 920 |
-
return list
|
| 921 |
-
elif isinstance(c, ValueNode):
|
| 922 |
-
return type(c._value())
|
| 923 |
-
elif isinstance(c, UnionNode):
|
| 924 |
-
return type(_get_value(c))
|
| 925 |
-
elif isinstance(c, dict):
|
| 926 |
-
return dict
|
| 927 |
-
elif isinstance(c, (list, tuple)):
|
| 928 |
-
return list
|
| 929 |
-
else:
|
| 930 |
-
return get_type_of(c)
|
| 931 |
-
|
| 932 |
-
@staticmethod
|
| 933 |
-
def _get_resolver(
|
| 934 |
-
name: str,
|
| 935 |
-
) -> Optional[
|
| 936 |
-
Callable[
|
| 937 |
-
[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]],
|
| 938 |
-
Any,
|
| 939 |
-
]
|
| 940 |
-
]:
|
| 941 |
-
# noinspection PyProtectedMember
|
| 942 |
-
return (
|
| 943 |
-
BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
|
| 944 |
-
)
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
# register all default resolvers
|
| 948 |
-
register_default_resolvers()
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
@contextmanager
|
| 952 |
-
def flag_override(
|
| 953 |
-
config: Node,
|
| 954 |
-
names: Union[List[str], str],
|
| 955 |
-
values: Union[List[Optional[bool]], Optional[bool]],
|
| 956 |
-
) -> Generator[Node, None, None]:
|
| 957 |
-
if isinstance(names, str):
|
| 958 |
-
names = [names]
|
| 959 |
-
if values is None or isinstance(values, bool):
|
| 960 |
-
values = [values]
|
| 961 |
-
|
| 962 |
-
prev_states = [config._get_node_flag(name) for name in names]
|
| 963 |
-
|
| 964 |
-
try:
|
| 965 |
-
config._set_flag(names, values)
|
| 966 |
-
yield config
|
| 967 |
-
finally:
|
| 968 |
-
config._set_flag(names, prev_states)
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
@contextmanager
|
| 972 |
-
def read_write(config: Node) -> Generator[Node, None, None]:
|
| 973 |
-
prev_state = config._get_node_flag("readonly")
|
| 974 |
-
try:
|
| 975 |
-
OmegaConf.set_readonly(config, False)
|
| 976 |
-
yield config
|
| 977 |
-
finally:
|
| 978 |
-
OmegaConf.set_readonly(config, prev_state)
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
@contextmanager
|
| 982 |
-
def open_dict(config: Container) -> Generator[Container, None, None]:
|
| 983 |
-
prev_state = config._get_node_flag("struct")
|
| 984 |
-
try:
|
| 985 |
-
OmegaConf.set_struct(config, False)
|
| 986 |
-
yield config
|
| 987 |
-
finally:
|
| 988 |
-
OmegaConf.set_struct(config, prev_state)
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
# === private === #
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
def _node_wrap(
|
| 995 |
-
parent: Optional[Box],
|
| 996 |
-
is_optional: bool,
|
| 997 |
-
value: Any,
|
| 998 |
-
key: Any,
|
| 999 |
-
ref_type: Any = Any,
|
| 1000 |
-
) -> Node:
|
| 1001 |
-
node: Node
|
| 1002 |
-
if is_dict_annotation(ref_type) or (is_primitive_dict(value) and ref_type is Any):
|
| 1003 |
-
key_type, element_type = get_dict_key_value_types(ref_type)
|
| 1004 |
-
node = DictConfig(
|
| 1005 |
-
content=value,
|
| 1006 |
-
key=key,
|
| 1007 |
-
parent=parent,
|
| 1008 |
-
ref_type=ref_type,
|
| 1009 |
-
is_optional=is_optional,
|
| 1010 |
-
key_type=key_type,
|
| 1011 |
-
element_type=element_type,
|
| 1012 |
-
)
|
| 1013 |
-
elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or (
|
| 1014 |
-
type(value) in (list, tuple) and ref_type is Any
|
| 1015 |
-
):
|
| 1016 |
-
element_type = get_list_element_type(ref_type)
|
| 1017 |
-
node = ListConfig(
|
| 1018 |
-
content=value,
|
| 1019 |
-
key=key,
|
| 1020 |
-
parent=parent,
|
| 1021 |
-
is_optional=is_optional,
|
| 1022 |
-
element_type=element_type,
|
| 1023 |
-
ref_type=ref_type,
|
| 1024 |
-
)
|
| 1025 |
-
elif is_structured_config(ref_type) or is_structured_config(value):
|
| 1026 |
-
key_type, element_type = get_dict_key_value_types(value)
|
| 1027 |
-
node = DictConfig(
|
| 1028 |
-
ref_type=ref_type,
|
| 1029 |
-
is_optional=is_optional,
|
| 1030 |
-
content=value,
|
| 1031 |
-
key=key,
|
| 1032 |
-
parent=parent,
|
| 1033 |
-
key_type=key_type,
|
| 1034 |
-
element_type=element_type,
|
| 1035 |
-
)
|
| 1036 |
-
elif is_union_annotation(ref_type):
|
| 1037 |
-
node = UnionNode(
|
| 1038 |
-
content=value,
|
| 1039 |
-
ref_type=ref_type,
|
| 1040 |
-
is_optional=is_optional,
|
| 1041 |
-
key=key,
|
| 1042 |
-
parent=parent,
|
| 1043 |
-
)
|
| 1044 |
-
elif ref_type == Any or ref_type is None:
|
| 1045 |
-
node = AnyNode(value=value, key=key, parent=parent)
|
| 1046 |
-
elif isinstance(ref_type, type) and issubclass(ref_type, Enum):
|
| 1047 |
-
node = EnumNode(
|
| 1048 |
-
enum_type=ref_type,
|
| 1049 |
-
value=value,
|
| 1050 |
-
key=key,
|
| 1051 |
-
parent=parent,
|
| 1052 |
-
is_optional=is_optional,
|
| 1053 |
-
)
|
| 1054 |
-
elif ref_type == int:
|
| 1055 |
-
node = IntegerNode(value=value, key=key, parent=parent, is_optional=is_optional)
|
| 1056 |
-
elif ref_type == float:
|
| 1057 |
-
node = FloatNode(value=value, key=key, parent=parent, is_optional=is_optional)
|
| 1058 |
-
elif ref_type == bool:
|
| 1059 |
-
node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
|
| 1060 |
-
elif ref_type == str:
|
| 1061 |
-
node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
|
| 1062 |
-
elif ref_type == bytes:
|
| 1063 |
-
node = BytesNode(value=value, key=key, parent=parent, is_optional=is_optional)
|
| 1064 |
-
elif ref_type == pathlib.Path:
|
| 1065 |
-
node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional)
|
| 1066 |
-
else:
|
| 1067 |
-
if parent is not None and parent._get_flag("allow_objects") is True:
|
| 1068 |
-
if type(value) in (list, tuple):
|
| 1069 |
-
node = ListConfig(
|
| 1070 |
-
content=value,
|
| 1071 |
-
key=key,
|
| 1072 |
-
parent=parent,
|
| 1073 |
-
ref_type=ref_type,
|
| 1074 |
-
is_optional=is_optional,
|
| 1075 |
-
)
|
| 1076 |
-
elif is_primitive_dict(value):
|
| 1077 |
-
node = DictConfig(
|
| 1078 |
-
content=value,
|
| 1079 |
-
key=key,
|
| 1080 |
-
parent=parent,
|
| 1081 |
-
ref_type=ref_type,
|
| 1082 |
-
is_optional=is_optional,
|
| 1083 |
-
)
|
| 1084 |
-
else:
|
| 1085 |
-
node = AnyNode(value=value, key=key, parent=parent)
|
| 1086 |
-
else:
|
| 1087 |
-
raise ValidationError(f"Unexpected type annotation: {type_str(ref_type)}")
|
| 1088 |
-
return node
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
def _maybe_wrap(
|
| 1092 |
-
ref_type: Any,
|
| 1093 |
-
key: Any,
|
| 1094 |
-
value: Any,
|
| 1095 |
-
is_optional: bool,
|
| 1096 |
-
parent: Optional[BaseContainer],
|
| 1097 |
-
) -> Node:
|
| 1098 |
-
# if already a node, update key and parent and return as is.
|
| 1099 |
-
# NOTE: that this mutate the input node!
|
| 1100 |
-
if isinstance(value, Node):
|
| 1101 |
-
value._set_key(key)
|
| 1102 |
-
value._set_parent(parent)
|
| 1103 |
-
return value
|
| 1104 |
-
else:
|
| 1105 |
-
return _node_wrap(
|
| 1106 |
-
ref_type=ref_type,
|
| 1107 |
-
parent=parent,
|
| 1108 |
-
is_optional=is_optional,
|
| 1109 |
-
value=value,
|
| 1110 |
-
key=key,
|
| 1111 |
-
)
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
def _select_one(
|
| 1115 |
-
c: Container, key: str, throw_on_missing: bool, throw_on_type_error: bool = True
|
| 1116 |
-
) -> Tuple[Optional[Node], Union[str, int]]:
|
| 1117 |
-
from .dictconfig import DictConfig
|
| 1118 |
-
from .listconfig import ListConfig
|
| 1119 |
-
|
| 1120 |
-
ret_key: Union[str, int] = key
|
| 1121 |
-
assert isinstance(c, Container), f"Unexpected type: {c}"
|
| 1122 |
-
if c._is_none():
|
| 1123 |
-
return None, ret_key
|
| 1124 |
-
|
| 1125 |
-
if isinstance(c, DictConfig):
|
| 1126 |
-
assert isinstance(ret_key, str)
|
| 1127 |
-
val = c._get_child(ret_key, validate_access=False)
|
| 1128 |
-
elif isinstance(c, ListConfig):
|
| 1129 |
-
assert isinstance(ret_key, str)
|
| 1130 |
-
if not is_int(ret_key):
|
| 1131 |
-
if throw_on_type_error:
|
| 1132 |
-
raise TypeError(
|
| 1133 |
-
f"Index '{ret_key}' ({type(ret_key).__name__}) is not an int"
|
| 1134 |
-
)
|
| 1135 |
-
else:
|
| 1136 |
-
val = None
|
| 1137 |
-
else:
|
| 1138 |
-
ret_key = int(ret_key)
|
| 1139 |
-
if ret_key < 0 or ret_key + 1 > len(c):
|
| 1140 |
-
val = None
|
| 1141 |
-
else:
|
| 1142 |
-
val = c._get_child(ret_key)
|
| 1143 |
-
else:
|
| 1144 |
-
assert False
|
| 1145 |
-
|
| 1146 |
-
if val is not None:
|
| 1147 |
-
assert isinstance(val, Node)
|
| 1148 |
-
if val._is_missing():
|
| 1149 |
-
if throw_on_missing:
|
| 1150 |
-
raise MissingMandatoryValue(
|
| 1151 |
-
f"Missing mandatory value: {c._get_full_key(ret_key)}"
|
| 1152 |
-
)
|
| 1153 |
-
else:
|
| 1154 |
-
return val, ret_key
|
| 1155 |
-
|
| 1156 |
-
assert val is None or isinstance(val, Node)
|
| 1157 |
-
return val, ret_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/py.typed
DELETED
|
File without changes
|
omegaconf/resolvers/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from omegaconf.resolvers import oc
|
| 2 |
-
|
| 3 |
-
__all__ = [
|
| 4 |
-
"oc",
|
| 5 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/resolvers/oc/__init__.py
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import string
|
| 3 |
-
import warnings
|
| 4 |
-
from typing import Any, Optional
|
| 5 |
-
|
| 6 |
-
from omegaconf import Container, Node
|
| 7 |
-
from omegaconf._utils import _DEFAULT_MARKER_, _get_value
|
| 8 |
-
from omegaconf.basecontainer import BaseContainer
|
| 9 |
-
from omegaconf.errors import ConfigKeyError
|
| 10 |
-
from omegaconf.grammar_parser import parse
|
| 11 |
-
from omegaconf.resolvers.oc import dict
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def create(obj: Any, _parent_: Container) -> Any:
|
| 15 |
-
"""Create a config object from `obj`, similar to `OmegaConf.create`"""
|
| 16 |
-
from omegaconf import OmegaConf
|
| 17 |
-
|
| 18 |
-
assert isinstance(_parent_, BaseContainer)
|
| 19 |
-
return OmegaConf.create(obj, parent=_parent_)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
|
| 23 |
-
"""
|
| 24 |
-
:param key: Environment variable key
|
| 25 |
-
:param default: Optional default value to use in case the key environment variable is not set.
|
| 26 |
-
If default is not a string, it is converted with str(default).
|
| 27 |
-
None default is returned as is.
|
| 28 |
-
:return: The environment variable 'key'. If the environment variable is not set and a default is
|
| 29 |
-
provided, the default is used. If used, the default is converted to a string with str(default).
|
| 30 |
-
If the default is None, None is returned (without a string conversion).
|
| 31 |
-
"""
|
| 32 |
-
try:
|
| 33 |
-
return os.environ[key]
|
| 34 |
-
except KeyError:
|
| 35 |
-
if default is not _DEFAULT_MARKER_:
|
| 36 |
-
return str(default) if default is not None else None
|
| 37 |
-
else:
|
| 38 |
-
raise KeyError(f"Environment variable '{key}' not found")
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def decode(expr: Optional[str], _parent_: Container, _node_: Node) -> Any:
|
| 42 |
-
"""
|
| 43 |
-
Parse and evaluate `expr` according to the `singleElement` rule of the grammar.
|
| 44 |
-
|
| 45 |
-
If `expr` is `None`, then return `None`.
|
| 46 |
-
"""
|
| 47 |
-
if expr is None:
|
| 48 |
-
return None
|
| 49 |
-
|
| 50 |
-
if not isinstance(expr, str):
|
| 51 |
-
raise TypeError(
|
| 52 |
-
f"`oc.decode` can only take strings or None as input, "
|
| 53 |
-
f"but `{expr}` is of type {type(expr).__name__}"
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE")
|
| 57 |
-
val = _parent_.resolve_parse_tree(parse_tree, node=_node_)
|
| 58 |
-
return _get_value(val)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def deprecated(
|
| 62 |
-
key: str,
|
| 63 |
-
message: str = "'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'",
|
| 64 |
-
*,
|
| 65 |
-
_parent_: Container,
|
| 66 |
-
_node_: Node,
|
| 67 |
-
) -> Any:
|
| 68 |
-
from omegaconf._impl import select_node
|
| 69 |
-
|
| 70 |
-
if not isinstance(key, str):
|
| 71 |
-
raise TypeError(
|
| 72 |
-
f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})"
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
if not isinstance(message, str):
|
| 76 |
-
raise TypeError(
|
| 77 |
-
f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})"
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
full_key = _node_._get_full_key(key=None)
|
| 81 |
-
target_node = select_node(_parent_, key, absolute_key=True)
|
| 82 |
-
if target_node is None:
|
| 83 |
-
raise ConfigKeyError(
|
| 84 |
-
f"In oc.deprecated resolver at '{full_key}': Key not found: '{key}'"
|
| 85 |
-
)
|
| 86 |
-
new_key = target_node._get_full_key(key=None)
|
| 87 |
-
msg = string.Template(message).safe_substitute(
|
| 88 |
-
OLD_KEY=full_key,
|
| 89 |
-
NEW_KEY=new_key,
|
| 90 |
-
)
|
| 91 |
-
warnings.warn(category=UserWarning, message=msg)
|
| 92 |
-
return target_node
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def select(
|
| 96 |
-
key: str,
|
| 97 |
-
default: Any = _DEFAULT_MARKER_,
|
| 98 |
-
*,
|
| 99 |
-
_parent_: Container,
|
| 100 |
-
) -> Any:
|
| 101 |
-
from omegaconf._impl import select_value
|
| 102 |
-
|
| 103 |
-
return select_value(cfg=_parent_, key=key, absolute_key=True, default=default)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
__all__ = [
|
| 107 |
-
"create",
|
| 108 |
-
"decode",
|
| 109 |
-
"deprecated",
|
| 110 |
-
"dict",
|
| 111 |
-
"env",
|
| 112 |
-
"select",
|
| 113 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/resolvers/oc/dict.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
| 1 |
-
from typing import Any, List
|
| 2 |
-
|
| 3 |
-
from omegaconf import AnyNode, Container, DictConfig, ListConfig
|
| 4 |
-
from omegaconf._utils import Marker
|
| 5 |
-
from omegaconf.basecontainer import BaseContainer
|
| 6 |
-
from omegaconf.errors import ConfigKeyError
|
| 7 |
-
|
| 8 |
-
_DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def keys(
|
| 12 |
-
key: str,
|
| 13 |
-
_parent_: Container,
|
| 14 |
-
) -> ListConfig:
|
| 15 |
-
from omegaconf import OmegaConf
|
| 16 |
-
|
| 17 |
-
assert isinstance(_parent_, BaseContainer)
|
| 18 |
-
|
| 19 |
-
in_dict = _get_and_validate_dict_input(
|
| 20 |
-
key, parent=_parent_, resolver_name="oc.dict.keys"
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
|
| 24 |
-
assert isinstance(ret, ListConfig)
|
| 25 |
-
return ret
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
|
| 29 |
-
assert isinstance(_parent_, BaseContainer)
|
| 30 |
-
in_dict = _get_and_validate_dict_input(
|
| 31 |
-
key, parent=_parent_, resolver_name="oc.dict.values"
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
content = in_dict._content
|
| 35 |
-
assert isinstance(content, dict)
|
| 36 |
-
|
| 37 |
-
ret = ListConfig([])
|
| 38 |
-
if key.startswith("."):
|
| 39 |
-
key = f".{key}" # extra dot to compensate for extra level of nesting within ret ListConfig
|
| 40 |
-
for k in content:
|
| 41 |
-
ref_node = AnyNode(f"${{{key}.{k!s}}}")
|
| 42 |
-
ret.append(ref_node)
|
| 43 |
-
|
| 44 |
-
# Finalize result by setting proper type and parent.
|
| 45 |
-
element_type: Any = in_dict._metadata.element_type
|
| 46 |
-
ret._metadata.element_type = element_type
|
| 47 |
-
ret._metadata.ref_type = List[element_type]
|
| 48 |
-
ret._set_parent(_parent_)
|
| 49 |
-
|
| 50 |
-
return ret
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _get_and_validate_dict_input(
|
| 54 |
-
key: str,
|
| 55 |
-
parent: BaseContainer,
|
| 56 |
-
resolver_name: str,
|
| 57 |
-
) -> DictConfig:
|
| 58 |
-
from omegaconf._impl import select_value
|
| 59 |
-
|
| 60 |
-
if not isinstance(key, str):
|
| 61 |
-
raise TypeError(
|
| 62 |
-
f"`{resolver_name}` requires a string as input, but obtained `{key}` "
|
| 63 |
-
f"of type: {type(key).__name__}"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
in_dict = select_value(
|
| 67 |
-
parent,
|
| 68 |
-
key,
|
| 69 |
-
throw_on_missing=True,
|
| 70 |
-
absolute_key=True,
|
| 71 |
-
default=_DEFAULT_SELECT_MARKER_,
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
if in_dict is _DEFAULT_SELECT_MARKER_:
|
| 75 |
-
raise ConfigKeyError(f"Key not found: '{key}'")
|
| 76 |
-
|
| 77 |
-
if not isinstance(in_dict, DictConfig):
|
| 78 |
-
raise TypeError(
|
| 79 |
-
f"`{resolver_name}` cannot be applied to objects of type: "
|
| 80 |
-
f"{type(in_dict).__name__}"
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
return in_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
omegaconf/version.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
import sys # pragma: no cover
|
| 2 |
-
|
| 3 |
-
__version__ = "2.3.0"
|
| 4 |
-
|
| 5 |
-
msg = """OmegaConf 2.0 and above is compatible with Python 3.6 and newer.
|
| 6 |
-
You have the following options:
|
| 7 |
-
1. Upgrade to Python 3.6 or newer.
|
| 8 |
-
This is highly recommended. new features will not be added to OmegaConf 1.4.
|
| 9 |
-
2. Continue using OmegaConf 1.4:
|
| 10 |
-
You can pip install 'OmegaConf<1.5' to do that.
|
| 11 |
-
"""
|
| 12 |
-
if sys.version_info < (3, 6):
|
| 13 |
-
raise ImportError(msg) # pragma: no cover
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -7,4 +7,4 @@ tensorboard
|
|
| 7 |
slider==0.8.1
|
| 8 |
torch_tb_profiler
|
| 9 |
rosu_pp_py
|
| 10 |
-
|
|
|
|
| 7 |
slider==0.8.1
|
| 8 |
torch_tb_profiler
|
| 9 |
rosu_pp_py
|
| 10 |
+
omegaconf
|