| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from models import iresnet
|
| from lpips.lpips import LPIPS
|
| from pytorch_msssim import SSIM
|
|
|
|
|
| def adopt_weight(weight, global_step, threshold=0, value=0.):
|
| if global_step < threshold:
|
| weight = value
|
| return weight
|
|
|
|
|
| def hinge_d_loss(logits_real, logits_fake):
|
| loss_real = torch.mean(F.relu(1. - logits_real))
|
| loss_fake = torch.mean(F.relu(1. + logits_fake))
|
| d_loss = 0.5 * (loss_real + loss_fake)
|
| return d_loss
|
|
|
|
|
| def mse_d_loss(logits_real, logits_fake):
|
| loss_real = torch.mean((logits_real - 1.) ** 2)
|
| loss_fake = torch.mean(logits_fake ** 2)
|
| d_loss = 0.5 * (loss_real + loss_fake)
|
| return d_loss
|
|
|
|
|
| def vanilla_d_loss(logits_real, logits_fake):
|
| d_loss = 0.5 * (
|
| torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
| torch.mean(torch.nn.functional.softplus(logits_fake)))
|
| return d_loss
|
|
|
|
|
| def create_fr_model(model_path, depth="100"):
|
| model = iresnet(depth)
|
| model.load_state_dict(torch.load(model_path))
|
|
|
| return model
|
|
|
|
|
| def downscale(img: torch.tensor):
|
| half_size = img.shape[-1] // 8
|
| img = F.interpolate(img, size=(half_size, half_size), mode='bicubic', align_corners=False)
|
| return img
|
|
|
|
|
| class VQLPIPSWithDiscriminator(nn.Module):
|
| def __init__(self, disc_start=1000, disc_factor=1.0, disc_weight=1.0,
|
| disc_conditional=False, disc_loss="mse", id_loss="mse",
|
| fr_model="./models/arcface-r100-glint360k.pth"):
|
| super().__init__()
|
| assert disc_loss in ["hinge", "vanilla", "mse", "smooth"]
|
| self.loss_name = disc_loss
|
| self.perceptual_loss = LPIPS().eval()
|
| self.discriminator_iter_start = disc_start
|
| if disc_loss == "hinge":
|
| self.disc_loss = hinge_d_loss
|
| elif disc_loss == "vanilla":
|
| self.disc_loss = vanilla_d_loss
|
| elif disc_loss == "mse":
|
| self.disc_loss = mse_d_loss
|
| else:
|
| raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
| print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
| self.fr_model = create_fr_model(fr_model).eval()
|
| if id_loss == "mse":
|
| self.feature_loss = nn.MSELoss()
|
| elif id_loss == "cosine":
|
| self.feature_loss = nn.CosineSimilarity()
|
| self.disc_factor = disc_factor
|
| self.discriminator_weight = disc_weight
|
| self.disc_conditional = disc_conditional
|
| self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
|
|
|
| def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
| if last_layer is not None:
|
| nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
| g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| else:
|
| nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
| g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
|
|
| d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
| d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
| d_weight = d_weight * self.discriminator_weight
|
| return d_weight
|
|
|
| def forward(self, im_features, gt_indices, logits, gt_img, image, discriminator, emb_loss,
|
| epoch, last_layer=None, cond=None, mask=None):
|
| rec_loss = (image - gt_img) ** 2
|
|
|
| if epoch >= 0:
|
| gen_feature = self.fr_model(image)
|
| feature_loss = torch.mean(1 - torch.cosine_similarity(im_features, gen_feature))
|
| else:
|
| feature_loss = 0
|
|
|
| p_loss = self.perceptual_loss(image, gt_img) * 2
|
|
|
| with torch.cuda.amp.autocast(enabled=False):
|
| ssim_loss = 1 - self.ssim_loss((image.float() + 1) / 2, (gt_img + 1) / 2)
|
| logits_fake = discriminator(image)
|
| logits_real_d = discriminator(gt_img.detach())
|
| logits_fake_d = discriminator(image.detach())
|
|
|
| if mask is None:
|
| token_loss = (logits[:, 1:, :] - gt_indices[:, 1:, :])
|
| token_loss = torch.mean(token_loss ** 2)
|
| else:
|
| token_loss = torch.abs((logits[:, 1:, :] - gt_indices[:, 1:, :])) * mask[:, 1:, None]
|
| token_loss = token_loss.sum() / mask[:, 1:].sum()
|
|
|
| nll_loss = torch.mean(rec_loss + p_loss) + \
|
| ssim_loss + \
|
| token_loss + feature_loss + emb_loss
|
|
|
| g_loss = -torch.mean(logits_fake)
|
|
|
| d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
| disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
|
| ae_loss = nll_loss + d_weight * disc_factor * g_loss
|
|
|
|
|
| disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
|
| d_loss = disc_factor * self.disc_loss(logits_real_d, logits_fake_d)
|
| return ae_loss, d_loss, token_loss, rec_loss, ssim_loss, p_loss, feature_loss
|
|
|