seifbenayed commited on
Commit
26e131d
·
1 Parent(s): 094d6ae
Files changed (2) hide show
  1. register/misc.py +352 -0
  2. register/register.py +318 -0
register/misc.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import collections.abc
3
+ import functools
4
+ import itertools
5
+ import subprocess
6
+ import warnings
7
+ from collections import abc
8
+ from importlib import import_module
9
+ from inspect import getfullargspec
10
+ from itertools import repeat
11
+
12
+
13
+ # From PyTorch internals
14
+ def _ntuple(n):
15
+
16
+ def parse(x):
17
+ if isinstance(x, collections.abc.Iterable):
18
+ return x
19
+ return tuple(repeat(x, n))
20
+
21
+ return parse
22
+
23
+
24
+ to_1tuple = _ntuple(1)
25
+ to_2tuple = _ntuple(2)
26
+ to_3tuple = _ntuple(3)
27
+ to_4tuple = _ntuple(4)
28
+ to_ntuple = _ntuple
29
+
30
+
31
+ def is_str(x):
32
+ """Whether the input is an string instance.
33
+ Note: This method is deprecated since python 2 is no longer supported.
34
+ """
35
+ return isinstance(x, str)
36
+
37
+
38
+ def import_modules_from_strings(imports, allow_failed_imports=False):
39
+ """Import modules from the given list of strings.
40
+ Args:
41
+ imports (list | str | None): The given module names to be imported.
42
+ allow_failed_imports (bool): If True, the failed imports will return
43
+ None. Otherwise, an ImportError is raise. Default: False.
44
+ Returns:
45
+ list[module] | module | None: The imported modules.
46
+ Examples:
47
+ >>> osp, sys = import_modules_from_strings(
48
+ ... ['os.path', 'sys'])
49
+ >>> import os.path as osp_
50
+ >>> import sys as sys_
51
+ >>> assert osp == osp_
52
+ >>> assert sys == sys_
53
+ """
54
+ if not imports:
55
+ return
56
+ single_import = False
57
+ if isinstance(imports, str):
58
+ single_import = True
59
+ imports = [imports]
60
+ if not isinstance(imports, list):
61
+ raise TypeError(
62
+ f'custom_imports must be a list but got type {type(imports)}')
63
+ imported = []
64
+ for imp in imports:
65
+ if not isinstance(imp, str):
66
+ raise TypeError(
67
+ f'{imp} is of type {type(imp)} and cannot be imported.')
68
+ try:
69
+ imported_tmp = import_module(imp)
70
+ except ImportError:
71
+ if allow_failed_imports:
72
+ warnings.warn(f'{imp} failed to import and is ignored.',
73
+ UserWarning)
74
+ imported_tmp = None
75
+ else:
76
+ raise ImportError
77
+ imported.append(imported_tmp)
78
+ if single_import:
79
+ imported = imported[0]
80
+ return imported
81
+
82
+
83
+ def iter_cast(inputs, dst_type, return_type=None):
84
+ """Cast elements of an iterable object into some type.
85
+ Args:
86
+ inputs (Iterable): The input object.
87
+ dst_type (type): Destination type.
88
+ return_type (type, optional): If specified, the output object will be
89
+ converted to this type, otherwise an iterator.
90
+ Returns:
91
+ iterator or specified type: The converted object.
92
+ """
93
+ if not isinstance(inputs, abc.Iterable):
94
+ raise TypeError('inputs must be an iterable object')
95
+ if not isinstance(dst_type, type):
96
+ raise TypeError('"dst_type" must be a valid type')
97
+
98
+ out_iterable = map(dst_type, inputs)
99
+
100
+ if return_type is None:
101
+ return out_iterable
102
+ else:
103
+ return return_type(out_iterable)
104
+
105
+
106
+ def list_cast(inputs, dst_type):
107
+ """Cast elements of an iterable object into a list of some type.
108
+ A partial method of :func:`iter_cast`.
109
+ """
110
+ return iter_cast(inputs, dst_type, return_type=list)
111
+
112
+
113
+ def tuple_cast(inputs, dst_type):
114
+ """Cast elements of an iterable object into a tuple of some type.
115
+ A partial method of :func:`iter_cast`.
116
+ """
117
+ return iter_cast(inputs, dst_type, return_type=tuple)
118
+
119
+
120
+ def is_seq_of(seq, expected_type, seq_type=None):
121
+ """Check whether it is a sequence of some type.
122
+ Args:
123
+ seq (Sequence): The sequence to be checked.
124
+ expected_type (type): Expected type of sequence items.
125
+ seq_type (type, optional): Expected sequence type.
126
+ Returns:
127
+ bool: Whether the sequence is valid.
128
+ """
129
+ if seq_type is None:
130
+ exp_seq_type = abc.Sequence
131
+ else:
132
+ assert isinstance(seq_type, type)
133
+ exp_seq_type = seq_type
134
+ if not isinstance(seq, exp_seq_type):
135
+ return False
136
+ for item in seq:
137
+ if not isinstance(item, expected_type):
138
+ return False
139
+ return True
140
+
141
+
142
+ def is_list_of(seq, expected_type):
143
+ """Check whether it is a list of some type.
144
+ A partial method of :func:`is_seq_of`.
145
+ """
146
+ return is_seq_of(seq, expected_type, seq_type=list)
147
+
148
+
149
+ def is_tuple_of(seq, expected_type):
150
+ """Check whether it is a tuple of some type.
151
+ A partial method of :func:`is_seq_of`.
152
+ """
153
+ return is_seq_of(seq, expected_type, seq_type=tuple)
154
+
155
+
156
+ def slice_list(in_list, lens):
157
+ """Slice a list into several sub lists by a list of given length.
158
+ Args:
159
+ in_list (list): The list to be sliced.
160
+ lens(int or list): The expected length of each out list.
161
+ Returns:
162
+ list: A list of sliced list.
163
+ """
164
+ if isinstance(lens, int):
165
+ assert len(in_list) % lens == 0
166
+ lens = [lens] * int(len(in_list) / lens)
167
+ if not isinstance(lens, list):
168
+ raise TypeError('"indices" must be an integer or a list of integers')
169
+ elif sum(lens) != len(in_list):
170
+ raise ValueError('sum of lens and list length does not '
171
+ f'match: {sum(lens)} != {len(in_list)}')
172
+ out_list = []
173
+ idx = 0
174
+ for i in range(len(lens)):
175
+ out_list.append(in_list[idx:idx + lens[i]])
176
+ idx += lens[i]
177
+ return out_list
178
+
179
+
180
+ def concat_list(in_list):
181
+ """Concatenate a list of list into a single list.
182
+ Args:
183
+ in_list (list): The list of list to be merged.
184
+ Returns:
185
+ list: The concatenated flat list.
186
+ """
187
+ return list(itertools.chain(*in_list))
188
+
189
+
190
+ def check_prerequisites(
191
+ prerequisites,
192
+ checker,
193
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
194
+ 'found, please install them first.'): # yapf: disable
195
+ """A decorator factory to check if prerequisites are satisfied.
196
+ Args:
197
+ prerequisites (str of list[str]): Prerequisites to be checked.
198
+ checker (callable): The checker method that returns True if a
199
+ prerequisite is meet, False otherwise.
200
+ msg_tmpl (str): The message template with two variables.
201
+ Returns:
202
+ decorator: A specific decorator.
203
+ """
204
+
205
+ def wrap(func):
206
+
207
+ @functools.wraps(func)
208
+ def wrapped_func(*args, **kwargs):
209
+ requirements = [prerequisites] if isinstance(
210
+ prerequisites, str) else prerequisites
211
+ missing = []
212
+ for item in requirements:
213
+ if not checker(item):
214
+ missing.append(item)
215
+ if missing:
216
+ print(msg_tmpl.format(', '.join(missing), func.__name__))
217
+ raise RuntimeError('Prerequisites not meet.')
218
+ else:
219
+ return func(*args, **kwargs)
220
+
221
+ return wrapped_func
222
+
223
+ return wrap
224
+
225
+
226
+ def _check_py_package(package):
227
+ try:
228
+ import_module(package)
229
+ except ImportError:
230
+ return False
231
+ else:
232
+ return True
233
+
234
+
235
+ def _check_executable(cmd):
236
+ if subprocess.call(f'which {cmd}', shell=True) != 0:
237
+ return False
238
+ else:
239
+ return True
240
+
241
+
242
+ def requires_package(prerequisites):
243
+ """A decorator to check if some python packages are installed.
244
+ Example:
245
+ >>> @requires_package('numpy')
246
+ >>> func(arg1, args):
247
+ >>> return numpy.zeros(1)
248
+ array([0.])
249
+ >>> @requires_package(['numpy', 'non_package'])
250
+ >>> func(arg1, args):
251
+ >>> return numpy.zeros(1)
252
+ ImportError
253
+ """
254
+ return check_prerequisites(prerequisites, checker=_check_py_package)
255
+
256
+
257
+ def requires_executable(prerequisites):
258
+ """A decorator to check if some executable files are installed.
259
+ Example:
260
+ >>> @requires_executable('ffmpeg')
261
+ >>> func(arg1, args):
262
+ >>> print(1)
263
+ 1
264
+ """
265
+ return check_prerequisites(prerequisites, checker=_check_executable)
266
+
267
+
268
+ def deprecated_api_warning(name_dict, cls_name=None):
269
+ """A decorator to check if some arguments are deprecate and try to replace
270
+ deprecate src_arg_name to dst_arg_name.
271
+ Args:
272
+ name_dict(dict):
273
+ key (str): Deprecate argument names.
274
+ val (str): Expected argument names.
275
+ Returns:
276
+ func: New function.
277
+ """
278
+
279
+ def api_warning_wrapper(old_func):
280
+
281
+ @functools.wraps(old_func)
282
+ def new_func(*args, **kwargs):
283
+ # get the arg spec of the decorated method
284
+ args_info = getfullargspec(old_func)
285
+ # get name of the function
286
+ func_name = old_func.__name__
287
+ if cls_name is not None:
288
+ func_name = f'{cls_name}.{func_name}'
289
+ if args:
290
+ arg_names = args_info.args[:len(args)]
291
+ for src_arg_name, dst_arg_name in name_dict.items():
292
+ if src_arg_name in arg_names:
293
+ warnings.warn(
294
+ f'"{src_arg_name}" is deprecated in '
295
+ f'`{func_name}`, please use "{dst_arg_name}" '
296
+ 'instead', DeprecationWarning)
297
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
298
+ if kwargs:
299
+ for src_arg_name, dst_arg_name in name_dict.items():
300
+ if src_arg_name in kwargs:
301
+
302
+ assert dst_arg_name not in kwargs, (
303
+ f'The expected behavior is to replace '
304
+ f'the deprecated key `{src_arg_name}` to '
305
+ f'new key `{dst_arg_name}`, but got them '
306
+ f'in the arguments at the same time, which '
307
+ f'is confusing. `{src_arg_name} will be '
308
+ f'deprecated in the future, please '
309
+ f'use `{dst_arg_name}` instead.')
310
+
311
+ warnings.warn(
312
+ f'"{src_arg_name}" is deprecated in '
313
+ f'`{func_name}`, please use "{dst_arg_name}" '
314
+ 'instead', DeprecationWarning)
315
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
316
+
317
+ # apply converted arguments to the decorated method
318
+ output = old_func(*args, **kwargs)
319
+ return output
320
+
321
+ return new_func
322
+
323
+ return api_warning_wrapper
324
+
325
+
326
+ def is_method_overridden(method, base_class, derived_class):
327
+ """Check if a method of base class is overridden in derived class.
328
+ Args:
329
+ method (str): the method name to check.
330
+ base_class (type): the class of the base class.
331
+ derived_class (type | Any): the class or instance of the derived class.
332
+ """
333
+ assert isinstance(base_class, type), \
334
+ "base_class doesn't accept instance, Please pass class instead."
335
+
336
+ if not isinstance(derived_class, type):
337
+ derived_class = derived_class.__class__
338
+
339
+ base_method = getattr(base_class, method)
340
+ derived_method = getattr(derived_class, method)
341
+ return derived_method != base_method
342
+
343
+
344
+ def has_method(obj: object, method: str) -> bool:
345
+ """Check whether the object has a method.
346
+ Args:
347
+ method (str): The method name to check.
348
+ obj (object): The object to check.
349
+ Returns:
350
+ bool: True if the object has the method else False.
351
+ """
352
+ return hasattr(obj, method) and callable(getattr(obj, method))
register/register.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import inspect
3
+ import warnings
4
+ from functools import partial
5
+ from typing import Any, Dict, Optional
6
+
7
+ from .misc import deprecated_api_warning, is_seq_of
8
+
9
+
10
+ def build_from_cfg(cfg: Dict,
11
+ registry: 'Registry',
12
+ default_args: Optional[Dict] = None) -> Any:
13
+ """Build a module from config dict when it is a class configuration, or
14
+ call a function from config dict when it is a function configuration.
15
+ Example:
16
+ >>> MODELS = Registry('models')
17
+ >>> @MODELS.register_module()
18
+ >>> class ResNet:
19
+ >>> pass
20
+ >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
21
+ >>> # Returns an instantiated object
22
+ >>> @MODELS.register_module()
23
+ >>> def resnet50():
24
+ >>> pass
25
+ >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
26
+ >>> # Return a result of the calling function
27
+ Args:
28
+ cfg (dict): Config dict. It should at least contain the key "type".
29
+ registry (:obj:`Registry`): The registry to search the type from.
30
+ default_args (dict, optional): Default initialization arguments.
31
+ Returns:
32
+ object: The constructed object.
33
+ """
34
+ if not isinstance(cfg, dict):
35
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
36
+ if 'type' not in cfg:
37
+ if default_args is None or 'type' not in default_args:
38
+ raise KeyError(
39
+ '`cfg` or `default_args` must contain the key "type", '
40
+ f'but got {cfg}\n{default_args}')
41
+ if not isinstance(registry, Registry):
42
+ raise TypeError('registry must be an mmcv.Registry object, '
43
+ f'but got {type(registry)}')
44
+ if not (isinstance(default_args, dict) or default_args is None):
45
+ raise TypeError('default_args must be a dict or None, '
46
+ f'but got {type(default_args)}')
47
+
48
+ args = cfg.copy()
49
+
50
+ if default_args is not None:
51
+ for name, value in default_args.items():
52
+ args.setdefault(name, value)
53
+
54
+ obj_type = args.pop('type')
55
+ if isinstance(obj_type, str):
56
+ obj_cls = registry.get(obj_type)
57
+ if obj_cls is None:
58
+ raise KeyError(
59
+ f'{obj_type} is not in the {registry.name} registry')
60
+ elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
61
+ obj_cls = obj_type
62
+ else:
63
+ raise TypeError(
64
+ f'type must be a str or valid type, but got {type(obj_type)}')
65
+ try:
66
+ return obj_cls(**args)
67
+ except Exception as e:
68
+ # Normal TypeError does not print class name.
69
+ raise type(e)(f'{obj_cls.__name__}: {e}')
70
+
71
+
72
+ class Registry:
73
+ """A registry to map strings to classes or functions.
74
+ Registered object could be built from registry. Meanwhile, registered
75
+ functions could be called from registry.
76
+ Example:
77
+ >>> MODELS = Registry('models')
78
+ >>> @MODELS.register_module()
79
+ >>> class ResNet:
80
+ >>> pass
81
+ >>> resnet = MODELS.build(dict(type='ResNet'))
82
+ >>> @MODELS.register_module()
83
+ >>> def resnet50():
84
+ >>> pass
85
+ >>> resnet = MODELS.build(dict(type='resnet50'))
86
+ Please refer to
87
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
88
+ advanced usage.
89
+ Args:
90
+ name (str): Registry name.
91
+ build_func(func, optional): Build function to construct instance from
92
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
93
+ ``build_func`` is specified. If ``parent`` is specified and
94
+ ``build_func`` is not given, ``build_func`` will be inherited
95
+ from ``parent``. Default: None.
96
+ parent (Registry, optional): Parent registry. The class registered in
97
+ children registry could be built from parent. Default: None.
98
+ scope (str, optional): The scope of registry. It is the key to search
99
+ for children registry. If not specified, scope will be the name of
100
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
101
+ Default: None.
102
+ """
103
+
104
+ def __init__(self, name, build_func=None, parent=None, scope=None):
105
+ self._name = name
106
+ self._module_dict = dict()
107
+ self._children = dict()
108
+ self._scope = self.infer_scope() if scope is None else scope
109
+
110
+ # self.build_func will be set with the following priority:
111
+ # 1. build_func
112
+ # 2. parent.build_func
113
+ # 3. build_from_cfg
114
+ if build_func is None:
115
+ if parent is not None:
116
+ self.build_func = parent.build_func
117
+ else:
118
+ self.build_func = build_from_cfg
119
+ else:
120
+ self.build_func = build_func
121
+ if parent is not None:
122
+ assert isinstance(parent, Registry)
123
+ parent._add_children(self)
124
+ self.parent = parent
125
+ else:
126
+ self.parent = None
127
+
128
+ def __len__(self):
129
+ return len(self._module_dict)
130
+
131
+ def __contains__(self, key):
132
+ return self.get(key) is not None
133
+
134
+ def __repr__(self):
135
+ format_str = self.__class__.__name__ + \
136
+ f'(name={self._name}, ' \
137
+ f'items={self._module_dict})'
138
+ return format_str
139
+
140
+ @staticmethod
141
+ def infer_scope():
142
+ """Infer the scope of registry.
143
+ The name of the package where registry is defined will be returned.
144
+ Example:
145
+ >>> # in mmdet/models/backbone/resnet.py
146
+ >>> MODELS = Registry('models')
147
+ >>> @MODELS.register_module()
148
+ >>> class ResNet:
149
+ >>> pass
150
+ The scope of ``ResNet`` will be ``mmdet``.
151
+ Returns:
152
+ str: The inferred scope name.
153
+ """
154
+ # We access the caller using inspect.currentframe() instead of
155
+ # inspect.stack() for performance reasons. See details in PR #1844
156
+ frame = inspect.currentframe()
157
+ # get the frame where `infer_scope()` is called
158
+ infer_scope_caller = frame.f_back.f_back
159
+ filename = inspect.getmodule(infer_scope_caller).__name__
160
+ split_filename = filename.split('.')
161
+ return split_filename[0]
162
+
163
+ @staticmethod
164
+ def split_scope_key(key):
165
+ """Split scope and key.
166
+ The first scope will be split from key.
167
+ Examples:
168
+ >>> Registry.split_scope_key('mmdet.ResNet')
169
+ 'mmdet', 'ResNet'
170
+ >>> Registry.split_scope_key('ResNet')
171
+ None, 'ResNet'
172
+ Return:
173
+ tuple[str | None, str]: The former element is the first scope of
174
+ the key, which can be ``None``. The latter is the remaining key.
175
+ """
176
+ split_index = key.find('.')
177
+ if split_index != -1:
178
+ return key[:split_index], key[split_index + 1:]
179
+ else:
180
+ return None, key
181
+
182
+ @property
183
+ def name(self):
184
+ return self._name
185
+
186
+ @property
187
+ def scope(self):
188
+ return self._scope
189
+
190
+ @property
191
+ def module_dict(self):
192
+ return self._module_dict
193
+
194
+ @property
195
+ def children(self):
196
+ return self._children
197
+
198
+ def get(self, key):
199
+ """Get the registry record.
200
+ Args:
201
+ key (str): The class name in string format.
202
+ Returns:
203
+ class: The corresponding class.
204
+ """
205
+ scope, real_key = self.split_scope_key(key)
206
+ if scope is None or scope == self._scope:
207
+ # get from self
208
+ if real_key in self._module_dict:
209
+ return self._module_dict[real_key]
210
+ else:
211
+ # get from self._children
212
+ if scope in self._children:
213
+ return self._children[scope].get(real_key)
214
+ else:
215
+ # goto root
216
+ parent = self.parent
217
+ while parent.parent is not None:
218
+ parent = parent.parent
219
+ return parent.get(key)
220
+
221
+ def build(self, *args, **kwargs):
222
+ return self.build_func(*args, **kwargs, registry=self)
223
+
224
+ def _add_children(self, registry):
225
+ """Add children for a registry.
226
+ The ``registry`` will be added as children based on its scope.
227
+ The parent registry could build objects from children registry.
228
+ Example:
229
+ >>> models = Registry('models')
230
+ >>> mmdet_models = Registry('models', parent=models)
231
+ >>> @mmdet_models.register_module()
232
+ >>> class ResNet:
233
+ >>> pass
234
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
235
+ """
236
+
237
+ assert isinstance(registry, Registry)
238
+ assert registry.scope is not None
239
+ assert registry.scope not in self.children, \
240
+ f'scope {registry.scope} exists in {self.name} registry'
241
+ self.children[registry.scope] = registry
242
+
243
+ @deprecated_api_warning(name_dict=dict(module_class='module'))
244
+ def _register_module(self, module, module_name=None, force=False):
245
+ if not inspect.isclass(module) and not inspect.isfunction(module):
246
+ raise TypeError('module must be a class or a function, '
247
+ f'but got {type(module)}')
248
+
249
+ if module_name is None:
250
+ module_name = module.__name__
251
+ if isinstance(module_name, str):
252
+ module_name = [module_name]
253
+ for name in module_name:
254
+ if not force and name in self._module_dict:
255
+ raise KeyError(f'{name} is already registered '
256
+ f'in {self.name}')
257
+ self._module_dict[name] = module
258
+
259
+ def deprecated_register_module(self, cls=None, force=False):
260
+ warnings.warn(
261
+ 'The old API of register_module(module, force=False) '
262
+ 'is deprecated and will be removed, please use the new API '
263
+ 'register_module(name=None, force=False, module=None) instead.',
264
+ DeprecationWarning)
265
+ if cls is None:
266
+ return partial(self.deprecated_register_module, force=force)
267
+ self._register_module(cls, force=force)
268
+ return cls
269
+
270
+ def register_module(self, name=None, force=False, module=None):
271
+ """Register a module.
272
+ A record will be added to `self._module_dict`, whose key is the class
273
+ name or the specified name, and value is the class itself.
274
+ It can be used as a decorator or a normal function.
275
+ Example:
276
+ >>> backbones = Registry('backbone')
277
+ >>> @backbones.register_module()
278
+ >>> class ResNet:
279
+ >>> pass
280
+ >>> backbones = Registry('backbone')
281
+ >>> @backbones.register_module(name='mnet')
282
+ >>> class MobileNet:
283
+ >>> pass
284
+ >>> backbones = Registry('backbone')
285
+ >>> class ResNet:
286
+ >>> pass
287
+ >>> backbones.register_module(ResNet)
288
+ Args:
289
+ name (str | None): The module name to be registered. If not
290
+ specified, the class name will be used.
291
+ force (bool, optional): Whether to override an existing class with
292
+ the same name. Default: False.
293
+ module (type): Module class or function to be registered.
294
+ """
295
+ if not isinstance(force, bool):
296
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
297
+ # NOTE: This is a walkaround to be compatible with the old api,
298
+ # while it may introduce unexpected bugs.
299
+ if isinstance(name, type):
300
+ return self.deprecated_register_module(name, force=force)
301
+
302
+ # raise the error ahead of time
303
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
304
+ raise TypeError(
305
+ 'name must be either of None, an instance of str or a sequence'
306
+ f' of str, but got {type(name)}')
307
+
308
+ # use it as a normal method: x.register_module(module=SomeClass)
309
+ if module is not None:
310
+ self._register_module(module=module, module_name=name, force=force)
311
+ return module
312
+
313
+ # use it as a decorator: @x.register_module()
314
+ def _register(module):
315
+ self._register_module(module=module, module_name=name, force=force)
316
+ return module
317
+
318
+ return _register