gpu_symbol / engine /core /workspace.py
himipo's picture
first
63e9186
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import inspect
import importlib
import functools
from collections import defaultdict
from typing import Any, Dict, Optional, List
GLOBAL_CONFIG = defaultdict(dict)
def register(dct :Any=GLOBAL_CONFIG, name=None, force=False):
"""
dct:
if dct is Dict, register foo into dct as key-value pair
if dct is Clas, register as modules attibute
force
whether force register.
"""
def decorator(foo):
register_name = foo.__name__ if name is None else name
if not force:
if inspect.isclass(dct):
assert not hasattr(dct, foo.__name__), \
f'module {dct.__name__} has {foo.__name__}'
else:
assert foo.__name__ not in dct, \
f'{foo.__name__} has been already registered'
if inspect.isfunction(foo):
@functools.wraps(foo)
def wrap_func(*args, **kwargs):
return foo(*args, **kwargs)
if isinstance(dct, dict):
dct[foo.__name__] = wrap_func
elif inspect.isclass(dct):
setattr(dct, foo.__name__, wrap_func)
else:
raise AttributeError('')
return wrap_func
elif inspect.isclass(foo):
dct[register_name] = extract_schema(foo)
else:
raise ValueError(f'Do not support {type(foo)} register')
return foo
return decorator
def extract_schema(module: type):
"""
Args:
module (type),
Return:
Dict,
"""
argspec = inspect.getfullargspec(module.__init__)
arg_names = [arg for arg in argspec.args if arg != 'self']
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
num_requires = len(arg_names) - num_defualts
schame = dict()
schame['_name'] = module.__name__
schame['_pymodule'] = importlib.import_module(module.__module__)
schame['_inject'] = getattr(module, '__inject__', [])
schame['_share'] = getattr(module, '__share__', [])
schame['_kwargs'] = {}
for i, name in enumerate(arg_names):
if name in schame['_share']:
assert i >= num_requires, 'share config must have default value.'
value = argspec.defaults[i - num_requires]
elif i >= num_requires:
value = argspec.defaults[i - num_requires]
else:
value = None
schame[name] = value
schame['_kwargs'][name] = value
return schame
def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
"""
"""
assert type(type_or_name) in (type, str), 'create should be modules or name.'
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
if name in global_cfg:
if hasattr(global_cfg[name], '__dict__'):
return global_cfg[name]
else:
raise ValueError('The module {} is not registered'.format(name))
cfg = global_cfg[name]
if isinstance(cfg, dict) and 'type' in cfg:
_cfg: dict = global_cfg[cfg['type']]
# clean args
_keys = [k for k in _cfg.keys() if not k.startswith('_')]
for _arg in _keys:
del _cfg[_arg]
_cfg.update(_cfg['_kwargs']) # restore default args
_cfg.update(cfg) # load config args
_cfg.update(kwargs)
name = _cfg.pop('type') # pop extra key `type` (from cfg)
return create(name, global_cfg)
module = getattr(cfg['_pymodule'], name)
module_kwargs = {}
module_kwargs.update(cfg)
# shared var
for k in cfg['_share']:
if k in global_cfg:
module_kwargs[k] = global_cfg[k]
else:
module_kwargs[k] = cfg[k]
# inject
for k in cfg['_inject']:
_k = cfg[k]
if _k is None:
continue
if isinstance(_k, str):
if _k not in global_cfg:
raise ValueError(f'Missing inject config of {_k}.')
_cfg = global_cfg[_k]
if isinstance(_cfg, dict):
module_kwargs[k] = create(_cfg['_name'], global_cfg)
else:
module_kwargs[k] = _cfg
elif isinstance(_k, dict):
if 'type' not in _k.keys():
raise ValueError('Missing inject for `type` style.')
_type = str(_k['type'])
if _type not in global_cfg:
raise ValueError(f'Missing {_type} in inspect stage.')
_cfg: dict = global_cfg[_type]
# clean args
_keys = [k for k in _cfg.keys() if not k.startswith('_')]
for _arg in _keys:
del _cfg[_arg]
_cfg.update(_cfg['_kwargs']) # restore default values
_cfg.update(_k) # load config args
name = _cfg.pop('type') # pop extra key (`type` from _k)
module_kwargs[k] = create(name, global_cfg)
else:
raise ValueError(f'Inject does not support {_k}')
module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith('_')}
# extra_args = set(module_kwargs.keys()) - set(arg_names)
# if len(extra_args) > 0:
# raise RuntimeError(f'Error: unknown args {extra_args} for {module}')
return module(**module_kwargs)