| """This file contains the definition of utility functions for GANs.""" | |
| import torch | |
| import torch.nn.functional as F | |
| from . import OriginalNLayerDiscriminator, NLayerDiscriminatorv2 | |
| def toggle_off_gradients(model: torch.nn.Module): | |
| """Toggles off gradients for all parameters in a model.""" | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| def toggle_on_gradients(model: torch.nn.Module): | |
| """Toggles on gradients for all parameters in a model.""" | |
| for param in model.parameters(): | |
| param.requires_grad = True | |
| def discriminator_weights_init(m): | |
| """Initialize weights for convolutions in the discriminator.""" | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | |
| def adopt_weight( | |
| weight: float, global_step: int, threshold: int = 0, value: float = 0.0 | |
| ) -> float: | |
| """If global_step is less than threshold, return value, else return weight.""" | |
| if global_step < threshold: | |
| weight = value | |
| return weight | |
| def compute_lecam_loss( | |
| logits_real_mean: torch.Tensor, | |
| logits_fake_mean: torch.Tensor, | |
| ema_logits_real_mean: torch.Tensor, | |
| ema_logits_fake_mean: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Computes the LeCam loss for the given average real and fake logits. | |
| Args: | |
| logits_real_mean -> torch.Tensor: The average real logits. | |
| logits_fake_mean -> torch.Tensor: The average fake logits. | |
| ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. | |
| ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. | |
| Returns: | |
| lecam_loss -> torch.Tensor: The LeCam loss. | |
| """ | |
| lecam_loss = torch.mean( | |
| torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2) | |
| ) | |
| lecam_loss += torch.mean( | |
| torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2) | |
| ) | |
| return lecam_loss | |
| def hinge_g_loss(logits_fake: torch.Tensor) -> torch.Tensor: | |
| """Computes the hinge loss for the generator given the fake logits. | |
| Args: | |
| logits_fake -> torch.Tensor: The fake logits. | |
| Returns: | |
| g_loss -> torch.Tensor: The hinge loss. | |
| """ | |
| g_loss = -torch.mean(logits_fake) | |
| return g_loss | |
| def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: | |
| """Computes the hinge loss for the discriminator given the real and fake logits. | |
| Args: | |
| logits_real -> torch.Tensor: The real logits. | |
| logits_fake -> torch.Tensor: The fake logits. | |
| Returns: | |
| d_loss -> torch.Tensor: The hinge loss. | |
| """ | |
| loss_real = torch.mean(F.relu(1.0 - logits_real)) | |
| loss_fake = torch.mean(F.relu(1.0 + logits_fake)) | |
| d_loss = 0.5 * (loss_real + loss_fake) | |
| return d_loss | |
| def sigmoid_cross_entropy_with_logits( | |
| logits: torch.Tensor, label: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Credits to Magvit. | |
| We use a stable formulation that is equivalent to the one used in TensorFlow. | |
| The following derivation shows how we arrive at the formulation: | |
| .. math:: | |
| z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) | |
| = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) | |
| = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) | |
| = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) | |
| = (1 - z) * x + log(1 + exp(-x)) | |
| = x - x * z + log(1 + exp(-x)) | |
| For x < 0, the following formula is more stable: | |
| .. math:: | |
| x - x * z + log(1 + exp(-x)) | |
| = log(exp(x)) - x * z + log(1 + exp(-x)) | |
| = - x * z + log(1 + exp(x)) | |
| We combine the two cases (x<0, x>=0) into one formula as follows: | |
| .. math:: | |
| max(x, 0) - x * z + log(1 + exp(-abs(x))) | |
| """ | |
| zeros = torch.zeros_like(logits) | |
| cond = logits >= zeros | |
| relu_logits = torch.where(cond, logits, zeros) | |
| neg_abs_logits = torch.where(cond, -logits, logits) | |
| loss = relu_logits - logits * label + torch.log1p(neg_abs_logits.exp()) | |
| return loss | |
| def non_saturating_d_loss( | |
| logits_real: torch.Tensor, logits_fake: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Computes the non-saturating loss for the discriminator given the real and fake logits. | |
| Args: | |
| logits_real -> torch.Tensor: The real logits. | |
| logits_fake -> torch.Tensor: The fake logits. | |
| Returns: | |
| loss -> torch.Tensor: The non-saturating loss. | |
| """ | |
| real_loss = torch.mean( | |
| sigmoid_cross_entropy_with_logits( | |
| logits_real, label=torch.ones_like(logits_real) | |
| ) | |
| ) | |
| fake_loss = torch.mean( | |
| sigmoid_cross_entropy_with_logits( | |
| logits_fake, label=torch.zeros_like(logits_fake) | |
| ) | |
| ) | |
| return torch.mean(real_loss) + torch.mean(fake_loss) | |
| def non_saturating_g_loss(logits_fake: torch.Tensor) -> torch.Tensor: | |
| """Computes the non-saturating loss for the generator given the fake logits. | |
| Args: | |
| logits_fake -> torch.Tensor: The fake logits. | |
| Returns: | |
| loss -> torch.Tensor: The non-saturating loss. | |
| """ | |
| return torch.mean( | |
| sigmoid_cross_entropy_with_logits( | |
| logits_fake, label=torch.ones_like(logits_fake) | |
| ) | |
| ) | |
| def vanilla_d_loss( | |
| logits_real: torch.Tensor, logits_fake: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Computes the vanilla loss for the discriminator given the real and fake logits. | |
| Args: | |
| logits_real -> torch.Tensor: The real logits. | |
| logits_fake -> torch.Tensor: The fake logits. | |
| Returns: | |
| loss -> torch.Tensor: The vanilla loss. | |
| """ | |
| d_loss = 0.5 * ( | |
| torch.mean(torch.nn.functional.softplus(-logits_real)) | |
| + torch.mean(torch.nn.functional.softplus(logits_fake)) | |
| ) | |
| return d_loss | |
| def create_discriminator(discriminator_config) -> torch.nn.Module: | |
| """Creates a discriminator based on the given config. | |
| Args: | |
| discriminator_config: The config for the discriminator. | |
| Returns: | |
| discriminator -> torch.nn.Module: The discriminator. | |
| """ | |
| if discriminator_config.name == "Original": | |
| return OriginalNLayerDiscriminator( | |
| num_channels=discriminator_config.num_channels, | |
| num_stages=discriminator_config.num_stages, | |
| hidden_channels=discriminator_config.hidden_channels, | |
| ).apply(discriminator_weights_init) | |
| elif discriminator_config.name == "VQGAN+Discriminator": | |
| return NLayerDiscriminatorv2( | |
| num_channels=discriminator_config.num_channels, | |
| num_stages=discriminator_config.num_stages, | |
| hidden_channels=discriminator_config.hidden_channels, | |
| blur_resample=discriminator_config.blur_resample, | |
| blur_kernel_size=discriminator_config.get("blur_kernel_size", 4), | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Discriminator {discriminator_config.name} is not implemented." | |
| ) | |