Spaces:
Runtime error
Runtime error
| # Copyright (c) Alibaba. All rights reserved. | |
| import inspect | |
| import warnings | |
| import functools | |
| from functools import partial | |
| from typing import Any, Dict, Optional | |
| from collections import abc | |
| from inspect import getfullargspec | |
| def is_seq_of(seq, expected_type, seq_type=None): | |
| """Check whether it is a sequence of some type. | |
| Args: | |
| seq (Sequence): The sequence to be checked. | |
| expected_type (type): Expected type of sequence items. | |
| seq_type (type, optional): Expected sequence type. | |
| Returns: | |
| bool: Whether the sequence is valid. | |
| """ | |
| if seq_type is None: | |
| exp_seq_type = abc.Sequence | |
| else: | |
| assert isinstance(seq_type, type) | |
| exp_seq_type = seq_type | |
| if not isinstance(seq, exp_seq_type): | |
| return False | |
| for item in seq: | |
| if not isinstance(item, expected_type): | |
| return False | |
| return True | |
| def deprecated_api_warning(name_dict, cls_name=None): | |
| """A decorator to check if some arguments are deprecate and try to replace | |
| deprecate src_arg_name to dst_arg_name. | |
| Args: | |
| name_dict(dict): | |
| key (str): Deprecate argument names. | |
| val (str): Expected argument names. | |
| Returns: | |
| func: New function. | |
| """ | |
| def api_warning_wrapper(old_func): | |
| def new_func(*args, **kwargs): | |
| # get the arg spec of the decorated method | |
| args_info = getfullargspec(old_func) | |
| # get name of the function | |
| func_name = old_func.__name__ | |
| if cls_name is not None: | |
| func_name = f'{cls_name}.{func_name}' | |
| if args: | |
| arg_names = args_info.args[:len(args)] | |
| for src_arg_name, dst_arg_name in name_dict.items(): | |
| if src_arg_name in arg_names: | |
| warnings.warn( | |
| f'"{src_arg_name}" is deprecated in ' | |
| f'`{func_name}`, please use "{dst_arg_name}" ' | |
| 'instead', DeprecationWarning) | |
| arg_names[arg_names.index(src_arg_name)] = dst_arg_name | |
| if kwargs: | |
| for src_arg_name, dst_arg_name in name_dict.items(): | |
| if src_arg_name in kwargs: | |
| assert dst_arg_name not in kwargs, ( | |
| f'The expected behavior is to replace ' | |
| f'the deprecated key `{src_arg_name}` to ' | |
| f'new key `{dst_arg_name}`, but got them ' | |
| f'in the arguments at the same time, which ' | |
| f'is confusing. `{src_arg_name} will be ' | |
| f'deprecated in the future, please ' | |
| f'use `{dst_arg_name}` instead.') | |
| warnings.warn( | |
| f'"{src_arg_name}" is deprecated in ' | |
| f'`{func_name}`, please use "{dst_arg_name}" ' | |
| 'instead', DeprecationWarning) | |
| kwargs[dst_arg_name] = kwargs.pop(src_arg_name) | |
| # apply converted arguments to the decorated method | |
| output = old_func(*args, **kwargs) | |
| return output | |
| return new_func | |
| return api_warning_wrapper | |
| def build_from_cfg(cfg: Dict, | |
| registry: 'Registry', | |
| default_args: Optional[Dict] = None) -> Any: | |
| """Build a module from config dict when it is a class configuration, or | |
| call a function from config dict when it is a function configuration. | |
| Example: | |
| >>> MODELS = Registry('models') | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS) | |
| >>> # Returns an instantiated object | |
| >>> @MODELS.register_module() | |
| >>> def resnet50(): | |
| >>> pass | |
| >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) | |
| >>> # Return a result of the calling function | |
| Args: | |
| cfg (dict): Config dict. It should at least contain the key "type". | |
| registry (:obj:`Registry`): The registry to search the type from. | |
| default_args (dict, optional): Default initialization arguments. | |
| Returns: | |
| object: The constructed object. | |
| """ | |
| if not isinstance(cfg, dict): | |
| raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |
| if 'type' not in cfg: | |
| if default_args is None or 'type' not in default_args: | |
| raise KeyError( | |
| '`cfg` or `default_args` must contain the key "type", ' | |
| f'but got {cfg}\n{default_args}') | |
| if not isinstance(registry, Registry): | |
| raise TypeError('registry must be an mmcv.Registry object, ' | |
| f'but got {type(registry)}') | |
| if not (isinstance(default_args, dict) or default_args is None): | |
| raise TypeError('default_args must be a dict or None, ' | |
| f'but got {type(default_args)}') | |
| args = cfg.copy() | |
| if default_args is not None: | |
| for name, value in default_args.items(): | |
| args.setdefault(name, value) | |
| obj_type = args.pop('type') | |
| if isinstance(obj_type, str): | |
| obj_cls = registry.get(obj_type) | |
| if obj_cls is None: | |
| raise KeyError( | |
| f'{obj_type} is not in the {registry.name} registry') | |
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |
| obj_cls = obj_type | |
| else: | |
| raise TypeError( | |
| f'type must be a str or valid type, but got {type(obj_type)}') | |
| try: | |
| return obj_cls(**args) | |
| except Exception as e: | |
| # Normal TypeError does not print class name. | |
| raise type(e)(f'{obj_cls.__name__}: {e}') | |
| class Registry: | |
| """A registry to map strings to classes or functions. | |
| Registered object could be built from registry. Meanwhile, registered | |
| functions could be called from registry. | |
| Example: | |
| >>> MODELS = Registry('models') | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> resnet = MODELS.build(dict(type='ResNet')) | |
| >>> @MODELS.register_module() | |
| >>> def resnet50(): | |
| >>> pass | |
| >>> resnet = MODELS.build(dict(type='resnet50')) | |
| Please refer to | |
| https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for | |
| advanced usage. | |
| Args: | |
| name (str): Registry name. | |
| build_func(func, optional): Build function to construct instance from | |
| Registry, func:`build_from_cfg` is used if neither ``parent`` or | |
| ``build_func`` is specified. If ``parent`` is specified and | |
| ``build_func`` is not given, ``build_func`` will be inherited | |
| from ``parent``. Default: None. | |
| parent (Registry, optional): Parent registry. The class registered in | |
| children registry could be built from parent. Default: None. | |
| scope (str, optional): The scope of registry. It is the key to search | |
| for children registry. If not specified, scope will be the name of | |
| the package where class is defined, e.g. mmdet, mmcls, mmseg. | |
| Default: None. | |
| """ | |
| def __init__(self, name, build_func=None, parent=None, scope=None): | |
| self._name = name | |
| self._module_dict = dict() | |
| self._children = dict() | |
| self._scope = self.infer_scope() if scope is None else scope | |
| # self.build_func will be set with the following priority: | |
| # 1. build_func | |
| # 2. parent.build_func | |
| # 3. build_from_cfg | |
| if build_func is None: | |
| if parent is not None: | |
| self.build_func = parent.build_func | |
| else: | |
| self.build_func = build_from_cfg | |
| else: | |
| self.build_func = build_func | |
| if parent is not None: | |
| assert isinstance(parent, Registry) | |
| parent._add_children(self) | |
| self.parent = parent | |
| else: | |
| self.parent = None | |
| def __len__(self): | |
| return len(self._module_dict) | |
| def __contains__(self, key): | |
| return self.get(key) is not None | |
| def __repr__(self): | |
| format_str = self.__class__.__name__ + \ | |
| f'(name={self._name}, ' \ | |
| f'items={self._module_dict})' | |
| return format_str | |
| def infer_scope(): | |
| """Infer the scope of registry. | |
| The name of the package where registry is defined will be returned. | |
| Example: | |
| >>> # in mmdet/models/backbone/resnet.py | |
| >>> MODELS = Registry('models') | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| The scope of ``ResNet`` will be ``mmdet``. | |
| Returns: | |
| str: The inferred scope name. | |
| """ | |
| # We access the caller using inspect.currentframe() instead of | |
| # inspect.stack() for performance reasons. See details in PR #1844 | |
| frame = inspect.currentframe() | |
| # get the frame where `infer_scope()` is called | |
| infer_scope_caller = frame.f_back.f_back | |
| filename = inspect.getmodule(infer_scope_caller).__name__ | |
| split_filename = filename.split('.') | |
| return split_filename[0] | |
| def split_scope_key(key): | |
| """Split scope and key. | |
| The first scope will be split from key. | |
| Examples: | |
| >>> Registry.split_scope_key('mmdet.ResNet') | |
| 'mmdet', 'ResNet' | |
| >>> Registry.split_scope_key('ResNet') | |
| None, 'ResNet' | |
| Return: | |
| tuple[str | None, str]: The former element is the first scope of | |
| the key, which can be ``None``. The latter is the remaining key. | |
| """ | |
| split_index = key.find('.') | |
| if split_index != -1: | |
| return key[:split_index], key[split_index + 1:] | |
| else: | |
| return None, key | |
| def name(self): | |
| return self._name | |
| def scope(self): | |
| return self._scope | |
| def module_dict(self): | |
| return self._module_dict | |
| def children(self): | |
| return self._children | |
| def get(self, key): | |
| """Get the registry record. | |
| Args: | |
| key (str): The class name in string format. | |
| Returns: | |
| class: The corresponding class. | |
| """ | |
| scope, real_key = self.split_scope_key(key) | |
| if scope is None or scope == self._scope: | |
| # get from self | |
| if real_key in self._module_dict: | |
| return self._module_dict[real_key] | |
| else: | |
| # get from self._children | |
| if scope in self._children: | |
| return self._children[scope].get(real_key) | |
| else: | |
| # goto root | |
| parent = self.parent | |
| while parent.parent is not None: | |
| parent = parent.parent | |
| return parent.get(key) | |
| def build(self, *args, **kwargs): | |
| return self.build_func(*args, **kwargs, registry=self) | |
| def _add_children(self, registry): | |
| """Add children for a registry. | |
| The ``registry`` will be added as children based on its scope. | |
| The parent registry could build objects from children registry. | |
| Example: | |
| >>> models = Registry('models') | |
| >>> mmdet_models = Registry('models', parent=models) | |
| >>> @mmdet_models.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> resnet = models.build(dict(type='mmdet.ResNet')) | |
| """ | |
| assert isinstance(registry, Registry) | |
| assert registry.scope is not None | |
| assert registry.scope not in self.children, \ | |
| f'scope {registry.scope} exists in {self.name} registry' | |
| self.children[registry.scope] = registry | |
| def _register_module(self, module, module_name=None, force=False): | |
| if not inspect.isclass(module) and not inspect.isfunction(module): | |
| raise TypeError('module must be a class or a function, ' | |
| f'but got {type(module)}') | |
| if module_name is None: | |
| module_name = module.__name__ | |
| if isinstance(module_name, str): | |
| module_name = [module_name] | |
| for name in module_name: | |
| if not force and name in self._module_dict: | |
| raise KeyError(f'{name} is already registered ' | |
| f'in {self.name}') | |
| self._module_dict[name] = module | |
| def deprecated_register_module(self, cls=None, force=False): | |
| warnings.warn( | |
| 'The old API of register_module(module, force=False) ' | |
| 'is deprecated and will be removed, please use the new API ' | |
| 'register_module(name=None, force=False, module=None) instead.', | |
| DeprecationWarning) | |
| if cls is None: | |
| return partial(self.deprecated_register_module, force=force) | |
| self._register_module(cls, force=force) | |
| return cls | |
| def register_module(self, name=None, force=False, module=None): | |
| """Register a module. | |
| A record will be added to `self._module_dict`, whose key is the class | |
| name or the specified name, and value is the class itself. | |
| It can be used as a decorator or a normal function. | |
| Example: | |
| >>> backbones = Registry('backbone') | |
| >>> @backbones.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> backbones = Registry('backbone') | |
| >>> @backbones.register_module(name='mnet') | |
| >>> class MobileNet: | |
| >>> pass | |
| >>> backbones = Registry('backbone') | |
| >>> class ResNet: | |
| >>> pass | |
| >>> backbones.register_module(ResNet) | |
| Args: | |
| name (str | None): The module name to be registered. If not | |
| specified, the class name will be used. | |
| force (bool, optional): Whether to override an existing class with | |
| the same name. Default: False. | |
| module (type): Module class or function to be registered. | |
| """ | |
| if not isinstance(force, bool): | |
| raise TypeError(f'force must be a boolean, but got {type(force)}') | |
| # NOTE: This is a walkaround to be compatible with the old api, | |
| # while it may introduce unexpected bugs. | |
| if isinstance(name, type): | |
| return self.deprecated_register_module(name, force=force) | |
| # raise the error ahead of time | |
| if not (name is None or isinstance(name, str) or is_seq_of(name, str)): | |
| raise TypeError( | |
| 'name must be either of None, an instance of str or a sequence' | |
| f' of str, but got {type(name)}') | |
| # use it as a normal method: x.register_module(module=SomeClass) | |
| if module is not None: | |
| self._register_module(module=module, module_name=name, force=force) | |
| return module | |
| # use it as a decorator: @x.register_module() | |
| def _register(module): | |
| self._register_module(module=module, module_name=name, force=force) | |
| return module | |
| return _register |