Spaces:
Running
Running
| import torch | |
| from . import model_util | |
| from .pix2pix_model import define_G as pix2pix_G | |
| from .pix2pixHD_model import define_G as pix2pixHD_G | |
| # from .video_model import MosaicNet | |
| # from .videoHD_model import MosaicNet as MosaicNet_HD | |
| from .BiSeNet_model import BiSeNet | |
| from .BVDNet import define_G as video_G | |
| def show_paramsnumber(net,netname='net'): | |
| parameters = sum(param.numel() for param in net.parameters()) | |
| parameters = round(parameters/1e6,2) | |
| print(netname+' parameters: '+str(parameters)+'M') | |
| def pix2pix(opt): | |
| # print(opt.model_path,opt.netG) | |
| if opt.netG == 'HD': | |
| netG = pix2pixHD_G(3, 3, 64, 'global' ,4) | |
| else: | |
| netG = pix2pix_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[]) | |
| show_paramsnumber(netG,'netG') | |
| netG.load_state_dict(torch.load(opt.model_path)) | |
| netG = model_util.todevice(netG,opt.gpu_id) | |
| netG.eval() | |
| return netG | |
| def style(opt): | |
| if opt.edges: | |
| netG = pix2pix_G(1, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[]) | |
| else: | |
| netG = pix2pix_G(3, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=False, init_type='normal', gpu_ids=[]) | |
| #in other to load old pretrain model | |
| #https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/models/base_model.py | |
| if isinstance(netG, torch.nn.DataParallel): | |
| netG = netG.module | |
| # if you are using PyTorch newer than 0.4 (e.g., built from | |
| # GitHub source), you can remove str() on self.device | |
| state_dict = torch.load(opt.model_path, map_location='cpu') | |
| if hasattr(state_dict, '_metadata'): | |
| del state_dict._metadata | |
| # patch InstanceNorm checkpoints prior to 0.4 | |
| for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop | |
| model_util.patch_instance_norm_state_dict(state_dict, netG, key.split('.')) | |
| netG.load_state_dict(state_dict) | |
| netG = model_util.todevice(netG,opt.gpu_id) | |
| netG.eval() | |
| return netG | |
| def video(opt): | |
| netG = video_G(N=2,n_blocks=4,gpu_id=opt.gpu_id) | |
| show_paramsnumber(netG,'netG') | |
| netG.load_state_dict(torch.load(opt.model_path)) | |
| netG = model_util.todevice(netG,opt.gpu_id) | |
| netG.eval() | |
| return netG | |
| def bisenet(opt,type='roi'): | |
| ''' | |
| type: roi or mosaic | |
| ''' | |
| net = BiSeNet(num_classes=1, context_path='resnet18',train_flag=False) | |
| show_paramsnumber(net,'segment') | |
| if type == 'roi': | |
| net.load_state_dict(torch.load(opt.model_path)) | |
| elif type == 'mosaic': | |
| net.load_state_dict(torch.load(opt.mosaic_position_model_path)) | |
| net = model_util.todevice(net,opt.gpu_id) | |
| net.eval() | |
| return net | |