File size: 4,793 Bytes
7734c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) Meta Platforms, Inc. and affiliates.
import functools
from typing import Any, Callable, Union

from omegaconf import DictConfig, ListConfig, OmegaConf
from hydra.utils import instantiate

TargetType = Union[str, type, Callable[..., Any]]
ClassOrCallableType = Union[type, Callable[..., Any]]


def dump_config(config: DictConfig, path: str = "./config.yaml"):
    txt = OmegaConf.to_yaml(config, sort_keys=True)
    with open(path, "w") as f:
        f.write(txt)


def locate(path: str) -> Any:
    if path == "":
        raise ImportError("Empty path")

    import builtins
    from importlib import import_module

    parts = [part for part in path.split(".") if part]

    # load module part
    module = None
    for n in reversed(range(len(parts))):
        try:
            mod = ".".join(parts[:n])
            module = import_module(mod)
        except Exception as e:
            if n == 0:
                raise ImportError(f"Error loading module '{path}'") from e
            continue
        if module:
            break

    if module:
        obj = module
    else:
        obj = builtins

    # load object path in module
    for part in parts[n:]:
        mod = mod + "." + part
        if not hasattr(obj, part):
            try:
                import_module(mod)
            except Exception as e:
                raise ImportError(
                    f"Encountered error: `{e}` when loading module '{path}'"
                ) from e
        obj = getattr(obj, part)

    return obj


def full_instance_name(instance: Any) -> str:
    return full_class_name(instance.__class__)


def full_class_name(klass: Any) -> str:
    module = klass.__module__
    if module == "builtins":
        return klass.__qualname__  # avoid outputs like 'builtins.str'
    return module + "." + klass.__qualname__


def ensure_is_subclass(child_class: type, parent_class: type) -> None:
    if not issubclass(child_class, parent_class):
        raise RuntimeError(
            f"class {full_class_name(child_class)} should be a subclass of {full_class_name(parent_class)}"
        )


def find_class_or_callable_from_target(
    target: TargetType,
) -> ClassOrCallableType:
    if isinstance(target, str):
        obj = locate(target)
    else:
        obj = target

    if (not isinstance(obj, type)) and (not callable(obj)):
        raise ValueError(f"Invalid type ({type(obj)}) found for {target}")

    return obj


def find_and_ensure_is_subclass(target: TargetType, type_: type) -> ClassOrCallableType:
    klass = find_class_or_callable_from_target(target)
    ensure_is_subclass(klass, type_)
    return klass


class StrictPartial:
    # remark : the `/` will handle the `path` argument name conflict (e.g. calling StrictPartial("a.b.c", ..., path="/a/b/c"))
    def __init__(self, path, /, *args, **kwargs):
        class_or_callable = find_class_or_callable_from_target(path)
        self._partial = functools.partial(class_or_callable, *args, **kwargs)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self._partial(*args, **kwargs)


class RecursivePartial:
    @staticmethod
    def replace_keys(config, key_mapping):
        def recurse(data):
            if isinstance(data, DictConfig):
                new_data = {
                    key_mapping[k] if k in key_mapping else k: recurse(v)
                    for k, v in data.items()
                }
                new_data = DictConfig(new_data)
            elif isinstance(data, ListConfig):
                new_data = ListConfig([recurse(item) for item in data])
            elif type(data) in {bool, str, int, float, type(None)}:
                new_data = data
            else:
                raise RuntimeError(f"unknow type found : {type(data)}")

            return new_data

        return recurse(config)

    def __init__(self, config):
        self.config = RecursivePartial.replace_keys(
            config, {"_rpartial_target_": "_target_"}
        )

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return instantiate(self.config)


class Partial(StrictPartial):
    # remark : allow `path` argument to be exposed for easier use
    def __init__(self, path, *args, **kwargs):
        super().__init__(path, *args, **kwargs)


def subkey(mapping, key):
    return mapping[key]


def make_set(*args):
    return set(args)


def make_tuple(*args):
    return tuple(args)


def make_list_from_kwargs(**kwargs):
    # Filter out None/null values to avoid issues with callbacks
    return [v for v in kwargs.values() if v is not None]


def make_string(value):
    return str(value)


def make_dict(**kwargs):
    return dict(kwargs)


def get_item(data, key: str):
    return data[key]


def get_attr(data, key: str):
    return getattr(data, key)