CR-Net / models /pix2pix_model.py
datnguyentien204's picture
Update models/pix2pix_model.py
76b9090 verified
import torch
import models.networks as networks
from models.networks.loss import VGG16, SpatialCorrelativeLoss, SSIMLoss, LaplacianLoss
from models.networks.time_utils import get_day_night_weights
import util.util as util
import torch.nn as nn
import os
import torch.nn.functional as F
import math
import random
class TVLoss(nn.Module):
def __init__(self):
super(TVLoss, self).__init__()
def forward(self, x):
batch_size, c, h, w = x.size()
tv_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum()
tv_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum()
return (tv_h + tv_w) / (batch_size * c * h * w)
class ThresholdedL1Loss(nn.Module):
def __init__(self, threshold):
super().__init__()
self.threshold = threshold
def forward(self, input_tensor):
avg_brightness = torch.mean(input_tensor, dim=[1, 2, 3])
loss = F.relu(avg_brightness - self.threshold)
return torch.mean(loss)
def hour_to_rad(hour):
return (hour / 12.0) * math.pi
def create_light_direction_map(phi, size, num_channels, device, w_day):
B = phi.shape[0]
H, W = size
x_coords = torch.linspace(-1, 1, W, device=device)
y_coords = torch.linspace(-1, 1, H, device=device)
y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
x_grid = x_grid.unsqueeze(0).expand(B, -1, -1)
y_grid = y_grid.unsqueeze(0).expand(B, -1, -1)
phi_adjusted = phi - math.pi / 2
light_vec_x = torch.cos(phi_adjusted).view(B, 1, 1)
light_vec_y = torch.sin(phi_adjusted).view(B, 1, 1)
directional_intensity = w_day.view(B, 1, 1)
if num_channels == 1:
light_map = (x_grid * light_vec_x + y_grid * light_vec_y) * directional_intensity
return light_map.unsqueeze(1)
elif num_channels == 2:
light_map_x = x_grid * light_vec_x * directional_intensity
light_map_y = y_grid * light_vec_y * directional_intensity
return torch.stack([light_map_x, light_map_y], dim=1)
else:
raise ValueError("num_channels for light map must be 1 or 2")
class Pix2PixModel(torch.nn.Module):
@staticmethod
def modify_commandline_options(parser, is_train):
networks.modify_commandline_options(parser, is_train)
parser.add_argument('--daylight_curve_steepness', type=float, default=0.75)
parser.add_argument('--lambda_vgg_interp', type=float, default=15.0)
parser.add_argument('--attn_layers', type=str, default='4,7,9')
parser.add_argument('--patch_nums', type=float, default=128)
parser.add_argument('--patch_size', type=int, default=32)
parser.add_argument('--loss_mode', type=str, default='cos')
parser.add_argument('--use_norm', action='store_true')
parser.add_argument('--learned_attn', action='store_true', default=False)
parser.add_argument('--T', type=float, default=0.07)
parser.add_argument('--lambda_spatial', type=float, default=10.0)
parser.add_argument('--lambda_G', type=float, default=1.0)
if is_train:
parser.add_argument('--night_loss_warmup_iters', type=int, default=10000)
parser.add_argument('--lambda_content_focus', type=float, default=5.0)
parser.add_argument('--lambda_continuity', type=float, default=10.0)
parser.add_argument('--lambda_style', type=float, default=30.0)
parser.add_argument('--lambda_identity', type=float, default=10.0)
parser.add_argument('--lambda_r1', type=float, default=10.0)
parser.add_argument('--lambda_latent_diversity', type=float, default=1.0)
parser.add_argument('--use_patch_vgg_loss', action='store_true')
parser.add_argument('--lambda_patch_vgg', type=float, default=10.0)
parser.add_argument('--patch_vgg_num_patches', type=int, default=64)
parser.add_argument('--patch_vgg_size', type=int, default=64)
parser.add_argument('--lambda_ppl', type=float, default=2.0)
parser.add_argument('--ppl_reg_every', type=int, default=4)
parser.add_argument('--r1_reg_every', type=int, default=16)
parser.add_argument('--lambda_ssim', type=float, default=1.0)
if not hasattr(parser.parse_known_args()[0], 'lambda_z_reg'):
parser.add_argument('--lambda_z_reg', type=float, default=0.0)
if not hasattr(parser.parse_known_args()[0], 'lambda_night_content'):
parser.add_argument('--lambda_night_content', type=float, default=0.0)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
self.attn_layers = [int(i) for i in self.opt.attn_layers.split(',')]
self.device = self._get_current_device_for_loss_init()
self.netG, self.netD, self.netD2, self.netE = self.initialize_networks(opt)
if opt.isTrain:
self.criterionGAN = networks.GANLoss(opt.gan_mode,
tensor=torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor,
opt=self.opt)
self.criterionFeat = torch.nn.L1Loss()
self.criterionIdentity = torch.nn.L1Loss()
if (not opt.no_vgg_loss and opt.lambda_vgg > 0) or \
(hasattr(opt, 'lambda_style') and opt.lambda_style > 0) or \
(hasattr(opt, 'lambda_vgg_interp') and opt.lambda_vgg_interp > 0) or \
(hasattr(opt, 'lambda_night_content') and opt.lambda_night_content > 0) or \
(hasattr(opt, 'lambda_continuity') and opt.lambda_continuity > 0):
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
if self.use_gpu(): self.criterionVGG.to(self.device)
if opt.lambda_ssim > 0:
self.criterionSSIM = SSIMLoss().to(self.device)
self.vgg_for_spatial = None
self.criterionSpatial = None
if opt.lambda_spatial > 0:
self.vgg_for_spatial = VGG16().to(self.device)
self.vgg_for_spatial.eval()
for param in self.vgg_for_spatial.parameters(): param.requires_grad = False
self.criterionSpatial = SpatialCorrelativeLoss(loss_mode=opt.loss_mode, patch_nums=opt.patch_nums,
patch_size=opt.patch_size, norm=opt.use_norm,
use_conv=opt.learned_attn, T=opt.T).to(self.device)
self.g_reg_interval = opt.ppl_reg_every
self.d_reg_interval = opt.r1_reg_every
self.pl_mean = torch.tensor(0.0, device=self.device)
def _get_current_device_for_loss_init(self):
return torch.device(f'cuda:{self.opt.gpu_ids[0]}' if self.use_gpu() and self.opt.gpu_ids else 'cpu')
def forward(self, data, mode, iter_count=0, arbitrary_input=False):
content_image_orig = data['day'].to(self.device)
night_image = data.get('night', content_image_orig).to(self.device)
phi2_for_mixing = data.get('phi2', None)
phi_for_interp = self.opt.phi.to(self.device, dtype=content_image_orig.dtype)
w_day, w_night = get_day_night_weights(phi_for_interp, self.opt.daylight_curve_steepness)
cos_phi = torch.cos(phi_for_interp)
sin_phi = torch.sin(phi_for_interp)
_, _, H, W = content_image_orig.shape
light_map = create_light_direction_map(phi_for_interp, size=(H, W), num_channels=self.opt.light_map_channels,
device=self.device, w_day=w_day)
content_image_with_light = torch.cat([content_image_orig, light_map], dim=1)
net_g_module = self.netG.module if isinstance(self.netG, torch.nn.DataParallel) else self.netG
if hasattr(net_g_module, 'opt'): net_g_module.opt.phi = self.opt.phi
if mode == 'inference':
with torch.no_grad():
fake_image, _ = net_g_module(content_image_with_light, w_day, w_night, cos_phi, sin_phi, phi2=None,
arbitrary_input=arbitrary_input)
return fake_image
w_day_pix = w_day.view(-1, 1, 1, 1).expand_as(content_image_orig)
style_image_target = w_day_pix * content_image_orig + (1.0 - w_day_pix) * night_image
if mode == 'generator':
g_loss_for_backward, g_losses_for_display, generated_images = self.compute_generator_loss(
content_image_with_light, content_image_orig, style_image_target,
night_image, w_day, w_night, cos_phi, sin_phi, iter_count, phi2_for_mixing,
arbitrary_input=arbitrary_input
)
generated_images['style_target_interpolated'] = style_image_target
generated_images['light_map_viz'] = light_map
return g_loss_for_backward, g_losses_for_display, generated_images
elif mode == 'discriminator':
d1_losses, d2_losses = self.compute_discriminator_loss(
content_image_with_light, content_image_orig, style_image_target,
w_day, w_night, cos_phi, sin_phi, iter_count, phi2_for_mixing, arbitrary_input=arbitrary_input
)
return d1_losses, d2_losses
else:
raise ValueError(f"|mode| is invalid: {mode}")
def g_main_loss(self, content_with_light, content_orig, style_target, night_image, w_day, w_night, cos_phi, sin_phi,
phi2_for_mixing, arbitrary_input, iter_count):
G_losses = {}
fake_image, identity_images_dict = self.netG(content_with_light, w_day, w_night, cos_phi, sin_phi,
phi2=phi2_for_mixing, arbitrary_input=arbitrary_input)
if fake_image is None: raise RuntimeError("Generator output is None.")
warmup_factor = min(1.0, iter_count / self.opt.night_loss_warmup_iters)
G_losses['Night_Loss_Warmup'] = torch.tensor(warmup_factor, device=self.device)
if self.opt.lambda_vgg_interp > 0 and self.criterionVGG is not None:
with torch.no_grad():
vgg_feats_day = self.criterionVGG.vgg(content_orig)
soft_night_anchor_img = (content_orig * 0.1) - 0.85
vgg_feats_soft_night = self.criterionVGG.vgg(soft_night_anchor_img)
hard_night_anchor_img = torch.full_like(content_orig, -0.9, device=self.device)
vgg_feats_hard_night = self.criterionVGG.vgg(hard_night_anchor_img)
extreme_night_anchor_img = torch.full_like(content_orig, -1.0, device=self.device)
vgg_feats_extreme_night = self.criterionVGG.vgg(extreme_night_anchor_img)
vgg_feats_fake = self.criterionVGG.vgg(fake_image)
vgg_interp_loss = 0
phi_deg = self.opt.phi * (180.0 / math.pi)
phi_deg = phi_deg.view(-1, 1, 1, 1)
phi_deg_sym = torch.where(phi_deg > 180.0, 360.0 - phi_deg, phi_deg)
alpha_p1 = (phi_deg_sym / 15.0).clamp(0.0, 1.0)
alpha_p2 = ((phi_deg_sym - 15.0) / 30.0).clamp(0.0, 1.0)
alpha_p3 = ((phi_deg_sym - 45.0) / 135.0).clamp(0.0, 1.0)
for i in range(len(vgg_feats_fake)):
interp_p1 = (1.0 - alpha_p1) * vgg_feats_extreme_night[i] + alpha_p1 * vgg_feats_hard_night[i]
interp_p2 = (1.0 - alpha_p2) * vgg_feats_hard_night[i] + alpha_p2 * vgg_feats_soft_night[i]
interp_p3 = (1.0 - alpha_p3) * vgg_feats_soft_night[i] + alpha_p3 * vgg_feats_day[i]
mask_p1 = (phi_deg_sym <= 15.0).expand_as(vgg_feats_fake[i])
mask_p2 = ((phi_deg_sym > 15.0) & (phi_deg_sym <= 45.0)).expand_as(vgg_feats_fake[i])
target_feat = torch.where(mask_p1, interp_p1,
torch.where(mask_p2, interp_p2,
interp_p3))
vgg_interp_loss += self.criterionFeat(vgg_feats_fake[i], target_feat.detach())
G_losses['VGG_Interp'] = vgg_interp_loss * self.opt.lambda_vgg_interp
with torch.no_grad():
hard_night_anchor = torch.full_like(content_orig, -1.0, device=self.device)
soft_night_anchor = (content_orig * 0.1) - 0.85
phi_deg_blend = self.opt.phi * (180.0 / math.pi)
blend_weight = (F.relu(20.0 - phi_deg_blend) / 20.0).clamp(max=1.0).view(-1, 1, 1, 1)
night_anchor_image = blend_weight * hard_night_anchor + (1.0 - blend_weight) * soft_night_anchor
if self.netD2 is not None:
pred_fake_d2_features, _ = self.discriminate(fake_image, content_orig, 'D2')
G_losses['GAN_D2'] = self.criterionGAN(pred_fake_d2_features, True,
for_discriminator=False) * self.opt.lambda_G
if not self.opt.no_ganFeat_loss and self.opt.lambda_feat > 0:
real_target_for_d = w_day.view(-1, 1, 1, 1) * content_orig + w_night.view(-1, 1, 1,
1) * night_anchor_image
with torch.no_grad():
pred_real_d2_features, _ = self.discriminate(real_target_for_d, content_orig, 'D2')
G_losses['GAN_Feat_D2'] = self.feat_matching_loss(pred_fake_d2_features,
pred_real_d2_features) * self.opt.lambda_feat
if self.opt.lambda_identity > 0:
i_cc = identity_images_dict.get('identity_content')
if i_cc is not None:
loss_identity = self.criterionIdentity(i_cc, content_orig)
G_losses['Identity'] = (loss_identity * w_day).mean() * self.opt.lambda_identity
if hasattr(self.opt, 'lambda_z_reg') and self.opt.lambda_z_reg > 0:
netG_module = self.netG.module if isinstance(self.netG, torch.nn.DataParallel) else self.netG
if hasattr(netG_module, 'z_day') and hasattr(netG_module, 'z_night'):
z_day, z_night = netG_module.z_day, netG_module.z_night
G_losses['Z_Reg'] = self.criterionFeat(z_day, z_night) * self.opt.lambda_z_reg
total_loss = sum(l for l in G_losses.values() if l is not None and isinstance(l, torch.Tensor))
return total_loss, G_losses, fake_image, identity_images_dict
def compute_generator_loss(self, content_with_light, content_orig, style_target, night_image, w_day, w_night,
cos_phi, sin_phi, iter_count,
phi2_for_mixing, arbitrary_input=False):
if self.netD2:
for p in self.netD2.parameters(): p.requires_grad = False
g_loss_main, g_losses_for_display, fake_image_1, identity_images_dict = self.g_main_loss(
content_with_light, content_orig, style_target, night_image, w_day, w_night, cos_phi, sin_phi,
phi2_for_mixing, arbitrary_input=arbitrary_input, iter_count=iter_count
)
if self.opt.lambda_continuity > 0:
netG_module = self.netG.module if isinstance(self.netG, torch.nn.DataParallel) else self.netG
phi_1 = self.opt.phi.clone()
delta_phi = (torch.rand_like(phi_1) * 0.2) - 0.1
phi_2 = (phi_1 + delta_phi)
w_day_2, w_night_2 = get_day_night_weights(phi_2, self.opt.daylight_curve_steepness)
cos_phi_2 = torch.cos(phi_2)
sin_phi_2 = torch.sin(phi_2)
_, _, H, W = content_orig.shape
light_map_2 = create_light_direction_map(phi_2, (H, W), self.opt.light_map_channels, self.device, w_day_2)
content_with_light_2 = torch.cat([content_orig, light_map_2], dim=1)
with torch.no_grad():
fake_image_2, _ = netG_module(content_with_light_2, w_day_2, w_night_2, cos_phi_2, sin_phi_2, phi2=None,
arbitrary_input=arbitrary_input)
loss_l1 = self.criterionIdentity(fake_image_1, fake_image_2.detach())
loss_vgg = self.criterionVGG(fake_image_1,
fake_image_2.detach()) if self.criterionVGG is not None else torch.tensor(0.0,
device=self.device)
continuity_loss = loss_l1 + loss_vgg
g_losses_for_display["Continuity"] = continuity_loss.clone().detach()
g_loss_main += continuity_loss * self.opt.lambda_continuity
if self.opt.lambda_ppl > 0 and (iter_count + 1) % self.g_reg_interval == 0:
g_losses_for_display["PPL"] = torch.tensor(0.0, device=self.device)
else:
g_losses_for_display["PPL"] = torch.tensor(0.0, device=self.device)
g_losses_for_display['G_total'] = g_loss_main.clone().detach()
return g_loss_main, g_losses_for_display, {**identity_images_dict, 'synthesized_image': fake_image_1}
def d_main_loss(self, content_with_light, content_orig, style_target, w_day, w_night, cos_phi, sin_phi,
phi2_for_mixing, arbitrary_input=False):
if not self.netD2: return torch.tensor(0.0, device=self.device), {}
for p in self.netD2.parameters(): p.requires_grad = True
D2_losses = {}
with torch.no_grad():
fake_image, _ = self.netG(content_with_light, w_day, w_night, cos_phi, sin_phi, phi2=phi2_for_mixing,
arbitrary_input=arbitrary_input)
fake_image = fake_image.detach()
pred_fake_d2_full, _ = self.discriminate(fake_image, content_orig, 'D2')
D2_losses['D2_Fake'] = self.criterionGAN(pred_fake_d2_full, False, for_discriminator=True)
pred_real_d2_full, _ = self.discriminate(style_target, content_orig, 'D2')
D2_losses['D2_real'] = self.criterionGAN(pred_real_d2_full, True, for_discriminator=True)
return sum(D2_losses.values()), D2_losses
def compute_discriminator_loss(self, content_with_light, content_orig, style_target, w_day, w_night, cos_phi,
sin_phi, iter_count, phi2_for_mixing, arbitrary_input=False):
if not self.netD2: return {}, {}
d_loss_main, d2_losses = self.d_main_loss(content_with_light, content_orig, style_target, w_day, w_night,
cos_phi, sin_phi, phi2_for_mixing, arbitrary_input=arbitrary_input)
if self.opt.lambda_r1 > 0 and (iter_count + 1) % self.d_reg_interval == 0:
r1_loss = self.d_r1_regularize(style_target.clone(), content_orig)
d2_losses['D2_R1'] = r1_loss.clone().detach()
weighted_r1_loss = self.opt.lambda_r1 / 2 * r1_loss * self.d_reg_interval
d_loss_main += weighted_r1_loss
else:
d2_losses['D2_R1'] = torch.tensor(0.0, device=self.device)
return {}, d2_losses
def d_r1_regularize(self, style_target, content_orig):
if not self.netD2: return torch.tensor(0.0, device=self.device)
style_target.requires_grad = True
_, pred_real_d2_logit = self.discriminate(style_target, content_orig, 'D2')
r1_loss = self.r1_penalty(pred_real_d2_logit, style_target)
return r1_loss
def create_optimizers(self, opt):
G_params = list(self.netG.parameters())
optimizer_G = torch.optim.Adam(G_params, lr=opt.lr, betas=(opt.beta1, opt.beta2))
D_params = list(self.netD2.parameters()) if self.netD2 else []
optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, opt.beta2)) if D_params else None
return optimizer_G, optimizer_D, optimizer_D
def r1_penalty(self, real_pred, real_img):
grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True, retain_graph=True)[
0]
return grad_real.pow(2).reshape(real_img.shape[0], -1).sum(1).mean()
def discriminate(self, image_to_judge, condition, type):
target_D = self.netD2
if 'swin' in self.opt.netD.lower():
return target_D(image_to_judge, condition)
else:
d_input = torch.cat((image_to_judge, condition), dim=1)
discriminator_out = target_D(d_input)
if isinstance(discriminator_out, list):
logit = discriminator_out[-1][-1] if isinstance(discriminator_out[-1], list) else discriminator_out[-1]
return discriminator_out, logit
else:
return [discriminator_out], discriminator_out
def feat_matching_loss(self, pred_fake_scales, pred_real_scales):
loss = 0.0
for fake_feat_scale, real_feat_scale in zip(pred_fake_scales, pred_real_scales):
if isinstance(fake_feat_scale, list):
for f, r in zip(fake_feat_scale[:-1], real_feat_scale[:-1]): loss += self.criterionFeat(f, r.detach())
else:
loss += self.criterionFeat(fake_feat_scale, real_feat_scale.detach())
return loss
def use_gpu(self):
return len(self.opt.gpu_ids) > 0
def save(self, epoch):
util.save_network(self.netG, 'G', epoch, self.opt)
if self.netD2: util.save_network(self.netD2, 'D2', epoch, self.opt)
def initialize_networks(self, opt):
netG = networks.define_G(opt)
netD2 = networks.define_D(opt) if opt.isTrain else None
if not opt.isTrain or opt.continue_train:
netG = util.load_network(netG, 'G', opt.which_epoch, opt)
if opt.isTrain: netD2 = util.load_network(netD2, 'D2', opt.which_epoch, opt)
return netG, None, netD2, None
def calculate_patch_vgg_loss(self, fake, real):
if not (hasattr(self.opt,
'use_patch_vgg_loss') and self.opt.use_patch_vgg_loss and self.opt.lambda_patch_vgg > 0 and self.criterionVGG is not None): return torch.tensor(
0.0, device=fake.device)
patch_size, num_patches = self.opt.patch_vgg_size, self.opt.patch_vgg_num_patches
B, C, H, W = fake.shape
if patch_size > H or patch_size > W: return torch.tensor(0.0, device=fake.device)
rand_h = torch.randint(0, H - patch_size + 1, (num_patches,), device=fake.device)
rand_w = torch.randint(0, W - patch_size + 1, (num_patches,), device=fake.device)
fake_patches, real_patches = [], []
for i in range(num_patches):
for b in range(B):
fake_patches.append(
fake[b:b + 1, :, rand_h[i]:rand_h[i] + patch_size, rand_w[i]:rand_w[i] + patch_size])
real_patches.append(
real[b:b + 1, :, rand_h[i]:rand_h[i] + patch_size, rand_w[i]:rand_w[i] + patch_size])
if not fake_patches: return torch.tensor(0.0, device=fake.device)
return self.criterionVGG(torch.cat(fake_patches, dim=0), torch.cat(real_patches, dim=0).detach())
def calculate_style_loss(self, fake, target):
if self.criterionVGG is None: return torch.tensor(0.0, device=fake.device)
fake_vgg_feats = self.criterionVGG.vgg(fake)
target_vgg_feats = self.criterionVGG.vgg(target)
style_loss = 0.0
for f_f, f_t in zip(fake_vgg_feats, target_vgg_feats):
style_loss += self.criterionFeat(self.gram_matrix(f_f), self.gram_matrix(f_t).detach())
return style_loss
def gram_matrix(self, input_tensor):
b, c, h, w = input_tensor.size()
features = input_tensor.view(b, c, h * w)
features_t = features.transpose(1, 2)
G = features.bmm(features_t)
return G.div(c * h * w)