import os import importlib.util _network_factory = { } def get_network(cfg): arch = cfg.network # network: 'ro_34' heads = cfg.heads # heads: {'ct_hm': 9, 'wh': 2} head_conv = cfg.head_conv num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 arch = arch[:arch.find('_')] if '_' in arch else arch get_model = _network_factory[arch] network = get_model(num_layers, heads, head_conv) return network def make_network(cfg): module = '.'.join(['lib.networks', cfg.task]) # task: 'snake' path = os.path.join('lib/networks', cfg.task, '__init__.py') print("网络路径:", path) print("模型ID:", cfg.task) print("model:", cfg.model) # 使用importlib替代已弃用的imp模块 spec = importlib.util.spec_from_file_location(module, path) module_obj = importlib.util.module_from_spec(spec) spec.loader.exec_module(module_obj) return module_obj.get_network(cfg) # 注意!!!这里的 get_network() 函数不是上面定义的的那个,而是动态导入进来的模块的函数 # 此处的 get_network() 函数来自lib/networks/snake 文件