Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from models import register | |
| from models.ldm.ldm_base import LDMBase | |
| from models.ldm.vqgan.lpips import LPIPS | |
| from models.ldm.vqgan.discriminator import make_discriminator | |
| class GLPTo(LDMBase): | |
| def __init__(self, lpips=True, disc=True, adaptive_gan_weight=True, noise_render=False, **kwargs): | |
| super().__init__(**kwargs) | |
| if lpips: | |
| self.lpips_loss = LPIPS().eval() | |
| self.disc = make_discriminator(input_nc=3) if disc else None | |
| self.adaptive_gan_weight = adaptive_gan_weight | |
| self.noise_render = noise_render | |
| def get_parameters(self, name): | |
| if name == 'disc': | |
| return self.disc.parameters() | |
| else: | |
| return super().get_parameters(name) | |
| def render(self, z_dec, coord, scale): | |
| if not self.noise_render: | |
| return self.renderer(z_dec, coord=coord, scale=scale) | |
| else: | |
| shape = (coord.shape[0], 3, coord.shape[2], coord.shape[3]) | |
| noise = torch.randn(shape, device=z_dec.device) | |
| return self.renderer(noise, coord=coord, scale=scale, z_dec=z_dec) | |
| def forward(self, data, mode, has_optimizer=None, use_gan=False): | |
| if mode in ['z', 'z_dec']: | |
| ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) | |
| return ret_z | |
| grad = self.get_grad_plan(has_optimizer) | |
| loss_config = self.loss_config | |
| if mode == 'pred': | |
| z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) | |
| gt_patch = data['gt'][:, :3, ...] | |
| coord = data['gt'][:, 3:5, ...] | |
| scale = data['gt'][:, 5:7, ...] | |
| if grad['renderer']: | |
| return self.render(z_dec, coord, scale) | |
| else: | |
| with torch.no_grad(): | |
| return self.render(z_dec, coord, scale) | |
| elif mode == 'loss': | |
| if not grad['renderer']: # Only training zdm | |
| _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) | |
| return ret | |
| gt_patch = data['gt'][:, :3, ...] | |
| coord = data['gt'][:, 3:5, ...] | |
| scale = data['gt'][:, 5:7, ...] | |
| z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) | |
| pred = self.render(z_dec, coord, scale) | |
| l1_loss = torch.abs(pred - gt_patch).mean() | |
| ret['l1_loss'] = l1_loss.item() | |
| l1_loss_w = loss_config.get('l1_loss', 1) | |
| ret['loss'] = ret['loss'] + l1_loss * l1_loss_w | |
| lpips_loss = self.lpips_loss(pred, gt_patch).mean() | |
| ret['lpips_loss'] = lpips_loss.item() | |
| lpips_loss_w = loss_config.get('lpips_loss', 1) | |
| ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w | |
| if use_gan: | |
| logits_fake = self.disc(pred) | |
| gan_g_loss = -torch.mean(logits_fake) | |
| ret['gan_g_loss'] = gan_g_loss.item() | |
| weight = loss_config.get('gan_g_loss', 1) | |
| if self.training and self.adaptive_gan_weight: | |
| nll_loss = l1_loss * l1_loss_w + lpips_loss * lpips_loss_w | |
| adaptive_gan_w = self.calculate_adaptive_gan_w(nll_loss, gan_g_loss, self.renderer.get_last_layer_weight()) | |
| ret['adaptive_gan_w'] = adaptive_gan_w.item() | |
| weight = weight * adaptive_gan_w | |
| ret['loss'] = ret['loss'] + gan_g_loss * weight | |
| return ret | |
| elif mode == 'disc_loss': | |
| gt_patch = data['gt'][:, :3, ...] | |
| coord = data['gt'][:, 3:5, ...] | |
| scale = data['gt'][:, 5:7, ...] | |
| with torch.no_grad(): | |
| z_dec, _ = super().forward(data, mode='z_dec', has_optimizer=None) | |
| pred = self.render(z_dec, coord, scale) | |
| logits_real = self.disc(gt_patch) | |
| logits_fake = self.disc(pred) | |
| disc_loss_type = loss_config.get('disc_loss_type', 'hinge') | |
| if disc_loss_type == 'hinge': | |
| loss_real = torch.mean(F.relu(1. - logits_real)) | |
| loss_fake = torch.mean(F.relu(1. + logits_fake)) | |
| loss = (loss_real + loss_fake) / 2 | |
| elif disc_loss_type == 'vanilla': | |
| loss_real = torch.mean(F.softplus(-logits_real)) | |
| loss_fake = torch.mean(F.softplus(logits_fake)) | |
| loss = (loss_real + loss_fake) / 2 | |
| return { | |
| 'loss': loss, | |
| 'disc_logits_real': logits_real.mean().item(), | |
| 'disc_logits_fake': logits_fake.mean().item(), | |
| } | |
| def calculate_adaptive_gan_w(self, nll_loss, g_loss, last_layer): | |
| nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] | |
| g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] | |
| world_size = int(os.environ.get('WORLD_SIZE', '1')) | |
| if world_size > 1: | |
| dist.all_reduce(nll_grads, op=dist.ReduceOp.SUM) | |
| nll_grads.div_(world_size) | |
| dist.all_reduce(g_grads, op=dist.ReduceOp.SUM) | |
| g_grads.div_(world_size) | |
| d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) | |
| d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() | |
| return d_weight | |