| | |
| | import inspect |
| | import warnings |
| | from functools import partial |
| |
|
| | from .misc import is_seq_of |
| |
|
| |
|
| | def build_from_cfg(cfg, registry, default_args=None): |
| | """Build a module from config dict. |
| | |
| | Args: |
| | cfg (dict): Config dict. It should at least contain the key "type". |
| | registry (:obj:`Registry`): The registry to search the type from. |
| | default_args (dict, optional): Default initialization arguments. |
| | |
| | Returns: |
| | object: The constructed object. |
| | """ |
| | if not isinstance(cfg, dict): |
| | raise TypeError(f'cfg must be a dict, but got {type(cfg)}') |
| | if 'type' not in cfg: |
| | if default_args is None or 'type' not in default_args: |
| | raise KeyError( |
| | '`cfg` or `default_args` must contain the key "type", ' |
| | f'but got {cfg}\n{default_args}') |
| | if not isinstance(registry, Registry): |
| | raise TypeError('registry must be an mmcv.Registry object, ' |
| | f'but got {type(registry)}') |
| | if not (isinstance(default_args, dict) or default_args is None): |
| | raise TypeError('default_args must be a dict or None, ' |
| | f'but got {type(default_args)}') |
| |
|
| | args = cfg.copy() |
| |
|
| | if default_args is not None: |
| | for name, value in default_args.items(): |
| | args.setdefault(name, value) |
| |
|
| | obj_type = args.pop('type') |
| | if isinstance(obj_type, str): |
| | obj_cls = registry.get(obj_type) |
| | if obj_cls is None: |
| | raise KeyError( |
| | f'{obj_type} is not in the {registry.name} registry') |
| | elif inspect.isclass(obj_type): |
| | obj_cls = obj_type |
| | else: |
| | raise TypeError( |
| | f'type must be a str or valid type, but got {type(obj_type)}') |
| | try: |
| | return obj_cls(**args) |
| | except Exception as e: |
| | |
| | raise type(e)(f'{obj_cls.__name__}: {e}') |
| |
|
| |
|
| | class Registry: |
| | """A registry to map strings to classes. |
| | |
| | Registered object could be built from registry. |
| | Example: |
| | >>> MODELS = Registry('models') |
| | >>> @MODELS.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | >>> resnet = MODELS.build(dict(type='ResNet')) |
| | |
| | Please refer to |
| | https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for |
| | advanced usage. |
| | |
| | Args: |
| | name (str): Registry name. |
| | build_func(func, optional): Build function to construct instance from |
| | Registry, func:`build_from_cfg` is used if neither ``parent`` or |
| | ``build_func`` is specified. If ``parent`` is specified and |
| | ``build_func`` is not given, ``build_func`` will be inherited |
| | from ``parent``. Default: None. |
| | parent (Registry, optional): Parent registry. The class registered in |
| | children registry could be built from parent. Default: None. |
| | scope (str, optional): The scope of registry. It is the key to search |
| | for children registry. If not specified, scope will be the name of |
| | the package where class is defined, e.g. mmdet, mmcls, mmseg. |
| | Default: None. |
| | """ |
| |
|
| | def __init__(self, name, build_func=None, parent=None, scope=None): |
| | self._name = name |
| | self._module_dict = dict() |
| | self._children = dict() |
| | self._scope = self.infer_scope() if scope is None else scope |
| |
|
| | |
| | |
| | |
| | |
| | if build_func is None: |
| | if parent is not None: |
| | self.build_func = parent.build_func |
| | else: |
| | self.build_func = build_from_cfg |
| | else: |
| | self.build_func = build_func |
| | if parent is not None: |
| | assert isinstance(parent, Registry) |
| | parent._add_children(self) |
| | self.parent = parent |
| | else: |
| | self.parent = None |
| |
|
| | def __len__(self): |
| | return len(self._module_dict) |
| |
|
| | def __contains__(self, key): |
| | return self.get(key) is not None |
| |
|
| | def __repr__(self): |
| | format_str = self.__class__.__name__ + \ |
| | f'(name={self._name}, ' \ |
| | f'items={self._module_dict})' |
| | return format_str |
| |
|
| | @staticmethod |
| | def infer_scope(): |
| | """Infer the scope of registry. |
| | |
| | The name of the package where registry is defined will be returned. |
| | |
| | Example: |
| | # in mmdet/models/backbone/resnet.py |
| | >>> MODELS = Registry('models') |
| | >>> @MODELS.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | The scope of ``ResNet`` will be ``mmdet``. |
| | |
| | |
| | Returns: |
| | scope (str): The inferred scope name. |
| | """ |
| | |
| | |
| | filename = inspect.getmodule(inspect.stack()[2][0]).__name__ |
| | split_filename = filename.split('.') |
| | return split_filename[0] |
| |
|
| | @staticmethod |
| | def split_scope_key(key): |
| | """Split scope and key. |
| | |
| | The first scope will be split from key. |
| | |
| | Examples: |
| | >>> Registry.split_scope_key('mmdet.ResNet') |
| | 'mmdet', 'ResNet' |
| | >>> Registry.split_scope_key('ResNet') |
| | None, 'ResNet' |
| | |
| | Return: |
| | scope (str, None): The first scope. |
| | key (str): The remaining key. |
| | """ |
| | split_index = key.find('.') |
| | if split_index != -1: |
| | return key[:split_index], key[split_index + 1:] |
| | else: |
| | return None, key |
| |
|
| | @property |
| | def name(self): |
| | return self._name |
| |
|
| | @property |
| | def scope(self): |
| | return self._scope |
| |
|
| | @property |
| | def module_dict(self): |
| | return self._module_dict |
| |
|
| | @property |
| | def children(self): |
| | return self._children |
| |
|
| | def get(self, key): |
| | """Get the registry record. |
| | |
| | Args: |
| | key (str): The class name in string format. |
| | |
| | Returns: |
| | class: The corresponding class. |
| | """ |
| | scope, real_key = self.split_scope_key(key) |
| | if scope is None or scope == self._scope: |
| | |
| | if real_key in self._module_dict: |
| | return self._module_dict[real_key] |
| | else: |
| | |
| | if scope in self._children: |
| | return self._children[scope].get(real_key) |
| | else: |
| | |
| | parent = self.parent |
| | while parent.parent is not None: |
| | parent = parent.parent |
| | return parent.get(key) |
| |
|
| | def build(self, *args, **kwargs): |
| | return self.build_func(*args, **kwargs, registry=self) |
| |
|
| | def _add_children(self, registry): |
| | """Add children for a registry. |
| | |
| | The ``registry`` will be added as children based on its scope. |
| | The parent registry could build objects from children registry. |
| | |
| | Example: |
| | >>> models = Registry('models') |
| | >>> mmdet_models = Registry('models', parent=models) |
| | >>> @mmdet_models.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | >>> resnet = models.build(dict(type='mmdet.ResNet')) |
| | """ |
| |
|
| | assert isinstance(registry, Registry) |
| | assert registry.scope is not None |
| | assert registry.scope not in self.children, \ |
| | f'scope {registry.scope} exists in {self.name} registry' |
| | self.children[registry.scope] = registry |
| |
|
| | def _register_module(self, module_class, module_name=None, force=False): |
| | if not inspect.isclass(module_class): |
| | raise TypeError('module must be a class, ' |
| | f'but got {type(module_class)}') |
| |
|
| | if module_name is None: |
| | module_name = module_class.__name__ |
| | if isinstance(module_name, str): |
| | module_name = [module_name] |
| | for name in module_name: |
| | if not force and name in self._module_dict: |
| | raise KeyError(f'{name} is already registered ' |
| | f'in {self.name}') |
| | self._module_dict[name] = module_class |
| |
|
| | def deprecated_register_module(self, cls=None, force=False): |
| | warnings.warn( |
| | 'The old API of register_module(module, force=False) ' |
| | 'is deprecated and will be removed, please use the new API ' |
| | 'register_module(name=None, force=False, module=None) instead.') |
| | if cls is None: |
| | return partial(self.deprecated_register_module, force=force) |
| | self._register_module(cls, force=force) |
| | return cls |
| |
|
| | def register_module(self, name=None, force=False, module=None): |
| | """Register a module. |
| | |
| | A record will be added to `self._module_dict`, whose key is the class |
| | name or the specified name, and value is the class itself. |
| | It can be used as a decorator or a normal function. |
| | |
| | Example: |
| | >>> backbones = Registry('backbone') |
| | >>> @backbones.register_module() |
| | >>> class ResNet: |
| | >>> pass |
| | |
| | >>> backbones = Registry('backbone') |
| | >>> @backbones.register_module(name='mnet') |
| | >>> class MobileNet: |
| | >>> pass |
| | |
| | >>> backbones = Registry('backbone') |
| | >>> class ResNet: |
| | >>> pass |
| | >>> backbones.register_module(ResNet) |
| | |
| | Args: |
| | name (str | None): The module name to be registered. If not |
| | specified, the class name will be used. |
| | force (bool, optional): Whether to override an existing class with |
| | the same name. Default: False. |
| | module (type): Module class to be registered. |
| | """ |
| | if not isinstance(force, bool): |
| | raise TypeError(f'force must be a boolean, but got {type(force)}') |
| | |
| | |
| | if isinstance(name, type): |
| | return self.deprecated_register_module(name, force=force) |
| |
|
| | |
| | if not (name is None or isinstance(name, str) or is_seq_of(name, str)): |
| | raise TypeError( |
| | 'name must be either of None, an instance of str or a sequence' |
| | f' of str, but got {type(name)}') |
| |
|
| | |
| | if module is not None: |
| | self._register_module( |
| | module_class=module, module_name=name, force=force) |
| | return module |
| |
|
| | |
| | def _register(cls): |
| | self._register_module( |
| | module_class=cls, module_name=name, force=force) |
| | return cls |
| |
|
| | return _register |
| |
|