import torch import models.networks as networks from models.networks.loss import VGG16, SpatialCorrelativeLoss, SSIMLoss, LaplacianLoss from models.networks.time_utils import get_day_night_weights import util.util as util import torch.nn as nn import os import torch.nn.functional as F import math import random class TVLoss(nn.Module): def __init__(self): super(TVLoss, self).__init__() def forward(self, x): batch_size, c, h, w = x.size() tv_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum() tv_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum() return (tv_h + tv_w) / (batch_size * c * h * w) class ThresholdedL1Loss(nn.Module): def __init__(self, threshold): super().__init__() self.threshold = threshold def forward(self, input_tensor): avg_brightness = torch.mean(input_tensor, dim=[1, 2, 3]) loss = F.relu(avg_brightness - self.threshold) return torch.mean(loss) def hour_to_rad(hour): return (hour / 12.0) * math.pi def create_light_direction_map(phi, size, num_channels, device, w_day): B = phi.shape[0] H, W = size x_coords = torch.linspace(-1, 1, W, device=device) y_coords = torch.linspace(-1, 1, H, device=device) y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij') x_grid = x_grid.unsqueeze(0).expand(B, -1, -1) y_grid = y_grid.unsqueeze(0).expand(B, -1, -1) phi_adjusted = phi - math.pi / 2 light_vec_x = torch.cos(phi_adjusted).view(B, 1, 1) light_vec_y = torch.sin(phi_adjusted).view(B, 1, 1) directional_intensity = w_day.view(B, 1, 1) if num_channels == 1: light_map = (x_grid * light_vec_x + y_grid * light_vec_y) * directional_intensity return light_map.unsqueeze(1) elif num_channels == 2: light_map_x = x_grid * light_vec_x * directional_intensity light_map_y = y_grid * light_vec_y * directional_intensity return torch.stack([light_map_x, light_map_y], dim=1) else: raise ValueError("num_channels for light map must be 1 or 2") class Pix2PixModel(torch.nn.Module): @staticmethod def modify_commandline_options(parser, is_train): networks.modify_commandline_options(parser, is_train) parser.add_argument('--daylight_curve_steepness', type=float, default=0.75) parser.add_argument('--lambda_vgg_interp', type=float, default=15.0) parser.add_argument('--attn_layers', type=str, default='4,7,9') parser.add_argument('--patch_nums', type=float, default=128) parser.add_argument('--patch_size', type=int, default=32) parser.add_argument('--loss_mode', type=str, default='cos') parser.add_argument('--use_norm', action='store_true') parser.add_argument('--learned_attn', action='store_true', default=False) parser.add_argument('--T', type=float, default=0.07) parser.add_argument('--lambda_spatial', type=float, default=10.0) parser.add_argument('--lambda_G', type=float, default=1.0) if is_train: parser.add_argument('--night_loss_warmup_iters', type=int, default=10000) parser.add_argument('--lambda_content_focus', type=float, default=5.0) parser.add_argument('--lambda_continuity', type=float, default=10.0) parser.add_argument('--lambda_style', type=float, default=30.0) parser.add_argument('--lambda_identity', type=float, default=10.0) parser.add_argument('--lambda_r1', type=float, default=10.0) parser.add_argument('--lambda_latent_diversity', type=float, default=1.0) parser.add_argument('--use_patch_vgg_loss', action='store_true') parser.add_argument('--lambda_patch_vgg', type=float, default=10.0) parser.add_argument('--patch_vgg_num_patches', type=int, default=64) parser.add_argument('--patch_vgg_size', type=int, default=64) parser.add_argument('--lambda_ppl', type=float, default=2.0) parser.add_argument('--ppl_reg_every', type=int, default=4) parser.add_argument('--r1_reg_every', type=int, default=16) parser.add_argument('--lambda_ssim', type=float, default=1.0) if not hasattr(parser.parse_known_args()[0], 'lambda_z_reg'): parser.add_argument('--lambda_z_reg', type=float, default=0.0) if not hasattr(parser.parse_known_args()[0], 'lambda_night_content'): parser.add_argument('--lambda_night_content', type=float, default=0.0) return parser def __init__(self, opt): super().__init__() self.opt = opt self.attn_layers = [int(i) for i in self.opt.attn_layers.split(',')] self.device = self._get_current_device_for_loss_init() self.netG, self.netD, self.netD2, self.netE = self.initialize_networks(opt) if opt.isTrain: self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor, opt=self.opt) self.criterionFeat = torch.nn.L1Loss() self.criterionIdentity = torch.nn.L1Loss() if (not opt.no_vgg_loss and opt.lambda_vgg > 0) or \ (hasattr(opt, 'lambda_style') and opt.lambda_style > 0) or \ (hasattr(opt, 'lambda_vgg_interp') and opt.lambda_vgg_interp > 0) or \ (hasattr(opt, 'lambda_night_content') and opt.lambda_night_content > 0) or \ (hasattr(opt, 'lambda_continuity') and opt.lambda_continuity > 0): self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) if self.use_gpu(): self.criterionVGG.to(self.device) if opt.lambda_ssim > 0: self.criterionSSIM = SSIMLoss().to(self.device) self.vgg_for_spatial = None self.criterionSpatial = None if opt.lambda_spatial > 0: self.vgg_for_spatial = VGG16().to(self.device) self.vgg_for_spatial.eval() for param in self.vgg_for_spatial.parameters(): param.requires_grad = False self.criterionSpatial = SpatialCorrelativeLoss(loss_mode=opt.loss_mode, patch_nums=opt.patch_nums, patch_size=opt.patch_size, norm=opt.use_norm, use_conv=opt.learned_attn, T=opt.T).to(self.device) self.g_reg_interval = opt.ppl_reg_every self.d_reg_interval = opt.r1_reg_every self.pl_mean = torch.tensor(0.0, device=self.device) def _get_current_device_for_loss_init(self): return torch.device(f'cuda:{self.opt.gpu_ids[0]}' if self.use_gpu() and self.opt.gpu_ids else 'cpu') def forward(self, data, mode, iter_count=0, arbitrary_input=False): content_image_orig = data['day'].to(self.device) night_image = data.get('night', content_image_orig).to(self.device) phi2_for_mixing = data.get('phi2', None) phi_for_interp = self.opt.phi.to(self.device, dtype=content_image_orig.dtype) w_day, w_night = get_day_night_weights(phi_for_interp, self.opt.daylight_curve_steepness) cos_phi = torch.cos(phi_for_interp) sin_phi = torch.sin(phi_for_interp) _, _, H, W = content_image_orig.shape light_map = create_light_direction_map(phi_for_interp, size=(H, W), num_channels=self.opt.light_map_channels, device=self.device, w_day=w_day) content_image_with_light = torch.cat([content_image_orig, light_map], dim=1) net_g_module = self.netG.module if isinstance(self.netG, torch.nn.DataParallel) else self.netG if hasattr(net_g_module, 'opt'): net_g_module.opt.phi = self.opt.phi if mode == 'inference': with torch.no_grad(): fake_image, _ = net_g_module(content_image_with_light, w_day, w_night, cos_phi, sin_phi, phi2=None, arbitrary_input=arbitrary_input) return fake_image w_day_pix = w_day.view(-1, 1, 1, 1).expand_as(content_image_orig) style_image_target = w_day_pix * content_image_orig + (1.0 - w_day_pix) * night_image if mode == 'generator': g_loss_for_backward, g_losses_for_display, generated_images = self.compute_generator_loss( content_image_with_light, content_image_orig, style_image_target, night_image, w_day, w_night, cos_phi, sin_phi, iter_count, phi2_for_mixing, arbitrary_input=arbitrary_input ) generated_images['style_target_interpolated'] = style_image_target generated_images['light_map_viz'] = light_map return g_loss_for_backward, g_losses_for_display, generated_images elif mode == 'discriminator': d1_losses, d2_losses = self.compute_discriminator_loss( content_image_with_light, content_image_orig, style_image_target, w_day, w_night, cos_phi, sin_phi, iter_count, phi2_for_mixing, arbitrary_input=arbitrary_input ) return d1_losses, d2_losses else: raise ValueError(f"|mode| is invalid: {mode}") def g_main_loss(self, content_with_light, content_orig, style_target, night_image, w_day, w_night, cos_phi, sin_phi, phi2_for_mixing, arbitrary_input, iter_count): G_losses = {} fake_image, identity_images_dict = self.netG(content_with_light, w_day, w_night, cos_phi, sin_phi, phi2=phi2_for_mixing, arbitrary_input=arbitrary_input) if fake_image is None: raise RuntimeError("Generator output is None.") warmup_factor = min(1.0, iter_count / self.opt.night_loss_warmup_iters) G_losses['Night_Loss_Warmup'] = torch.tensor(warmup_factor, device=self.device) if self.opt.lambda_vgg_interp > 0 and self.criterionVGG is not None: with torch.no_grad(): vgg_feats_day = self.criterionVGG.vgg(content_orig) soft_night_anchor_img = (content_orig * 0.1) - 0.85 vgg_feats_soft_night = self.criterionVGG.vgg(soft_night_anchor_img) hard_night_anchor_img = torch.full_like(content_orig, -0.9, device=self.device) vgg_feats_hard_night = self.criterionVGG.vgg(hard_night_anchor_img) extreme_night_anchor_img = torch.full_like(content_orig, -1.0, device=self.device) vgg_feats_extreme_night = self.criterionVGG.vgg(extreme_night_anchor_img) vgg_feats_fake = self.criterionVGG.vgg(fake_image) vgg_interp_loss = 0 phi_deg = self.opt.phi * (180.0 / math.pi) phi_deg = phi_deg.view(-1, 1, 1, 1) phi_deg_sym = torch.where(phi_deg > 180.0, 360.0 - phi_deg, phi_deg) alpha_p1 = (phi_deg_sym / 15.0).clamp(0.0, 1.0) alpha_p2 = ((phi_deg_sym - 15.0) / 30.0).clamp(0.0, 1.0) alpha_p3 = ((phi_deg_sym - 45.0) / 135.0).clamp(0.0, 1.0) for i in range(len(vgg_feats_fake)): interp_p1 = (1.0 - alpha_p1) * vgg_feats_extreme_night[i] + alpha_p1 * vgg_feats_hard_night[i] interp_p2 = (1.0 - alpha_p2) * vgg_feats_hard_night[i] + alpha_p2 * vgg_feats_soft_night[i] interp_p3 = (1.0 - alpha_p3) * vgg_feats_soft_night[i] + alpha_p3 * vgg_feats_day[i] mask_p1 = (phi_deg_sym <= 15.0).expand_as(vgg_feats_fake[i]) mask_p2 = ((phi_deg_sym > 15.0) & (phi_deg_sym <= 45.0)).expand_as(vgg_feats_fake[i]) target_feat = torch.where(mask_p1, interp_p1, torch.where(mask_p2, interp_p2, interp_p3)) vgg_interp_loss += self.criterionFeat(vgg_feats_fake[i], target_feat.detach()) G_losses['VGG_Interp'] = vgg_interp_loss * self.opt.lambda_vgg_interp with torch.no_grad(): hard_night_anchor = torch.full_like(content_orig, -1.0, device=self.device) soft_night_anchor = (content_orig * 0.1) - 0.85 phi_deg_blend = self.opt.phi * (180.0 / math.pi) blend_weight = (F.relu(20.0 - phi_deg_blend) / 20.0).clamp(max=1.0).view(-1, 1, 1, 1) night_anchor_image = blend_weight * hard_night_anchor + (1.0 - blend_weight) * soft_night_anchor if self.netD2 is not None: pred_fake_d2_features, _ = self.discriminate(fake_image, content_orig, 'D2') G_losses['GAN_D2'] = self.criterionGAN(pred_fake_d2_features, True, for_discriminator=False) * self.opt.lambda_G if not self.opt.no_ganFeat_loss and self.opt.lambda_feat > 0: real_target_for_d = w_day.view(-1, 1, 1, 1) * content_orig + w_night.view(-1, 1, 1, 1) * night_anchor_image with torch.no_grad(): pred_real_d2_features, _ = self.discriminate(real_target_for_d, content_orig, 'D2') G_losses['GAN_Feat_D2'] = self.feat_matching_loss(pred_fake_d2_features, pred_real_d2_features) * self.opt.lambda_feat if self.opt.lambda_identity > 0: i_cc = identity_images_dict.get('identity_content') if i_cc is not None: loss_identity = self.criterionIdentity(i_cc, content_orig) G_losses['Identity'] = (loss_identity * w_day).mean() * self.opt.lambda_identity if hasattr(self.opt, 'lambda_z_reg') and self.opt.lambda_z_reg > 0: netG_module = self.netG.module if isinstance(self.netG, torch.nn.DataParallel) else self.netG if hasattr(netG_module, 'z_day') and hasattr(netG_module, 'z_night'): z_day, z_night = netG_module.z_day, netG_module.z_night G_losses['Z_Reg'] = self.criterionFeat(z_day, z_night) * self.opt.lambda_z_reg total_loss = sum(l for l in G_losses.values() if l is not None and isinstance(l, torch.Tensor)) return total_loss, G_losses, fake_image, identity_images_dict def compute_generator_loss(self, content_with_light, content_orig, style_target, night_image, w_day, w_night, cos_phi, sin_phi, iter_count, phi2_for_mixing, arbitrary_input=False): if self.netD2: for p in self.netD2.parameters(): p.requires_grad = False g_loss_main, g_losses_for_display, fake_image_1, identity_images_dict = self.g_main_loss( content_with_light, content_orig, style_target, night_image, w_day, w_night, cos_phi, sin_phi, phi2_for_mixing, arbitrary_input=arbitrary_input, iter_count=iter_count ) if self.opt.lambda_continuity > 0: netG_module = self.netG.module if isinstance(self.netG, torch.nn.DataParallel) else self.netG phi_1 = self.opt.phi.clone() delta_phi = (torch.rand_like(phi_1) * 0.2) - 0.1 phi_2 = (phi_1 + delta_phi) w_day_2, w_night_2 = get_day_night_weights(phi_2, self.opt.daylight_curve_steepness) cos_phi_2 = torch.cos(phi_2) sin_phi_2 = torch.sin(phi_2) _, _, H, W = content_orig.shape light_map_2 = create_light_direction_map(phi_2, (H, W), self.opt.light_map_channels, self.device, w_day_2) content_with_light_2 = torch.cat([content_orig, light_map_2], dim=1) with torch.no_grad(): fake_image_2, _ = netG_module(content_with_light_2, w_day_2, w_night_2, cos_phi_2, sin_phi_2, phi2=None, arbitrary_input=arbitrary_input) loss_l1 = self.criterionIdentity(fake_image_1, fake_image_2.detach()) loss_vgg = self.criterionVGG(fake_image_1, fake_image_2.detach()) if self.criterionVGG is not None else torch.tensor(0.0, device=self.device) continuity_loss = loss_l1 + loss_vgg g_losses_for_display["Continuity"] = continuity_loss.clone().detach() g_loss_main += continuity_loss * self.opt.lambda_continuity if self.opt.lambda_ppl > 0 and (iter_count + 1) % self.g_reg_interval == 0: g_losses_for_display["PPL"] = torch.tensor(0.0, device=self.device) else: g_losses_for_display["PPL"] = torch.tensor(0.0, device=self.device) g_losses_for_display['G_total'] = g_loss_main.clone().detach() return g_loss_main, g_losses_for_display, {**identity_images_dict, 'synthesized_image': fake_image_1} def d_main_loss(self, content_with_light, content_orig, style_target, w_day, w_night, cos_phi, sin_phi, phi2_for_mixing, arbitrary_input=False): if not self.netD2: return torch.tensor(0.0, device=self.device), {} for p in self.netD2.parameters(): p.requires_grad = True D2_losses = {} with torch.no_grad(): fake_image, _ = self.netG(content_with_light, w_day, w_night, cos_phi, sin_phi, phi2=phi2_for_mixing, arbitrary_input=arbitrary_input) fake_image = fake_image.detach() pred_fake_d2_full, _ = self.discriminate(fake_image, content_orig, 'D2') D2_losses['D2_Fake'] = self.criterionGAN(pred_fake_d2_full, False, for_discriminator=True) pred_real_d2_full, _ = self.discriminate(style_target, content_orig, 'D2') D2_losses['D2_real'] = self.criterionGAN(pred_real_d2_full, True, for_discriminator=True) return sum(D2_losses.values()), D2_losses def compute_discriminator_loss(self, content_with_light, content_orig, style_target, w_day, w_night, cos_phi, sin_phi, iter_count, phi2_for_mixing, arbitrary_input=False): if not self.netD2: return {}, {} d_loss_main, d2_losses = self.d_main_loss(content_with_light, content_orig, style_target, w_day, w_night, cos_phi, sin_phi, phi2_for_mixing, arbitrary_input=arbitrary_input) if self.opt.lambda_r1 > 0 and (iter_count + 1) % self.d_reg_interval == 0: r1_loss = self.d_r1_regularize(style_target.clone(), content_orig) d2_losses['D2_R1'] = r1_loss.clone().detach() weighted_r1_loss = self.opt.lambda_r1 / 2 * r1_loss * self.d_reg_interval d_loss_main += weighted_r1_loss else: d2_losses['D2_R1'] = torch.tensor(0.0, device=self.device) return {}, d2_losses def d_r1_regularize(self, style_target, content_orig): if not self.netD2: return torch.tensor(0.0, device=self.device) style_target.requires_grad = True _, pred_real_d2_logit = self.discriminate(style_target, content_orig, 'D2') r1_loss = self.r1_penalty(pred_real_d2_logit, style_target) return r1_loss def create_optimizers(self, opt): G_params = list(self.netG.parameters()) optimizer_G = torch.optim.Adam(G_params, lr=opt.lr, betas=(opt.beta1, opt.beta2)) D_params = list(self.netD2.parameters()) if self.netD2 else [] optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, opt.beta2)) if D_params else None return optimizer_G, optimizer_D, optimizer_D def r1_penalty(self, real_pred, real_img): grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True, retain_graph=True)[ 0] return grad_real.pow(2).reshape(real_img.shape[0], -1).sum(1).mean() def discriminate(self, image_to_judge, condition, type): target_D = self.netD2 if 'swin' in self.opt.netD.lower(): return target_D(image_to_judge, condition) else: d_input = torch.cat((image_to_judge, condition), dim=1) discriminator_out = target_D(d_input) if isinstance(discriminator_out, list): logit = discriminator_out[-1][-1] if isinstance(discriminator_out[-1], list) else discriminator_out[-1] return discriminator_out, logit else: return [discriminator_out], discriminator_out def feat_matching_loss(self, pred_fake_scales, pred_real_scales): loss = 0.0 for fake_feat_scale, real_feat_scale in zip(pred_fake_scales, pred_real_scales): if isinstance(fake_feat_scale, list): for f, r in zip(fake_feat_scale[:-1], real_feat_scale[:-1]): loss += self.criterionFeat(f, r.detach()) else: loss += self.criterionFeat(fake_feat_scale, real_feat_scale.detach()) return loss def use_gpu(self): return len(self.opt.gpu_ids) > 0 def save(self, epoch): util.save_network(self.netG, 'G', epoch, self.opt) if self.netD2: util.save_network(self.netD2, 'D2', epoch, self.opt) def initialize_networks(self, opt): netG = networks.define_G(opt) netD2 = networks.define_D(opt) if opt.isTrain else None if not opt.isTrain or opt.continue_train: netG = util.load_network(netG, 'G', opt.which_epoch, opt) if opt.isTrain: netD2 = util.load_network(netD2, 'D2', opt.which_epoch, opt) return netG, None, netD2, None def calculate_patch_vgg_loss(self, fake, real): if not (hasattr(self.opt, 'use_patch_vgg_loss') and self.opt.use_patch_vgg_loss and self.opt.lambda_patch_vgg > 0 and self.criterionVGG is not None): return torch.tensor( 0.0, device=fake.device) patch_size, num_patches = self.opt.patch_vgg_size, self.opt.patch_vgg_num_patches B, C, H, W = fake.shape if patch_size > H or patch_size > W: return torch.tensor(0.0, device=fake.device) rand_h = torch.randint(0, H - patch_size + 1, (num_patches,), device=fake.device) rand_w = torch.randint(0, W - patch_size + 1, (num_patches,), device=fake.device) fake_patches, real_patches = [], [] for i in range(num_patches): for b in range(B): fake_patches.append( fake[b:b + 1, :, rand_h[i]:rand_h[i] + patch_size, rand_w[i]:rand_w[i] + patch_size]) real_patches.append( real[b:b + 1, :, rand_h[i]:rand_h[i] + patch_size, rand_w[i]:rand_w[i] + patch_size]) if not fake_patches: return torch.tensor(0.0, device=fake.device) return self.criterionVGG(torch.cat(fake_patches, dim=0), torch.cat(real_patches, dim=0).detach()) def calculate_style_loss(self, fake, target): if self.criterionVGG is None: return torch.tensor(0.0, device=fake.device) fake_vgg_feats = self.criterionVGG.vgg(fake) target_vgg_feats = self.criterionVGG.vgg(target) style_loss = 0.0 for f_f, f_t in zip(fake_vgg_feats, target_vgg_feats): style_loss += self.criterionFeat(self.gram_matrix(f_f), self.gram_matrix(f_t).detach()) return style_loss def gram_matrix(self, input_tensor): b, c, h, w = input_tensor.size() features = input_tensor.view(b, c, h * w) features_t = features.transpose(1, 2) G = features.bmm(features_t) return G.div(c * h * w)