| | import numpy as np |
| | import torch |
| | import os |
| | from collections import OrderedDict |
| | from torch.autograd import Variable |
| | import itertools |
| | import util.util as util |
| | from util.util import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler |
| | from util.image_pool import ImagePool |
| | from .base_model import BaseModel |
| | from . import networks |
| | from .unit_network import * |
| | import sys |
| |
|
| | def get_config(config): |
| | import yaml |
| | with open(config, 'r') as stream: |
| | return yaml.load(stream) |
| |
|
| | class UNITModel(BaseModel): |
| | def name(self): |
| | return 'UNITModel' |
| |
|
| | def initialize(self, opt): |
| | BaseModel.initialize(self, opt) |
| | |
| | self.config = get_config(opt.config) |
| | nb = opt.batchSize |
| | size = opt.fineSize |
| | self.input_A = self.Tensor(nb, opt.input_nc, size, size) |
| | self.input_B = self.Tensor(nb, opt.output_nc, size, size) |
| |
|
| | |
| | |
| | |
| |
|
| | self.gen_a = VAEGen(self.config['input_dim_a'], self.config['gen']) |
| | self.gen_b = VAEGen(self.config['input_dim_a'], self.config['gen']) |
| |
|
| | if self.isTrain: |
| | self.dis_a = MsImageDis(self.config['input_dim_a'], self.config['dis']) |
| | self.dis_b = MsImageDis(self.config['input_dim_b'], self.config['dis']) |
| | if not self.isTrain or opt.continue_train: |
| | which_epoch = opt.which_epoch |
| | self.load_network(self.gen_a, 'G_A', which_epoch) |
| | self.load_network(self.gen_b, 'G_B', which_epoch) |
| | if self.isTrain: |
| | self.load_network(self.dis_a, 'D_A', which_epoch) |
| | self.load_network(self.dis_b, 'D_B', which_epoch) |
| |
|
| | if self.isTrain: |
| | self.old_lr = self.config['lr'] |
| | self.fake_A_pool = ImagePool(opt.pool_size) |
| | self.fake_B_pool = ImagePool(opt.pool_size) |
| | |
| | |
| | beta1 = self.config['beta1'] |
| | beta2 = self.config['beta2'] |
| | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) |
| | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) |
| | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], |
| | lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay']) |
| | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], |
| | lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay']) |
| | self.dis_scheduler = get_scheduler(self.dis_opt, self.config) |
| | self.gen_scheduler = get_scheduler(self.gen_opt, self.config) |
| |
|
| | |
| | |
| | self.dis_a.apply(weights_init('gaussian')) |
| | self.dis_b.apply(weights_init('gaussian')) |
| |
|
| | |
| | if 'vgg_w' in self.config.keys() and self.config['vgg_w'] > 0: |
| | self.vgg = load_vgg16(self.config['vgg_model_path'] + '/models') |
| | self.vgg.eval() |
| | for param in self.vgg.parameters(): |
| | param.requires_grad = False |
| | self.gen_a.cuda() |
| | self.gen_b.cuda() |
| | self.dis_a.cuda() |
| | self.dis_b.cuda() |
| |
|
| | print('---------- Networks initialized -------------') |
| | networks.print_network(self.gen_a) |
| | networks.print_network(self.gen_b) |
| | if self.isTrain: |
| | networks.print_network(self.dis_a) |
| | networks.print_network(self.dis_b) |
| | print('-----------------------------------------------') |
| |
|
| | def set_input(self, input): |
| | AtoB = self.opt.which_direction == 'AtoB' |
| | input_A = input['A' if AtoB else 'B'] |
| | input_B = input['B' if AtoB else 'A'] |
| | self.input_A.resize_(input_A.size()).copy_(input_A) |
| | self.input_B.resize_(input_B.size()).copy_(input_B) |
| | self.image_paths = input['A_paths' if AtoB else 'B_paths'] |
| | self.real_A = Variable(self.input_A.cuda()) |
| | self.real_B = Variable(self.input_B.cuda()) |
| |
|
| | |
| | |
| | |
| |
|
| | def test(self): |
| | self.real_A = Variable(self.input_A.cuda(), volatile=True) |
| | self.real_B = Variable(self.input_B.cuda(), volatile=True) |
| | h_a, n_a = self.gen_a.encode(self.real_A) |
| | h_b, n_b = self.gen_b.encode(self.real_B) |
| | x_a_recon = self.gen_a.decode(h_a + n_a) + x_a*1 |
| | x_b_recon = self.gen_b.decode(h_b + n_b) + x_b*1 |
| | x_ba = self.gen_a.decode(h_b + n_b) + x_b*1 |
| | x_ab = self.gen_b.decode(h_a + n_a) + x_a*1 |
| | h_b_recon, n_b_recon = self.gen_a.encode(x_ba) |
| | h_a_recon, n_a_recon = self.gen_b.encode(x_ab) |
| | x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + x_ab*1 if self.config['recon_x_cyc_w'] > 0 else None |
| | x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + x_ba*1 if self.config['recon_x_cyc_w'] > 0 else None |
| | self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba |
| | self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab |
| |
|
| | |
| | def get_image_paths(self): |
| | return self.image_paths |
| |
|
| |
|
| | def optimize_parameters(self): |
| | self.gen_update(self.real_A, self.real_B) |
| | self.dis_update(self.real_A, self.real_B) |
| |
|
| | def recon_criterion(self, input, target): |
| | return torch.mean(torch.abs(input - target)) |
| |
|
| | def forward(self, x_a, x_b): |
| | self.eval() |
| | x_a.volatile = True |
| | x_b.volatile = True |
| | h_a, _ = self.gen_a.encode(x_a) |
| | h_b, _ = self.gen_b.encode(x_b) |
| | x_ba = self.gen_a.decode(h_b) |
| | x_ab = self.gen_b.decode(h_a) |
| | self.train() |
| | return x_ab, x_ba |
| |
|
| | def __compute_kl(self, mu): |
| | |
| | |
| | |
| | |
| | |
| | mu_2 = torch.pow(mu, 2) |
| | encoding_loss = torch.mean(mu_2) |
| | return encoding_loss |
| |
|
| | def gen_update(self, x_a, x_b): |
| | self.gen_opt.zero_grad() |
| | |
| | h_a, n_a = self.gen_a.encode(x_a) |
| | h_b, n_b = self.gen_b.encode(x_b) |
| | |
| | x_a_recon = self.gen_a.decode(h_a + n_a) + 0*x_a |
| | x_b_recon = self.gen_b.decode(h_b + n_b) + 0*x_b |
| | |
| | x_ba = self.gen_a.decode(h_b + n_b) + 0*x_b |
| | x_ab = self.gen_b.decode(h_a + n_a) + 0*x_a |
| | |
| | h_b_recon, n_b_recon = self.gen_a.encode(x_ba) |
| | h_a_recon, n_a_recon = self.gen_b.encode(x_ab) |
| | |
| | x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + 0*x_ab if self.config['recon_x_cyc_w'] > 0 else None |
| | x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + 0*x_ba if self.config['recon_x_cyc_w'] > 0 else None |
| |
|
| | |
| | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) |
| | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) |
| | self.loss_gen_recon_kl_a = self.__compute_kl(h_a) |
| | self.loss_gen_recon_kl_b = self.__compute_kl(h_b) |
| | self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) |
| | self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) |
| | self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) |
| | self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) |
| | |
| | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) |
| | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) |
| | |
| | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if self.config['vgg_w'] > 0 else 0 |
| | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if self.config['vgg_w'] > 0 else 0 |
| | |
| | self.loss_gen_total = self.config['gan_w'] * self.loss_gen_adv_a + \ |
| | self.config['gan_w'] * self.loss_gen_adv_b + \ |
| | self.config['recon_x_w'] * self.loss_gen_recon_x_a + \ |
| | self.config['recon_kl_w'] * self.loss_gen_recon_kl_a + \ |
| | self.config['recon_x_w'] * self.loss_gen_recon_x_b + \ |
| | self.config['recon_kl_w'] * self.loss_gen_recon_kl_b + \ |
| | self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ |
| | self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ |
| | self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ |
| | self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ |
| | self.config['vgg_w'] * self.loss_gen_vgg_a + \ |
| | self.config['vgg_w'] * self.loss_gen_vgg_b |
| | self.loss_gen_total.backward() |
| | self.gen_opt.step() |
| | self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba |
| | self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab |
| |
|
| | def compute_vgg_loss(self, vgg, img, target): |
| | img_vgg = vgg_preprocess(img) |
| | target_vgg = vgg_preprocess(target) |
| | img_fea = vgg(img_vgg) |
| | target_fea = vgg(target_vgg) |
| | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) |
| |
|
| | def dis_update(self, x_a, x_b): |
| | self.dis_opt.zero_grad() |
| | |
| | h_a, n_a = self.gen_a.encode(x_a) |
| | h_b, n_b = self.gen_b.encode(x_b) |
| | |
| | x_ba = self.gen_a.decode(h_b + n_b) |
| | x_ab = self.gen_b.decode(h_a + n_a) |
| | |
| | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) |
| | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) |
| | self.loss_dis_total = self.config['gan_w'] * self.loss_dis_a + self.config['gan_w'] * self.loss_dis_b |
| | self.loss_dis_total.backward() |
| | self.dis_opt.step() |
| |
|
| | def get_current_errors(self): |
| | D_A = self.loss_dis_a.data[0] |
| | G_A = self.loss_gen_adv_a.data[0] |
| | kl_A = self.loss_gen_recon_kl_a.data[0] |
| | Cyc_A = self.loss_gen_cyc_x_a.data[0] |
| | D_B = self.loss_dis_b.data[0] |
| | G_B = self.loss_gen_adv_b.data[0] |
| | kl_B = self.loss_gen_recon_kl_b.data[0] |
| | Cyc_B = self.loss_gen_cyc_x_b.data[0] |
| | if self.config['vgg_w'] > 0: |
| | vgg_A = self.loss_gen_vgg_a |
| | vgg_B = self.loss_gen_vgg_b |
| | return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('kl_A', kl_A), ('vgg_A', vgg_A), |
| | ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('kl_B', kl_B), ('vgg_B', vgg_B)]) |
| | else: |
| | return OrderedDict([('D_A', D_A), ('G_A', G_A), ('kl_A', kl_A), ('Cyc_A', Cyc_A), |
| | ('D_B', D_B), ('G_B', G_B), ('kl_B', kl_B), ('Cyc_B', Cyc_B)]) |
| |
|
| | def get_current_visuals(self): |
| | real_A = util.tensor2im(self.real_A.data) |
| | recon_A = util.tensor2im(self.x_a_recon.data) |
| | A_B = util.tensor2im(self.x_ab.data) |
| | ABA = util.tensor2im(self.x_aba.data) |
| | real_B = util.tensor2im(self.real_B.data) |
| | recon_B = util.tensor2im(self.x_b_recon.data) |
| | B_A = util.tensor2im(self.x_ba.data) |
| | BAB = util.tensor2im(self.x_b_recon.data) |
| | return OrderedDict([('real_A', real_A), ('A_B', A_B), ('recon_A', recon_A), ('ABA', ABA), |
| | ('real_B', real_B), ('B_A', B_A), ('recon_B', recon_B), ('BAB', BAB)]) |
| |
|
| | def save(self, label): |
| | self.save_network(self.gen_a, 'G_A', label, self.gpu_ids) |
| | self.save_network(self.dis_a, 'D_A', label, self.gpu_ids) |
| | self.save_network(self.gen_b, 'G_B', label, self.gpu_ids) |
| | self.save_network(self.dis_b, 'D_B', label, self.gpu_ids) |
| |
|
| | def update_learning_rate(self): |
| | lrd = self.config['lr'] / self.opt.niter_decay |
| | lr = self.old_lr - lrd |
| | for param_group in self.gen_a.param_groups: |
| | param_group['lr'] = lr |
| | for param_group in self.gen_b.param_groups: |
| | param_group['lr'] = lr |
| | for param_group in self.dis_a.param_groups: |
| | param_group['lr'] = lr |
| | for param_group in self.dis_b.param_groups: |
| | param_group['lr'] = lr |
| |
|
| | print('update learning rate: %f -> %f' % (self.old_lr, lr)) |
| | self.old_lr = lr |