Spaces:
Configuration error
Configuration error
| from typing import Any, Union | |
| from math import log2 | |
| from beartype import beartype | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.autograd import grad as torch_grad | |
| from torch.cuda.amp import autocast | |
| import torchvision | |
| from torchvision.models import VGG16_Weights | |
| from einops import rearrange, einsum, repeat | |
| from einops.layers.torch import Rearrange | |
| from kornia.filters import filter3d | |
| from ..magvit2_pytorch import Residual, FeedForward, LinearSpaceAttention | |
| from .lpips import LPIPS | |
| from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D | |
| from sgm.util import instantiate_from_config | |
| def exists(v): | |
| return v is not None | |
| def pair(t): | |
| return t if isinstance(t, tuple) else (t, t) | |
| def leaky_relu(p=0.1): | |
| return nn.LeakyReLU(p) | |
| def hinge_discr_loss(fake, real): | |
| return (F.relu(1 + fake) + F.relu(1 - real)).mean() | |
| def hinge_gen_loss(fake): | |
| return -fake.mean() | |
| def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): | |
| return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() | |
| def pick_video_frame(video, frame_indices): | |
| batch, device = video.shape[0], video.device | |
| video = rearrange(video, "b c f ... -> b f c ...") | |
| batch_indices = torch.arange(batch, device=device) | |
| batch_indices = rearrange(batch_indices, "b -> b 1") | |
| images = video[batch_indices, frame_indices] | |
| images = rearrange(images, "b 1 c ... -> b c ...") | |
| return images | |
| def gradient_penalty(images, output): | |
| batch_size = images.shape[0] | |
| gradients = torch_grad( | |
| outputs=output, | |
| inputs=images, | |
| grad_outputs=torch.ones(output.size(), device=images.device), | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True, | |
| )[0] | |
| gradients = rearrange(gradients, "b ... -> b (...)") | |
| return ((gradients.norm(2, dim=1) - 1) ** 2).mean() | |
| # discriminator with anti-aliased downsampling (blurpool Zhang et al.) | |
| class Blur(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| f = torch.Tensor([1, 2, 1]) | |
| self.register_buffer("f", f) | |
| def forward(self, x, space_only=False, time_only=False): | |
| assert not (space_only and time_only) | |
| f = self.f | |
| if space_only: | |
| f = einsum("i, j -> i j", f, f) | |
| f = rearrange(f, "... -> 1 1 ...") | |
| elif time_only: | |
| f = rearrange(f, "f -> 1 f 1 1") | |
| else: | |
| f = einsum("i, j, k -> i j k", f, f, f) | |
| f = rearrange(f, "... -> 1 ...") | |
| is_images = x.ndim == 4 | |
| if is_images: | |
| x = rearrange(x, "b c h w -> b c 1 h w") | |
| out = filter3d(x, f, normalized=True) | |
| if is_images: | |
| out = rearrange(out, "b c 1 h w -> b c h w") | |
| return out | |
| class DiscriminatorBlock(nn.Module): | |
| def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True): | |
| super().__init__() | |
| self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) | |
| self.net = nn.Sequential( | |
| nn.Conv2d(input_channels, filters, 3, padding=1), | |
| leaky_relu(), | |
| nn.Conv2d(filters, filters, 3, padding=1), | |
| leaky_relu(), | |
| ) | |
| self.maybe_blur = Blur() if antialiased_downsample else None | |
| self.downsample = ( | |
| nn.Sequential( | |
| Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) | |
| ) | |
| if downsample | |
| else None | |
| ) | |
| def forward(self, x): | |
| res = self.conv_res(x) | |
| x = self.net(x) | |
| if exists(self.downsample): | |
| if exists(self.maybe_blur): | |
| x = self.maybe_blur(x, space_only=True) | |
| x = self.downsample(x) | |
| x = (x + res) * (2**-0.5) | |
| return x | |
| class Discriminator(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| image_size, | |
| channels=3, | |
| max_dim=512, | |
| attn_heads=8, | |
| attn_dim_head=32, | |
| linear_attn_dim_head=8, | |
| linear_attn_heads=16, | |
| ff_mult=4, | |
| antialiased_downsample=False, | |
| ): | |
| super().__init__() | |
| image_size = pair(image_size) | |
| min_image_resolution = min(image_size) | |
| num_layers = int(log2(min_image_resolution) - 2) | |
| blocks = [] | |
| layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] | |
| layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] | |
| layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) | |
| blocks = [] | |
| attn_blocks = [] | |
| image_resolution = min_image_resolution | |
| for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): | |
| num_layer = ind + 1 | |
| is_not_last = ind != (len(layer_dims_in_out) - 1) | |
| block = DiscriminatorBlock( | |
| in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample | |
| ) | |
| attn_block = nn.Sequential( | |
| Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), | |
| Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), | |
| ) | |
| blocks.append(nn.ModuleList([block, attn_block])) | |
| image_resolution //= 2 | |
| self.blocks = nn.ModuleList(blocks) | |
| dim_last = layer_dims[-1] | |
| downsample_factor = 2**num_layers | |
| last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) | |
| latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last | |
| self.to_logits = nn.Sequential( | |
| nn.Conv2d(dim_last, dim_last, 3, padding=1), | |
| leaky_relu(), | |
| Rearrange("b ... -> b (...)"), | |
| nn.Linear(latent_dim, 1), | |
| Rearrange("b 1 -> b"), | |
| ) | |
| def forward(self, x): | |
| for block, attn_block in self.blocks: | |
| x = block(x) | |
| x = attn_block(x) | |
| return self.to_logits(x) | |
| class DiscriminatorBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels, | |
| filters, | |
| antialiased_downsample=True, | |
| ): | |
| super().__init__() | |
| self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2) | |
| self.net = nn.Sequential( | |
| nn.Conv3d(input_channels, filters, 3, padding=1), | |
| leaky_relu(), | |
| nn.Conv3d(filters, filters, 3, padding=1), | |
| leaky_relu(), | |
| ) | |
| self.maybe_blur = Blur() if antialiased_downsample else None | |
| self.downsample = nn.Sequential( | |
| Rearrange("b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w", p1=2, p2=2, p3=2), | |
| nn.Conv3d(filters * 8, filters, 1), | |
| ) | |
| def forward(self, x): | |
| res = self.conv_res(x) | |
| x = self.net(x) | |
| if exists(self.downsample): | |
| if exists(self.maybe_blur): | |
| x = self.maybe_blur(x, space_only=True) | |
| x = self.downsample(x) | |
| x = (x + res) * (2**-0.5) | |
| return x | |
| class DiscriminatorBlock3DWithfirstframe(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels, | |
| filters, | |
| antialiased_downsample=True, | |
| pad_mode="first", | |
| ): | |
| super().__init__() | |
| self.downsample_res = DownSample3D( | |
| in_channels=input_channels, | |
| out_channels=filters, | |
| with_conv=True, | |
| compress_time=True, | |
| ) | |
| self.net = nn.Sequential( | |
| CausalConv3d(input_channels, filters, kernel_size=3, pad_mode=pad_mode), | |
| leaky_relu(), | |
| CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode), | |
| leaky_relu(), | |
| ) | |
| self.maybe_blur = Blur() if antialiased_downsample else None | |
| self.downsample = DownSample3D( | |
| in_channels=filters, | |
| out_channels=filters, | |
| with_conv=True, | |
| compress_time=True, | |
| ) | |
| def forward(self, x): | |
| res = self.downsample_res(x) | |
| x = self.net(x) | |
| if exists(self.downsample): | |
| if exists(self.maybe_blur): | |
| x = self.maybe_blur(x, space_only=True) | |
| x = self.downsample(x) | |
| x = (x + res) * (2**-0.5) | |
| return x | |
| class Discriminator3D(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| image_size, | |
| frame_num, | |
| channels=3, | |
| max_dim=512, | |
| linear_attn_dim_head=8, | |
| linear_attn_heads=16, | |
| ff_mult=4, | |
| antialiased_downsample=False, | |
| ): | |
| super().__init__() | |
| image_size = pair(image_size) | |
| min_image_resolution = min(image_size) | |
| num_layers = int(log2(min_image_resolution) - 2) | |
| temporal_num_layers = int(log2(frame_num)) | |
| self.temporal_num_layers = temporal_num_layers | |
| layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] | |
| layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] | |
| layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) | |
| blocks = [] | |
| image_resolution = min_image_resolution | |
| frame_resolution = frame_num | |
| for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): | |
| num_layer = ind + 1 | |
| is_not_last = ind != (len(layer_dims_in_out) - 1) | |
| if ind < temporal_num_layers: | |
| block = DiscriminatorBlock3D( | |
| in_chan, | |
| out_chan, | |
| antialiased_downsample=antialiased_downsample, | |
| ) | |
| blocks.append(block) | |
| frame_resolution //= 2 | |
| else: | |
| block = DiscriminatorBlock( | |
| in_chan, | |
| out_chan, | |
| downsample=is_not_last, | |
| antialiased_downsample=antialiased_downsample, | |
| ) | |
| attn_block = nn.Sequential( | |
| Residual( | |
| LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) | |
| ), | |
| Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), | |
| ) | |
| blocks.append(nn.ModuleList([block, attn_block])) | |
| image_resolution //= 2 | |
| self.blocks = nn.ModuleList(blocks) | |
| dim_last = layer_dims[-1] | |
| downsample_factor = 2**num_layers | |
| last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) | |
| latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last | |
| self.to_logits = nn.Sequential( | |
| nn.Conv2d(dim_last, dim_last, 3, padding=1), | |
| leaky_relu(), | |
| Rearrange("b ... -> b (...)"), | |
| nn.Linear(latent_dim, 1), | |
| Rearrange("b 1 -> b"), | |
| ) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.blocks): | |
| if i < self.temporal_num_layers: | |
| x = layer(x) | |
| if i == self.temporal_num_layers - 1: | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| else: | |
| block, attn_block = layer | |
| x = block(x) | |
| x = attn_block(x) | |
| return self.to_logits(x) | |
| class Discriminator3DWithfirstframe(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| image_size, | |
| frame_num, | |
| channels=3, | |
| max_dim=512, | |
| linear_attn_dim_head=8, | |
| linear_attn_heads=16, | |
| ff_mult=4, | |
| antialiased_downsample=False, | |
| ): | |
| super().__init__() | |
| image_size = pair(image_size) | |
| min_image_resolution = min(image_size) | |
| num_layers = int(log2(min_image_resolution) - 2) | |
| temporal_num_layers = int(log2(frame_num)) | |
| self.temporal_num_layers = temporal_num_layers | |
| layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] | |
| layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] | |
| layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) | |
| blocks = [] | |
| image_resolution = min_image_resolution | |
| frame_resolution = frame_num | |
| for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): | |
| num_layer = ind + 1 | |
| is_not_last = ind != (len(layer_dims_in_out) - 1) | |
| if ind < temporal_num_layers: | |
| block = DiscriminatorBlock3DWithfirstframe( | |
| in_chan, | |
| out_chan, | |
| antialiased_downsample=antialiased_downsample, | |
| ) | |
| blocks.append(block) | |
| frame_resolution //= 2 | |
| else: | |
| block = DiscriminatorBlock( | |
| in_chan, | |
| out_chan, | |
| downsample=is_not_last, | |
| antialiased_downsample=antialiased_downsample, | |
| ) | |
| attn_block = nn.Sequential( | |
| Residual( | |
| LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) | |
| ), | |
| Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), | |
| ) | |
| blocks.append(nn.ModuleList([block, attn_block])) | |
| image_resolution //= 2 | |
| self.blocks = nn.ModuleList(blocks) | |
| dim_last = layer_dims[-1] | |
| downsample_factor = 2**num_layers | |
| last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) | |
| latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last | |
| self.to_logits = nn.Sequential( | |
| nn.Conv2d(dim_last, dim_last, 3, padding=1), | |
| leaky_relu(), | |
| Rearrange("b ... -> b (...)"), | |
| nn.Linear(latent_dim, 1), | |
| Rearrange("b 1 -> b"), | |
| ) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.blocks): | |
| if i < self.temporal_num_layers: | |
| x = layer(x) | |
| if i == self.temporal_num_layers - 1: | |
| x = x.mean(dim=2) | |
| # x = rearrange(x, "b c f h w -> (b f) c h w") | |
| else: | |
| block, attn_block = layer | |
| x = block(x) | |
| x = attn_block(x) | |
| return self.to_logits(x) | |
| class VideoAutoencoderLoss(nn.Module): | |
| def __init__( | |
| self, | |
| disc_start, | |
| perceptual_weight=1, | |
| adversarial_loss_weight=0, | |
| multiscale_adversarial_loss_weight=0, | |
| grad_penalty_loss_weight=0, | |
| quantizer_aux_loss_weight=0, | |
| vgg_weights=VGG16_Weights.DEFAULT, | |
| discr_kwargs=None, | |
| discr_3d_kwargs=None, | |
| ): | |
| super().__init__() | |
| self.disc_start = disc_start | |
| self.perceptual_weight = perceptual_weight | |
| self.adversarial_loss_weight = adversarial_loss_weight | |
| self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight | |
| self.grad_penalty_loss_weight = grad_penalty_loss_weight | |
| self.quantizer_aux_loss_weight = quantizer_aux_loss_weight | |
| if self.perceptual_weight > 0: | |
| self.perceptual_model = LPIPS().eval() | |
| # self.vgg = torchvision.models.vgg16(pretrained = True) | |
| # self.vgg.requires_grad_(False) | |
| # if self.adversarial_loss_weight > 0: | |
| # self.discr = Discriminator(**discr_kwargs) | |
| # else: | |
| # self.discr = None | |
| # if self.multiscale_adversarial_loss_weight > 0: | |
| # self.multiscale_discrs = nn.ModuleList([*multiscale_discrs]) | |
| # else: | |
| # self.multiscale_discrs = None | |
| if discr_kwargs is not None: | |
| self.discr = Discriminator(**discr_kwargs) | |
| else: | |
| self.discr = None | |
| if discr_3d_kwargs is not None: | |
| # self.discr_3d = Discriminator3D(**discr_3d_kwargs) | |
| self.discr_3d = instantiate_from_config(discr_3d_kwargs) | |
| else: | |
| self.discr_3d = None | |
| # self.multiscale_discrs = nn.ModuleList([*multiscale_discrs]) | |
| self.register_buffer("zero", torch.tensor(0.0), persistent=False) | |
| def get_trainable_params(self) -> Any: | |
| params = [] | |
| if self.discr is not None: | |
| params += list(self.discr.parameters()) | |
| if self.discr_3d is not None: | |
| params += list(self.discr_3d.parameters()) | |
| # if self.multiscale_discrs is not None: | |
| # for discr in self.multiscale_discrs: | |
| # params += list(discr.parameters()) | |
| return params | |
| def get_trainable_parameters(self) -> Any: | |
| return self.get_trainable_params() | |
| def forward( | |
| self, | |
| inputs, | |
| reconstructions, | |
| optimizer_idx, | |
| global_step, | |
| aux_losses=None, | |
| last_layer=None, | |
| split="train", | |
| ): | |
| batch, channels, frames = inputs.shape[:3] | |
| if optimizer_idx == 0: | |
| recon_loss = F.mse_loss(inputs, reconstructions) | |
| if self.perceptual_weight > 0: | |
| frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices | |
| input_frames = pick_video_frame(inputs, frame_indices) | |
| recon_frames = pick_video_frame(reconstructions, frame_indices) | |
| perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean() | |
| else: | |
| perceptual_loss = self.zero | |
| if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0: | |
| gen_loss = self.zero | |
| adaptive_weight = 0 | |
| else: | |
| # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices | |
| # recon_video_frames = pick_video_frame(reconstructions, frame_indices) | |
| # fake_logits = self.discr(recon_video_frames) | |
| fake_logits = self.discr_3d(reconstructions) | |
| gen_loss = hinge_gen_loss(fake_logits) | |
| adaptive_weight = 1 | |
| if self.perceptual_weight > 0 and last_layer is not None: | |
| norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2) | |
| norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2) | |
| adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) | |
| adaptive_weight.clamp_(max=1e3) | |
| if torch.isnan(adaptive_weight).any(): | |
| adaptive_weight = 1 | |
| # multiscale discriminator losses | |
| # multiscale_gen_losses = [] | |
| # multiscale_gen_adaptive_weights = [] | |
| # if self.multiscale_adversarial_loss_weight > 0: | |
| # if not exists(recon_video_frames): | |
| # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices | |
| # recon_video_frames = pick_video_frame(reconstructions, frame_indices) | |
| # for discr in self.multiscale_discrs: | |
| # fake_logits = recon_video_frames | |
| # multiscale_gen_loss = hinge_gen_loss(fake_logits) | |
| # multiscale_gen_losses.append(multiscale_gen_loss) | |
| # multiscale_adaptive_weight = 1. | |
| # if exists(norm_grad_wrt_perceptual_loss): | |
| # norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2) | |
| # multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5) | |
| # multiscale_adaptive_weight.clamp_(max = 1e3) | |
| # multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight) | |
| # weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)) | |
| # else: | |
| # weighted_multiscale_gen_losses = self.zero | |
| if aux_losses is None: | |
| aux_losses = self.zero | |
| total_loss = ( | |
| recon_loss | |
| + aux_losses * self.quantizer_aux_loss_weight | |
| + perceptual_loss * self.perceptual_weight | |
| + gen_loss * self.adversarial_loss_weight | |
| ) | |
| # gen_loss * adaptive_weight * self.adversarial_loss_weight + \ | |
| # weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight | |
| log = { | |
| "{}/total_loss".format(split): total_loss.detach(), | |
| "{}/recon_loss".format(split): recon_loss.detach(), | |
| "{}/perceptual_loss".format(split): perceptual_loss.detach(), | |
| "{}/gen_loss".format(split): gen_loss.detach(), | |
| "{}/aux_losses".format(split): aux_losses.detach(), | |
| # "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(), | |
| "{}/adaptive_weight".format(split): adaptive_weight, | |
| # "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights), | |
| } | |
| return total_loss, log | |
| if optimizer_idx == 1: | |
| # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices | |
| # real = pick_video_frame(inputs, frame_indices) | |
| # fake = pick_video_frame(reconstructions, frame_indices) | |
| # apply_gradient_penalty = self.grad_penalty_loss_weight > 0 | |
| # if apply_gradient_penalty: | |
| # real = real.requires_grad_() | |
| # real_logits = self.discr(real) | |
| # fake_logits = self.discr(fake.detach()) | |
| apply_gradient_penalty = self.grad_penalty_loss_weight > 0 | |
| if apply_gradient_penalty: | |
| inputs = inputs.requires_grad_() | |
| real_logits = self.discr_3d(inputs) | |
| fake_logits = self.discr_3d(reconstructions.detach()) | |
| discr_loss = hinge_discr_loss(fake_logits, real_logits) | |
| # # multiscale discriminators | |
| # multiscale_discr_losses = [] | |
| # if self.multiscale_adversarial_loss_weight > 0: | |
| # for discr in self.multiscale_discrs: | |
| # multiscale_real_logits = discr(inputs) | |
| # multiscale_fake_logits = discr(reconstructions.detach()) | |
| # multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits) | |
| # multiscale_discr_losses.append(multiscale_discr_loss) | |
| # else: | |
| # multiscale_discr_losses.append(self.zero) | |
| # gradient penalty | |
| if apply_gradient_penalty: | |
| # gradient_penalty_loss = gradient_penalty(real, real_logits) | |
| gradient_penalty_loss = gradient_penalty(inputs, real_logits) | |
| else: | |
| gradient_penalty_loss = self.zero | |
| total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss | |
| # self.grad_penalty_loss_weight * gradient_penalty_loss + \ | |
| # sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight | |
| log = { | |
| "{}/total_disc_loss".format(split): total_loss.detach(), | |
| "{}/discr_loss".format(split): discr_loss.detach(), | |
| "{}/grad_penalty_loss".format(split): gradient_penalty_loss.detach(), | |
| # "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(), | |
| "{}/logits_real".format(split): real_logits.detach().mean(), | |
| "{}/logits_fake".format(split): fake_logits.detach().mean(), | |
| } | |
| return total_loss, log | |