File size: 7,118 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from collections import defaultdict

from .log_utils import default_logger as logger


class Registrable:
    """Any class that inherits from ``Registrable`` gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod ``@BaseClass.register(name)``.

    After which you can call ``BaseClass.list_available()`` to get the keys for the registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass.

    Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call ``from_params(params)`` on the returned subclass.
    """

    _registry = defaultdict(dict)
    _default_impl = None

    @classmethod
    def register(cls, name, constructor=None, overwrite=False):
        """Register a class under a particular name.
        Args:
            name (str): The name to register the class under.
            constructor (str): optional (default=None)
                The name of the method to use on the class to construct the object.  If this is given,
                we will use this method (which must be a ``classmethod``) instead of the default
                constructor.
            overwrite (bool) : optional (default=False)
                If True, overwrites any existing models registered under ``name``. Else,
                throws an error if a model is already registered under ``name``.

        # Examples
        To use this class, you would typically have a base class that inherits from ``Registrable``:
        ```python
        class Transform(Registrable):
            ...
        ```
        Then, if you want to register a subclass, you decorate it like this:
        ```python
        @Transform.register("shift-transform")
        class ShiftTransform(Transform):
            def __init__(self, param1: int, param2: str):
                ...
        ```
        Registering a class like this will let you instantiate a class from a config file, where you
        give ``"type": "shift-transform"``, and keys corresponding to the parameters of the ``__init__``
        method (note that for this to work, those parameters must have type annotations).
        If you want to have the instantiation from a config file call a method other than the
        constructor, either because you have several different construction paths that could be
        taken for the same object (as we do in ``Transform``) or because you have logic you want to
        happen before you get to the constructor, you can register a specific ``@classmethod`` as the constructor to use.
        """
        registry = Registrable._registry[cls]

        def add_subclass_to_registry(subclass):
            # Add to registry, raise an error if key has already been used.
            if name in registry:
                if overwrite:
                    message = (
                        f"{name} has already been registered as {registry[name][0].__name__}, but "
                        f"overwrite=True, so overwriting with {cls.__name__}"
                    )
                    logger.info(message)
                else:
                    message = (
                        f"Cannot register {name} as {cls.__name__}; "
                        f"name already in use for {registry[name][0].__name__}"
                    )
                    raise RuntimeError(message)
            registry[name] = (subclass, constructor)
            return subclass

        return add_subclass_to_registry

    @classmethod
    def by_name(cls, name):
        """
        Returns a callable function that constructs an argument of the registered class.  Because
        you can register particular functions as constructors for specific names, this isn't
        necessarily the ``__init__`` method of some class.
        """
        logger.debug(f"instantiating registered subclass {name} of {cls}")
        subclass, constructor = cls.resolve_class_name(name)
        if not constructor:
            return subclass
        else:
            return getattr(subclass, constructor)

    @classmethod
    def resolve_class_name(cls, name):
        """
        Returns the subclass that corresponds to the given ``name``, along with the name of the
        method that was registered as a constructor for that ``name``, if any.
        This method also allows ``name`` to be a fully-specified module name, instead of a name that
        was already added to the ``Registry``.  In that case, you cannot use a separate function as
        a constructor (as you need to call ``cls.register()`` in order to tell us what separate
        function to use).
        """
        if name in Registrable._registry[cls]:
            subclass, constructor = Registrable._registry[cls].get(name)
            return subclass, constructor
        else:
            for base_cls, v in Registrable._registry.items():
                if name in v:
                    subclass, constructor = Registrable._registry[base_cls].get(name)
                    return subclass, constructor

        if "." in name:
            # This might be a fully qualified class name, so we'll try importing its "module"
            # and finding it there.
            parts = name.split(".")
            submodule = ".".join(parts[:-1])
            class_name = parts[-1]
            import importlib
            try:
                module = importlib.import_module(submodule)
            except ModuleNotFoundError:
                raise RuntimeError(
                    f"tried to interpret {name} as a path to a class "
                    f"but unable to import module {submodule}"
                )

            try:
                subclass = getattr(module, class_name)
                constructor = None
                return subclass, constructor
            except AttributeError:
                raise RuntimeError(
                    f"tried to interpret {name} as a path to a class "
                    f"but unable to find class {class_name} in {submodule}"
                )

        else:
            # is not a qualified class name
            raise RuntimeError(
                f"{name} is not a registered name for {cls.__name__}. "
                "You probably need to use the --include-package flag "
                "to load your custom code. Alternatively, you can specify your choices "
                """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
                "in which case they will be automatically imported correctly."
            )

    @classmethod
    def list_available(cls):
        """List default first if it exists"""
        keys = list(Registrable._registry[cls].keys())
        default = cls._default_impl

        if default is None:
            return keys
        elif default not in keys:
            raise RuntimeError(f"Default implementation {default} is not registered")
        else:
            return [default] + [k for k in keys if k != default]

    @classmethod
    def registry_dict(cls):
        return Registrable._registry[cls]