XYScanNet_Demo / models /networks.py
HanzhouLiu
Add application file
b56342d
raw
history blame contribute delete
518 Bytes
import torch.nn as nn
from models.XYScanNet import XYScanNet
from models.XYScanNetP import XYScanNetP
def get_generator(model_config):
generator_name = model_config['g_name']
if generator_name == 'XYScanNet':
model_g = XYScanNet()
elif generator_name == 'XYScanNetP':
model_g = XYScanNetP()
else:
raise ValueError("Generator Network [%s] not recognized." % generator_name)
return nn.DataParallel(model_g)
def get_nets(model_config):
return get_generator(model_config)