Spaces:
Running
Running
| import inspect | |
| import os | |
| from collections import OrderedDict | |
| from typing import Optional, Iterable, Callable | |
| _innest_error = True | |
| _DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON' | |
| class Registry(dict): | |
| """ | |
| Overview: | |
| A helper class for managing registering modules, it extends a dictionary | |
| and provides a register functions. | |
| Interfaces: | |
| ``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details`` | |
| Examples: | |
| creeting a registry: | |
| >>> some_registry = Registry({"default": default_module}) | |
| There're two ways of registering new modules: | |
| 1): normal way is just calling register function: | |
| >>> def foo(): | |
| >>> ... | |
| some_registry.register("foo_module", foo) | |
| 2): used as decorator when declaring the module: | |
| >>> @some_registry.register("foo_module") | |
| >>> @some_registry.register("foo_modeul_nickname") | |
| >>> def foo(): | |
| >>> ... | |
| Access of module is just like using a dictionary, eg: | |
| >>> f = some_registry["foo_module"] | |
| """ | |
| def __init__(self, *args, **kwargs) -> None: | |
| """ | |
| Overview: | |
| Initialize the Registry object. | |
| Arguments: | |
| - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ | |
| dict. | |
| - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ | |
| dict. | |
| """ | |
| super(Registry, self).__init__(*args, **kwargs) | |
| self.__trace__ = dict() | |
| def register( | |
| self, | |
| module_name: Optional[str] = None, | |
| module: Optional[Callable] = None, | |
| force_overwrite: bool = False | |
| ) -> Callable: | |
| """ | |
| Overview: | |
| Register the module. | |
| Arguments: | |
| - module_name (:obj:`Optional[str]`): The name of the module. | |
| - module (:obj:`Optional[Callable]`): The module to be registered. | |
| - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. | |
| """ | |
| if _DI_ENGINE_REG_TRACE_IS_ON: | |
| frame = inspect.stack()[1][0] | |
| info = inspect.getframeinfo(frame) | |
| filename = info.filename | |
| lineno = info.lineno | |
| # used as function call | |
| if module is not None: | |
| assert module_name is not None | |
| Registry._register_generic(self, module_name, module, force_overwrite) | |
| if _DI_ENGINE_REG_TRACE_IS_ON: | |
| self.__trace__[module_name] = (filename, lineno) | |
| return | |
| # used as decorator | |
| def register_fn(fn: Callable) -> Callable: | |
| if module_name is None: | |
| name = fn.__name__ | |
| else: | |
| name = module_name | |
| Registry._register_generic(self, name, fn, force_overwrite) | |
| if _DI_ENGINE_REG_TRACE_IS_ON: | |
| self.__trace__[name] = (filename, lineno) | |
| return fn | |
| return register_fn | |
| def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None: | |
| """ | |
| Overview: | |
| Register the module. | |
| Arguments: | |
| - module_dict (:obj:`dict`): The dict to store the module. | |
| - module_name (:obj:`str`): The name of the module. | |
| - module (:obj:`Callable`): The module to be registered. | |
| - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. | |
| """ | |
| if not force_overwrite: | |
| assert module_name not in module_dict, module_name | |
| module_dict[module_name] = module | |
| def get(self, module_name: str) -> Callable: | |
| """ | |
| Overview: | |
| Get the module. | |
| Arguments: | |
| - module_name (:obj:`str`): The name of the module. | |
| """ | |
| return self[module_name] | |
| def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object: | |
| """ | |
| Overview: | |
| Build the object. | |
| Arguments: | |
| - obj_type (:obj:`str`): The type of the object. | |
| - obj_args (:obj:`Tuple`): The arguments passed to the object. | |
| - obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object. | |
| """ | |
| try: | |
| build_fn = self[obj_type] | |
| return build_fn(*obj_args, **obj_kwargs) | |
| except Exception as e: | |
| # get build_fn fail | |
| if isinstance(e, KeyError): | |
| raise KeyError("not support buildable-object type: {}".format(obj_type)) | |
| # build_fn execution fail | |
| global _innest_error | |
| if _innest_error: | |
| argspec = inspect.getfullargspec(build_fn) | |
| message = 'for {}(alias={})'.format(build_fn, obj_type) | |
| message += '\nExpected args are:{}'.format(argspec) | |
| message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys()) | |
| message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs) | |
| _innest_error = False | |
| raise e | |
| def query(self) -> Iterable: | |
| """ | |
| Overview: | |
| all registered module names. | |
| """ | |
| return self.keys() | |
| def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict: | |
| """ | |
| Overview: | |
| Get the details of the registered modules. | |
| Arguments: | |
| - aliases (:obj:`Optional[Iterable]`): The aliases of the modules. | |
| """ | |
| assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first" | |
| if aliases is None: | |
| aliases = self.keys() | |
| return OrderedDict((alias, self.__trace__[alias]) for alias in aliases) | |