| import torch |
| from .base_model import BaseModel |
| from . import networks |
|
|
|
|
| class Pix2Pix4DepthModel(BaseModel): |
| """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. |
| |
| The model training requires '--dataset_mode aligned' dataset. |
| By default, it uses a '--netG unet256' U-Net generator, |
| a '--netD basic' discriminator (PatchGAN), |
| and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). |
| |
| pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf |
| """ |
| @staticmethod |
| def modify_commandline_options(parser, is_train=True): |
| """Add new dataset-specific options, and rewrite default values for existing options. |
| |
| Parameters: |
| parser -- original option parser |
| is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. |
| |
| Returns: |
| the modified parser. |
| |
| For pix2pix, we do not use image buffer |
| The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 |
| By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. |
| """ |
| |
| parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge') |
| if is_train: |
| parser.set_defaults(pool_size=0, gan_mode='vanilla',) |
| parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss') |
| return parser |
|
|
| def __init__(self, opt): |
| """Initialize the pix2pix class. |
| |
| Parameters: |
| opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
| """ |
| BaseModel.__init__(self, opt) |
| |
|
|
| self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] |
| |
|
|
| |
| if self.isTrain: |
| self.visual_names = ['outer','inner', 'fake_B', 'real_B'] |
| else: |
| self.visual_names = ['fake_B'] |
|
|
| |
| if self.isTrain: |
| self.model_names = ['G','D'] |
| else: |
| self.model_names = ['G'] |
|
|
| |
| self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none', |
| False, 'normal', 0.02, self.gpu_ids) |
|
|
| if self.isTrain: |
| self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, |
| opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) |
|
|
| if self.isTrain: |
| |
| self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) |
| self.criterionL1 = torch.nn.L1Loss() |
| |
| self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999)) |
| self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999)) |
| self.optimizers.append(self.optimizer_G) |
| self.optimizers.append(self.optimizer_D) |
|
|
| def set_input_train(self, input): |
| self.outer = input['data_outer'].to(self.device) |
| self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False) |
|
|
| self.inner = input['data_inner'].to(self.device) |
| self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False) |
|
|
| self.image_paths = input['image_path'] |
|
|
| if self.isTrain: |
| self.gtfake = input['data_gtfake'].to(self.device) |
| self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False) |
| self.real_B = self.gtfake |
|
|
| self.real_A = torch.cat((self.outer, self.inner), 1) |
|
|
| def set_input(self, outer, inner): |
| inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0) |
| outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0) |
|
|
| inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner)) |
| outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer)) |
|
|
| inner = self.normalize(inner) |
| outer = self.normalize(outer) |
|
|
| self.real_A = torch.cat((outer, inner), 1).to(self.device) |
|
|
|
|
| def normalize(self, input): |
| input = input * 2 |
| input = input - 1 |
| return input |
|
|
| def forward(self): |
| """Run forward pass; called by both functions <optimize_parameters> and <test>.""" |
| self.fake_B = self.netG(self.real_A) |
|
|
| def backward_D(self): |
| """Calculate GAN loss for the discriminator""" |
| |
| fake_AB = torch.cat((self.real_A, self.fake_B), 1) |
| pred_fake = self.netD(fake_AB.detach()) |
| self.loss_D_fake = self.criterionGAN(pred_fake, False) |
| |
| real_AB = torch.cat((self.real_A, self.real_B), 1) |
| pred_real = self.netD(real_AB) |
| self.loss_D_real = self.criterionGAN(pred_real, True) |
| |
| self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 |
| self.loss_D.backward() |
|
|
| def backward_G(self): |
| """Calculate GAN and L1 loss for the generator""" |
| |
| fake_AB = torch.cat((self.real_A, self.fake_B), 1) |
| pred_fake = self.netD(fake_AB) |
| self.loss_G_GAN = self.criterionGAN(pred_fake, True) |
| |
| self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 |
| |
| self.loss_G = self.loss_G_L1 + self.loss_G_GAN |
| self.loss_G.backward() |
|
|
| def optimize_parameters(self): |
| self.forward() |
| |
| self.set_requires_grad(self.netD, True) |
| self.optimizer_D.zero_grad() |
| self.backward_D() |
| self.optimizer_D.step() |
| |
| self.set_requires_grad(self.netD, False) |
| self.optimizer_G.zero_grad() |
| self.backward_G() |
| self.optimizer_G.step() |