File size: 7,259 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""Parser for config files that can be used with absl flags."""

from __future__ import annotations

import logging
import re
import sys
import traceback
from typing import Any

from absl import flags
from ml_collections import ConfigDict, FieldReference
from ml_collections.config_flags.config_flags import (
    _ConfigFlag,
    _ErrorConfig,
    _LockConfig,
)

from vis4d.config import copy_and_resolve_references
from vis4d.config.registry import get_config_by_name


class ConfigFileParser(flags.ArgumentParser):  # type: ignore
    """Parser for config files."""

    def __init__(
        self,
        name: str,
        lock_config: bool = True,
        method_name: str = "get_config",
    ) -> None:
        """Initializes the parser.

        Args:
            name (str): The name of the flag (e.g. config for --config flag)
            lock_config (bool, optional): Whether or not to lock the config.
                Defaults to True.
            method_name (str, optional): Name of the method to call in the
                config. Defaults to "get_config".
        """
        self.name = name
        self._lock_config = lock_config
        self.method_name = method_name

    def parse(  # pylint: disable=arguments-renamed
        self, path: str
    ) -> ConfigDict | _ErrorConfig:
        """Loads a config module from `path` and returns the `method_name()`.

        This implementation is based on the original ml_collections and
        modified to allow for a custom method name.

        If a colon is present in `path`, everything to the right of the first
        colon is passed to `method_name` as an argument. This allows the
        structure of what
        is returned to be modified, which is useful when performing complex
        hyperparameter sweeps.

        Args:
          path: string, path pointing to the config file to execute. May also
              contain a config_string argument, e.g. be of the form
              "config.py:some_configuration".

        Returns:
          Result of calling `method_name` in the specified module.
        """
        # This will be a 2 element list iff extra configuration args are
        # present.
        split_path = path.split(":", 1)

        try:
            config = get_config_by_name(
                split_path[0],
                *split_path[1:],
                method_name=self.method_name,
            )
            if config is None:
                logging.warning(
                    "%s:%s() returned None, did you forget a return "
                    "statement?",
                    path,
                    self.method_name,
                )
        except IOError as e:
            # Don't raise the error unless/until the config is
            # actually accessed.
            return _ErrorConfig(e)
        # Third party flags library catches TypeError and ValueError
        # and rethrows,
        # removing useful information unless it is added here (b/63877430):
        except (TypeError, ValueError) as e:
            error_trace = traceback.format_exc()
            raise type(e)(
                "Error whilst parsing config file:\n\n" + error_trace
            )

        if self._lock_config:
            _LockConfig(config)

        return config

    def flag_type(self) -> str:
        """Returns the type of the flag."""
        return "config object"


def DEFINE_config_file(  # pylint: disable=invalid-name
    name: str,
    default: str | None = None,
    help_string: str = "path to config file [.py |.yaml].",
    lock_config: bool = False,
    method_name: str = "get_config",
) -> flags.FlagHolder:  # type: ignore
    """Registers a new flag for a config file.

    Args:
        name (str): The name of the flag (e.g. config for --config flag)
        default (str | None, optional): Default Value. Defaults to None.
        help_string (str, optional): Help String.
            Defaults to "path to config file.".
        lock_config (bool, optional): Whether or note to lock the returned
            config. Defaults to False.
        method_name (str, optional): Name of the method to call in the config.

    Returns:
        flags.FlagHolder: Flag holder instance.
    """
    parser = ConfigFileParser(
        name=name, lock_config=lock_config, method_name=method_name
    )
    flag = _ConfigFlag(
        parser=parser,
        serializer=flags.ArgumentSerializer(),
        name=name,
        default=default,
        help_string=help_string,
        flag_values=flags.FLAGS,
    )

    # Get the module name for the frame at depth 1 in the call stack.
    module_name = sys._getframe(  # pylint: disable=protected-access
        1
    ).f_globals.get("__name__", None)
    module_name = sys.argv[0] if module_name == "__main__" else module_name
    return flags.DEFINE_flag(flag, flags.FLAGS, module_name=module_name)


def pprints_config(data: ConfigDict) -> str:
    """Converts a Config Dict into a string with a .yaml like structure.

    This function differs from __repr__ of ConfigDict in that it will not
    encode python classes using binary formats but just prints the __repr__
    of these classes.

    Args:
        data (ConfigDict): Configuration dict to convert to string

    Returns:
        str: A string representation of the ConfigDict
    """
    return _pprints_config(copy_and_resolve_references(data))


def _pprints_config(  # type: ignore
    data: Any, prefix: str = "", n_indents: int = 1
) -> str:
    """Converts a ConfigDict into a string with a YAML like structure.

    This is the recursive implementation of 'pprints_config' and will be called
    recursively for every element in the dict.

    This function differs from __repr__ of ConfigDict in that it will not
    encode python classes using binary formats but just prints the __repr__
    of these classes.

    Args:
        data (Any): Configuration dict or object to convert to
            string
        prefix (str): Prefix to print on each new line
        n_indents (int): Number of spaces to append for each nester property.

    Returns:
        str: A string representation of the ConfigDict
    """
    string_repr = ""
    if isinstance(data, FieldReference):
        data = data.get()

    if not isinstance(data, (dict, ConfigDict, list, tuple, dict)):
        return str(data)

    string_repr += "\n"

    if isinstance(data, (ConfigDict, dict)):
        for key in data:
            value = data[key]
            string_repr += (
                prefix
                + key
                + ": "
                + _pprints_config(value, prefix=prefix + " " * n_indents)
            ) + "\n"

    elif isinstance(data, (list, tuple)):
        for value in data:
            string_repr += prefix + "- "
            if isinstance(value, (ConfigDict, dict)):
                string_repr += "\n"

            string_repr += (
                _pprints_config(value, prefix=prefix + " " + " " * n_indents)
                + "\n"
            )
        string_repr += " \n"  # Add newline after list for better readability.

    # Clean up some formatting issues using regex. Could be done better
    string_repr = re.sub("\n\n+", "\n", string_repr)
    return re.sub("- +\n +", "- ", string_repr)