Spaces:
Sleeping
Sleeping
File size: 5,547 Bytes
37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
import torch.nn.functional as F
import torch.distributed as dist
from models import register
from models.ldm.ldm_base import LDMBase
from models.ldm.vqgan.lpips import LPIPS
from models.ldm.vqgan.discriminator import make_discriminator
@register('glpto')
class GLPTo(LDMBase):
def __init__(self, lpips=True, disc=True, adaptive_gan_weight=True, noise_render=False, **kwargs):
super().__init__(**kwargs)
if lpips:
self.lpips_loss = LPIPS().eval()
self.disc = make_discriminator(input_nc=3) if disc else None
self.adaptive_gan_weight = adaptive_gan_weight
self.noise_render = noise_render
def get_parameters(self, name):
if name == 'disc':
return self.disc.parameters()
else:
return super().get_parameters(name)
def render(self, z_dec, coord, scale):
if not self.noise_render:
return self.renderer(z_dec, coord=coord, scale=scale)
else:
shape = (coord.shape[0], 3, coord.shape[2], coord.shape[3])
noise = torch.randn(shape, device=z_dec.device)
return self.renderer(noise, coord=coord, scale=scale, z_dec=z_dec)
def forward(self, data, mode, has_optimizer=None, use_gan=False):
if mode in ['z', 'z_dec']:
ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer)
return ret_z
grad = self.get_grad_plan(has_optimizer)
loss_config = self.loss_config
if mode == 'pred':
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
gt_patch = data['gt'][:, :3, ...]
coord = data['gt'][:, 3:5, ...]
scale = data['gt'][:, 5:7, ...]
if grad['renderer']:
return self.render(z_dec, coord, scale)
else:
with torch.no_grad():
return self.render(z_dec, coord, scale)
elif mode == 'loss':
if not grad['renderer']: # Only training zdm
_, ret = super().forward(data, mode='z', has_optimizer=has_optimizer)
return ret
gt_patch = data['gt'][:, :3, ...]
coord = data['gt'][:, 3:5, ...]
scale = data['gt'][:, 5:7, ...]
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
pred = self.render(z_dec, coord, scale)
l1_loss = torch.abs(pred - gt_patch).mean()
ret['l1_loss'] = l1_loss.item()
l1_loss_w = loss_config.get('l1_loss', 1)
ret['loss'] = ret['loss'] + l1_loss * l1_loss_w
lpips_loss = self.lpips_loss(pred, gt_patch).mean()
ret['lpips_loss'] = lpips_loss.item()
lpips_loss_w = loss_config.get('lpips_loss', 1)
ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w
if use_gan:
logits_fake = self.disc(pred)
gan_g_loss = -torch.mean(logits_fake)
ret['gan_g_loss'] = gan_g_loss.item()
weight = loss_config.get('gan_g_loss', 1)
if self.training and self.adaptive_gan_weight:
nll_loss = l1_loss * l1_loss_w + lpips_loss * lpips_loss_w
adaptive_gan_w = self.calculate_adaptive_gan_w(nll_loss, gan_g_loss, self.renderer.get_last_layer_weight())
ret['adaptive_gan_w'] = adaptive_gan_w.item()
weight = weight * adaptive_gan_w
ret['loss'] = ret['loss'] + gan_g_loss * weight
return ret
elif mode == 'disc_loss':
gt_patch = data['gt'][:, :3, ...]
coord = data['gt'][:, 3:5, ...]
scale = data['gt'][:, 5:7, ...]
with torch.no_grad():
z_dec, _ = super().forward(data, mode='z_dec', has_optimizer=None)
pred = self.render(z_dec, coord, scale)
logits_real = self.disc(gt_patch)
logits_fake = self.disc(pred)
disc_loss_type = loss_config.get('disc_loss_type', 'hinge')
if disc_loss_type == 'hinge':
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
loss = (loss_real + loss_fake) / 2
elif disc_loss_type == 'vanilla':
loss_real = torch.mean(F.softplus(-logits_real))
loss_fake = torch.mean(F.softplus(logits_fake))
loss = (loss_real + loss_fake) / 2
return {
'loss': loss,
'disc_logits_real': logits_real.mean().item(),
'disc_logits_fake': logits_fake.mean().item(),
}
def calculate_adaptive_gan_w(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
world_size = int(os.environ.get('WORLD_SIZE', '1'))
if world_size > 1:
dist.all_reduce(nll_grads, op=dist.ReduceOp.SUM)
nll_grads.div_(world_size)
dist.all_reduce(g_grads, op=dist.ReduceOp.SUM)
g_grads.div_(world_size)
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
return d_weight
|