Spaces:
Build error
Build error
| #!/usr/bin/python | |
| # | |
| # Copyright 2018 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| from torch import nn | |
| def get_gan_losses(gan_type): | |
| """ | |
| Returns the generator and discriminator loss for a particular GAN type. | |
| The returned functions have the following API: | |
| loss_g = g_loss(scores_fake) | |
| loss_d = d_loss(scores_real, scores_fake) | |
| """ | |
| if gan_type == 'gan': | |
| return gan_g_loss, gan_d_loss | |
| elif gan_type == 'wgan': | |
| return wgan_g_loss, wgan_d_loss | |
| elif gan_type == 'lsgan': | |
| return lsgan_g_loss, lsgan_d_loss | |
| else: | |
| raise ValueError('Unrecognized GAN type "%s"' % gan_type) | |
| def bce_loss(input, target): | |
| """ | |
| Numerically stable version of the binary cross-entropy loss function. | |
| As per https://github.com/pytorch/pytorch/issues/751 | |
| See the TensorFlow docs for a derivation of this formula: | |
| https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits | |
| Inputs: | |
| - input: PyTorch Tensor of shape (N, ) giving scores. | |
| - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets. | |
| Returns: | |
| - A PyTorch Tensor containing the mean BCE loss over the minibatch of | |
| input data. | |
| """ | |
| neg_abs = -input.abs() | |
| loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() | |
| return loss.mean() | |
| def _make_targets(x, y): | |
| """ | |
| Inputs: | |
| - x: PyTorch Tensor | |
| - y: Python scalar | |
| Outputs: | |
| - out: PyTorch Variable with same shape and dtype as x, but filled with y | |
| """ | |
| return torch.full_like(x, y) | |
| def gan_g_loss(scores_fake): | |
| """ | |
| Input: | |
| - scores_fake: Tensor of shape (N,) containing scores for fake samples | |
| Output: | |
| - loss: Variable of shape (,) giving GAN generator loss | |
| """ | |
| if scores_fake.dim() > 1: | |
| scores_fake = scores_fake.view(-1) | |
| y_fake = _make_targets(scores_fake, 1) | |
| return bce_loss(scores_fake, y_fake) | |
| def gan_d_loss(scores_real, scores_fake): | |
| """ | |
| Input: | |
| - scores_real: Tensor of shape (N,) giving scores for real samples | |
| - scores_fake: Tensor of shape (N,) giving scores for fake samples | |
| Output: | |
| - loss: Tensor of shape (,) giving GAN discriminator loss | |
| """ | |
| assert scores_real.size() == scores_fake.size() | |
| if scores_real.dim() > 1: | |
| scores_real = scores_real.view(-1) | |
| scores_fake = scores_fake.view(-1) | |
| y_real = _make_targets(scores_real, 1) | |
| y_fake = _make_targets(scores_fake, 0) | |
| loss_real = bce_loss(scores_real, y_real) | |
| loss_fake = bce_loss(scores_fake, y_fake) | |
| return loss_real + loss_fake | |
| def wgan_g_loss(scores_fake): | |
| """ | |
| Input: | |
| - scores_fake: Tensor of shape (N,) containing scores for fake samples | |
| Output: | |
| - loss: Tensor of shape (,) giving WGAN generator loss | |
| """ | |
| return -scores_fake.mean() | |
| def wgan_d_loss(scores_real, scores_fake): | |
| """ | |
| Input: | |
| - scores_real: Tensor of shape (N,) giving scores for real samples | |
| - scores_fake: Tensor of shape (N,) giving scores for fake samples | |
| Output: | |
| - loss: Tensor of shape (,) giving WGAN discriminator loss | |
| """ | |
| return scores_fake.mean() - scores_real.mean() | |
| def lsgan_g_loss(scores_fake): | |
| if scores_fake.dim() > 1: | |
| scores_fake = scores_fake.view(-1) | |
| y_fake = _make_targets(scores_fake, 1) | |
| return F.mse_loss(scores_fake.sigmoid(), y_fake) | |
| def lsgan_d_loss(scores_real, scores_fake): | |
| assert scores_real.size() == scores_fake.size() | |
| if scores_real.dim() > 1: | |
| scores_real = scores_real.view(-1) | |
| scores_fake = scores_fake.view(-1) | |
| y_real = _make_targets(scores_real, 1) | |
| y_fake = _make_targets(scores_fake, 0) | |
| loss_real = F.mse_loss(scores_real.sigmoid(), y_real) | |
| loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake) | |
| return loss_real + loss_fake | |
| def gradient_penalty(x_real, x_fake, f, gamma=1.0): | |
| N = x_real.size(0) | |
| device, dtype = x_real.device, x_real.dtype | |
| eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype) | |
| x_hat = eps * x_real + (1 - eps) * x_fake | |
| x_hat_score = f(x_hat) | |
| if x_hat_score.dim() > 1: | |
| x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1) | |
| x_hat_score = x_hat_score.sum() | |
| grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True) | |
| grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1) | |
| gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean() | |
| return gp_loss | |
| # VGG Features matching | |
| class Vgg19(torch.nn.Module): | |
| def __init__(self, requires_grad=False): | |
| super(Vgg19, self).__init__() | |
| vgg_pretrained_features = models.vgg19(pretrained=True).features | |
| self.slice1 = torch.nn.Sequential() | |
| self.slice2 = torch.nn.Sequential() | |
| self.slice3 = torch.nn.Sequential() | |
| self.slice4 = torch.nn.Sequential() | |
| self.slice5 = torch.nn.Sequential() | |
| for x in range(2): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(2, 7): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(7, 12): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(12, 21): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(21, 30): | |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, X): | |
| h_relu1 = self.slice1(X) | |
| h_relu2 = self.slice2(h_relu1) | |
| h_relu3 = self.slice3(h_relu2) | |
| h_relu4 = self.slice4(h_relu3) | |
| h_relu5 = self.slice5(h_relu4) | |
| out = [h_relu5, h_relu2, h_relu3, h_relu4, h_relu5] | |
| return out | |
| class VGGLoss(nn.Module): | |
| def __init__(self): | |
| super(VGGLoss, self).__init__() | |
| if torch.cuda.is_available(): | |
| self.vgg = Vgg19().cuda() | |
| else: | |
| self.vgg = Vgg19() | |
| self.criterion = nn.L1Loss() | |
| self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
| def forward(self, x, y): | |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
| loss = 0 | |
| for i in range(len(x_vgg)): | |
| loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
| return loss | |