Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| from collections import OrderedDict | |
| import random | |
| from . import model_utils | |
| class BaseModel: | |
| """docstring for BaseModel""" | |
| def __init__(self): | |
| super(BaseModel, self).__init__() | |
| self.name = "Base" | |
| def initialize(self, opt): | |
| self.opt = opt | |
| self.gpu_ids = self.opt.gpu_ids | |
| self.device = torch.device('cuda:%d' % self.gpu_ids[0] if self.gpu_ids else 'cpu') | |
| self.is_train = self.opt.mode == "train" | |
| # inherit to define network model | |
| self.models_name = [] | |
| def setup(self): | |
| # print("%s with Model [%s]" % (self.opt.mode.capitalize(), self.name)) | |
| if self.is_train: | |
| self.set_train() | |
| # define loss function | |
| self.criterionGAN = model_utils.GANLoss(gan_type=self.opt.gan_type).to(self.device) | |
| self.criterionL1 = torch.nn.L1Loss().to(self.device) | |
| self.criterionMSE = torch.nn.MSELoss().to(self.device) | |
| self.criterionTV = model_utils.TVLoss().to(self.device) | |
| torch.nn.DataParallel(self.criterionGAN, self.gpu_ids) | |
| torch.nn.DataParallel(self.criterionL1, self.gpu_ids) | |
| torch.nn.DataParallel(self.criterionMSE, self.gpu_ids) | |
| torch.nn.DataParallel(self.criterionTV, self.gpu_ids) | |
| # inherit to set up train/val/test status | |
| self.losses_name = [] | |
| self.optims = [] | |
| self.schedulers = [] | |
| else: | |
| self.set_eval() | |
| def set_eval(self): | |
| print("Set model to Test state.") | |
| for name in self.models_name: | |
| if isinstance(name, str): | |
| net = getattr(self, 'net_' + name) | |
| if True: | |
| net.eval() | |
| print("Set net_%s to EVAL." % name) | |
| else: | |
| net.train() | |
| self.is_train = False | |
| def set_train(self): | |
| print("Set model to Train state.") | |
| for name in self.models_name: | |
| if isinstance(name, str): | |
| net = getattr(self, 'net_' + name) | |
| net.train() | |
| print("Set net_%s to TRAIN." % name) | |
| self.is_train = True | |
| def set_requires_grad(self, parameters, requires_grad=False): | |
| if not isinstance(parameters, list): | |
| parameters = [parameters] | |
| for param in parameters: | |
| if param is not None: | |
| param.requires_grad = requires_grad | |
| def get_latest_visuals(self, visuals_name): | |
| visual_ret = OrderedDict() | |
| for name in visuals_name: | |
| if isinstance(name, str) and hasattr(self, name): | |
| visual_ret[name] = getattr(self, name) | |
| return visual_ret | |
| def get_latest_losses(self, losses_name): | |
| errors_ret = OrderedDict() | |
| for name in losses_name: | |
| if isinstance(name, str): | |
| cur_loss = float(getattr(self, 'loss_' + name)) | |
| # cur_loss_lambda = 1. if len(losses_name) == 1 else float(getattr(self.opt, 'lambda_' + name)) | |
| # errors_ret[name] = cur_loss * cur_loss_lambda | |
| errors_ret[name] = cur_loss | |
| return errors_ret | |
| def feed_batch(self, batch): | |
| pass | |
| def forward(self): | |
| pass | |
| def optimize_paras(self): | |
| pass | |
| def update_learning_rate(self): | |
| for scheduler in self.schedulers: | |
| scheduler.step() | |
| lr = self.optims[0].param_groups[0]['lr'] | |
| return lr | |
| def save_ckpt(self, epoch, models_name): | |
| for name in models_name: | |
| if isinstance(name, str): | |
| save_filename = '%s_net_%s.pth' % (epoch, name) | |
| save_path = os.path.join(self.opt.ckpt_dir, save_filename) | |
| net = getattr(self, 'net_' + name) | |
| # save cpu params, so that it can be used in other GPU settings | |
| if len(self.gpu_ids) > 0 and torch.cuda.is_available(): | |
| torch.save(net.module.cpu().state_dict(), save_path) | |
| net.to(self.gpu_ids[0]) | |
| net = torch.nn.DataParallel(net, self.gpu_ids) | |
| else: | |
| torch.save(net.cpu().state_dict(), save_path) | |
| def load_ckpt(self, epoch, models_name): | |
| # print(models_name) | |
| for name in models_name: | |
| if isinstance(name, str): | |
| load_filename = '%s_net_%s.pth' % (epoch, name) | |
| # load_path = os.path.join(self.opt.ckpt_dir, load_filename) | |
| # assert os.path.isfile(load_path), "File '%s' does not exist." % load_path | |
| # pretrained_state_dict = torch.load(load_path, map_location=str(self.device)) | |
| pretrained_state_dict = torch.load('checkpoints/30_net_gen.pth', map_location=str('cuda:0')) | |
| if hasattr(pretrained_state_dict, '_metadata'): | |
| del pretrained_state_dict._metadata | |
| net = getattr(self, 'net_' + name) | |
| if isinstance(net, torch.nn.DataParallel): | |
| net = net.module | |
| # load only existing keys | |
| pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k in net.state_dict()} | |
| # for k, v in pretrained_state_dict.items(): | |
| # print(k) | |
| # assert False | |
| net.load_state_dict(pretrained_dict) | |
| print("[Info] Successfully load trained weights for net_%s." % name) | |
| def clean_ckpt(self, epoch, models_name): | |
| for name in models_name: | |
| if isinstance(name, str): | |
| load_filename = '%s_net_%s.pth' % (epoch, name) | |
| load_path = os.path.join(self.opt.ckpt_dir, load_filename) | |
| if os.path.isfile(load_path): | |
| os.remove(load_path) | |
| def gradient_penalty(self, input_img, generate_img): | |
| # interpolate sample | |
| alpha = torch.rand(input_img.size(0), 1, 1, 1).to(self.device) | |
| inter_img = (alpha * input_img.data + (1 - alpha) * generate_img.data).requires_grad_(True) | |
| inter_img_prob, _ = self.net_dis(inter_img) | |
| # computer gradient penalty: x: inter_img, y: inter_img_prob | |
| # (L2_norm(dy/dx) - 1)**2 | |
| dydx = torch.autograd.grad(outputs=inter_img_prob, | |
| inputs=inter_img, | |
| grad_outputs=torch.ones(inter_img_prob.size()).to(self.device), | |
| retain_graph=True, | |
| create_graph=True, | |
| only_inputs=True)[0] | |
| dydx = dydx.view(dydx.size(0), -1) | |
| dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) | |
| return torch.mean((dydx_l2norm - 1) ** 2) | |