| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import inspect |
| import importlib |
| import functools |
| from collections import defaultdict |
| from typing import Any, Dict, Optional, List |
|
|
|
|
| GLOBAL_CONFIG = defaultdict(dict) |
|
|
|
|
| def register(dct :Any=GLOBAL_CONFIG, name=None, force=False): |
| """ |
| dct: |
| if dct is Dict, register foo into dct as key-value pair |
| if dct is Clas, register as modules attibute |
| force |
| whether force register. |
| """ |
| def decorator(foo): |
| register_name = foo.__name__ if name is None else name |
| if not force: |
| if inspect.isclass(dct): |
| assert not hasattr(dct, foo.__name__), \ |
| f'module {dct.__name__} has {foo.__name__}' |
| else: |
| assert foo.__name__ not in dct, \ |
| f'{foo.__name__} has been already registered' |
|
|
| if inspect.isfunction(foo): |
| @functools.wraps(foo) |
| def wrap_func(*args, **kwargs): |
| return foo(*args, **kwargs) |
| if isinstance(dct, dict): |
| dct[foo.__name__] = wrap_func |
| elif inspect.isclass(dct): |
| setattr(dct, foo.__name__, wrap_func) |
| else: |
| raise AttributeError('') |
| return wrap_func |
|
|
| elif inspect.isclass(foo): |
| dct[register_name] = extract_schema(foo) |
|
|
| else: |
| raise ValueError(f'Do not support {type(foo)} register') |
|
|
| return foo |
|
|
| return decorator |
|
|
|
|
|
|
| def extract_schema(module: type): |
| """ |
| Args: |
| module (type), |
| Return: |
| Dict, |
| """ |
| argspec = inspect.getfullargspec(module.__init__) |
| arg_names = [arg for arg in argspec.args if arg != 'self'] |
| num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0 |
| num_requires = len(arg_names) - num_defualts |
|
|
| schame = dict() |
| schame['_name'] = module.__name__ |
| schame['_pymodule'] = importlib.import_module(module.__module__) |
| schame['_inject'] = getattr(module, '__inject__', []) |
| schame['_share'] = getattr(module, '__share__', []) |
| schame['_kwargs'] = {} |
| for i, name in enumerate(arg_names): |
| if name in schame['_share']: |
| assert i >= num_requires, 'share config must have default value.' |
| value = argspec.defaults[i - num_requires] |
|
|
| elif i >= num_requires: |
| value = argspec.defaults[i - num_requires] |
|
|
| else: |
| value = None |
|
|
| schame[name] = value |
| schame['_kwargs'][name] = value |
|
|
| return schame |
|
|
|
|
| def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs): |
| """ |
| """ |
| assert type(type_or_name) in (type, str), 'create should be modules or name.' |
|
|
| name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__ |
|
|
| if name in global_cfg: |
| if hasattr(global_cfg[name], '__dict__'): |
| return global_cfg[name] |
| else: |
| raise ValueError('The module {} is not registered'.format(name)) |
|
|
| cfg = global_cfg[name] |
|
|
| if isinstance(cfg, dict) and 'type' in cfg: |
| _cfg: dict = global_cfg[cfg['type']] |
| |
| _keys = [k for k in _cfg.keys() if not k.startswith('_')] |
| for _arg in _keys: |
| del _cfg[_arg] |
| _cfg.update(_cfg['_kwargs']) |
| _cfg.update(cfg) |
| _cfg.update(kwargs) |
| name = _cfg.pop('type') |
|
|
| return create(name, global_cfg) |
|
|
| module = getattr(cfg['_pymodule'], name) |
| module_kwargs = {} |
| module_kwargs.update(cfg) |
|
|
| |
| for k in cfg['_share']: |
| if k in global_cfg: |
| module_kwargs[k] = global_cfg[k] |
| else: |
| module_kwargs[k] = cfg[k] |
|
|
| |
| for k in cfg['_inject']: |
| _k = cfg[k] |
|
|
| if _k is None: |
| continue |
|
|
| if isinstance(_k, str): |
| if _k not in global_cfg: |
| raise ValueError(f'Missing inject config of {_k}.') |
|
|
| _cfg = global_cfg[_k] |
|
|
| if isinstance(_cfg, dict): |
| module_kwargs[k] = create(_cfg['_name'], global_cfg) |
| else: |
| module_kwargs[k] = _cfg |
|
|
| elif isinstance(_k, dict): |
| if 'type' not in _k.keys(): |
| raise ValueError('Missing inject for `type` style.') |
|
|
| _type = str(_k['type']) |
| if _type not in global_cfg: |
| raise ValueError(f'Missing {_type} in inspect stage.') |
|
|
| |
| _cfg: dict = global_cfg[_type] |
| |
| _keys = [k for k in _cfg.keys() if not k.startswith('_')] |
| for _arg in _keys: |
| del _cfg[_arg] |
| _cfg.update(_cfg['_kwargs']) |
| _cfg.update(_k) |
| name = _cfg.pop('type') |
| module_kwargs[k] = create(name, global_cfg) |
|
|
| else: |
| raise ValueError(f'Inject does not support {_k}') |
|
|
| |
| module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith('_')} |
|
|
| |
| |
| |
| |
|
|
| return module(**module_kwargs) |
|
|