Spaces:
Sleeping
Sleeping
File size: 2,103 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import inspect
import typing
import omegaconf
import dataclasses
class ClassRegistry:
def __init__(self):
self.classes = dict()
self.args = dict()
self.arg_keys = None
def __getitem__(self, item):
return self.classes[item]
def make_dataclass_from_init(self, func, name, arg_keys, stop_args):
args = inspect.signature(func).parameters
args = [
(k, typing.Any, omegaconf.MISSING)
if v.default is inspect.Parameter.empty
else (k, typing.Optional[typing.Any], None)
if v.default is None
else (
k,
type(v.default),
dataclasses.field(default=v.default),
)
for k, v in args.items()
]
args = [arg for arg in args if arg[0] not in stop_args]
if arg_keys:
self.arg_keys = arg_keys
arg_classes = dict()
for key in arg_keys:
arg_classes[key] = dataclasses.make_dataclass(key, args)
return dataclasses.make_dataclass(
name,
[
(k, v, dataclasses.field(default=v()))
for k, v in arg_classes.items()
],
)
return dataclasses.make_dataclass(name, args)
def make_dataclass_from_classes(self, name):
return dataclasses.make_dataclass(
name,
[(k, v, dataclasses.field(default=v())) for k, v in self.classes.items()],
)
def make_dataclass_from_args(self, name):
return dataclasses.make_dataclass(
name,
[(k, v, dataclasses.field(default=v())) for k, v in self.args.items()],
)
def add_to_registry(
self, name, arg_keys=None, stop_args=("self", "args", "kwargs")
):
def add_class_by_name(cls):
self.classes[name] = cls
self.args[name] = self.make_dataclass_from_init(
cls.__init__, name, arg_keys, stop_args
)
return cls
return add_class_by_name
|