import os import torch import torch.nn.functional as F import torch.distributed as dist from models import register from models.ldm.ldm_base import LDMBase from models.ldm.vqgan.lpips import LPIPS from models.ldm.vqgan.discriminator import make_discriminator @register('glpto') class GLPTo(LDMBase): def __init__(self, lpips=True, disc=True, adaptive_gan_weight=True, noise_render=False, **kwargs): super().__init__(**kwargs) if lpips: self.lpips_loss = LPIPS().eval() self.disc = make_discriminator(input_nc=3) if disc else None self.adaptive_gan_weight = adaptive_gan_weight self.noise_render = noise_render def get_parameters(self, name): if name == 'disc': return self.disc.parameters() else: return super().get_parameters(name) def render(self, z_dec, coord, scale): if not self.noise_render: return self.renderer(z_dec, coord=coord, scale=scale) else: shape = (coord.shape[0], 3, coord.shape[2], coord.shape[3]) noise = torch.randn(shape, device=z_dec.device) return self.renderer(noise, coord=coord, scale=scale, z_dec=z_dec) def forward(self, data, mode, has_optimizer=None, use_gan=False): if mode in ['z', 'z_dec']: ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) return ret_z grad = self.get_grad_plan(has_optimizer) loss_config = self.loss_config if mode == 'pred': z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) gt_patch = data['gt'][:, :3, ...] coord = data['gt'][:, 3:5, ...] scale = data['gt'][:, 5:7, ...] if grad['renderer']: return self.render(z_dec, coord, scale) else: with torch.no_grad(): return self.render(z_dec, coord, scale) elif mode == 'loss': if not grad['renderer']: # Only training zdm _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) return ret gt_patch = data['gt'][:, :3, ...] coord = data['gt'][:, 3:5, ...] scale = data['gt'][:, 5:7, ...] z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) pred = self.render(z_dec, coord, scale) l1_loss = torch.abs(pred - gt_patch).mean() ret['l1_loss'] = l1_loss.item() l1_loss_w = loss_config.get('l1_loss', 1) ret['loss'] = ret['loss'] + l1_loss * l1_loss_w lpips_loss = self.lpips_loss(pred, gt_patch).mean() ret['lpips_loss'] = lpips_loss.item() lpips_loss_w = loss_config.get('lpips_loss', 1) ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w if use_gan: logits_fake = self.disc(pred) gan_g_loss = -torch.mean(logits_fake) ret['gan_g_loss'] = gan_g_loss.item() weight = loss_config.get('gan_g_loss', 1) if self.training and self.adaptive_gan_weight: nll_loss = l1_loss * l1_loss_w + lpips_loss * lpips_loss_w adaptive_gan_w = self.calculate_adaptive_gan_w(nll_loss, gan_g_loss, self.renderer.get_last_layer_weight()) ret['adaptive_gan_w'] = adaptive_gan_w.item() weight = weight * adaptive_gan_w ret['loss'] = ret['loss'] + gan_g_loss * weight return ret elif mode == 'disc_loss': gt_patch = data['gt'][:, :3, ...] coord = data['gt'][:, 3:5, ...] scale = data['gt'][:, 5:7, ...] with torch.no_grad(): z_dec, _ = super().forward(data, mode='z_dec', has_optimizer=None) pred = self.render(z_dec, coord, scale) logits_real = self.disc(gt_patch) logits_fake = self.disc(pred) disc_loss_type = loss_config.get('disc_loss_type', 'hinge') if disc_loss_type == 'hinge': loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) loss = (loss_real + loss_fake) / 2 elif disc_loss_type == 'vanilla': loss_real = torch.mean(F.softplus(-logits_real)) loss_fake = torch.mean(F.softplus(logits_fake)) loss = (loss_real + loss_fake) / 2 return { 'loss': loss, 'disc_logits_real': logits_real.mean().item(), 'disc_logits_fake': logits_fake.mean().item(), } def calculate_adaptive_gan_w(self, nll_loss, g_loss, last_layer): nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] world_size = int(os.environ.get('WORLD_SIZE', '1')) if world_size > 1: dist.all_reduce(nll_grads, op=dist.ReduceOp.SUM) nll_grads.div_(world_size) dist.all_reduce(g_grads, op=dist.ReduceOp.SUM) g_grads.div_(world_size) d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() return d_weight