Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import importlib | |
| from lib.models.networks.discriminator import MultiscaleDiscriminator, ImageDiscriminator | |
| from lib.models.networks.generator import ModulateGenerator | |
| from lib.models.networks.encoder import ResSEAudioEncoder, ResNeXtEncoder, ResSESyncEncoder, FanEncoder | |
| # import util.util as util | |
| def find_class_in_module(target_cls_name, module): | |
| target_cls_name = target_cls_name.replace('_', '').lower() | |
| clslib = importlib.import_module(module) | |
| cls = None | |
| for name, clsobj in clslib.__dict__.items(): | |
| if name.lower() == target_cls_name: | |
| cls = clsobj | |
| if cls is None: | |
| print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) | |
| exit(0) | |
| return cls | |
| def find_network_using_name(target_network_name, filename): | |
| target_class_name = target_network_name + filename | |
| module_name = 'lib.models.networks.' + filename | |
| network = find_class_in_module(target_class_name, module_name) | |
| return network | |
| def define_networks(opt, name, _type): | |
| net = find_network_using_name(name, _type) | |
| net = net(opt) | |
| return net | |