Spaces:
Build error
Build error
Upload gan_losses.py
Browse files- gan_losses.py +213 -0
gan_losses.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2018 Google LLC
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torchvision import models
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_gan_losses(gan_type):
|
| 24 |
+
"""
|
| 25 |
+
Returns the generator and discriminator loss for a particular GAN type.
|
| 26 |
+
|
| 27 |
+
The returned functions have the following API:
|
| 28 |
+
loss_g = g_loss(scores_fake)
|
| 29 |
+
loss_d = d_loss(scores_real, scores_fake)
|
| 30 |
+
"""
|
| 31 |
+
if gan_type == 'gan':
|
| 32 |
+
return gan_g_loss, gan_d_loss
|
| 33 |
+
elif gan_type == 'wgan':
|
| 34 |
+
return wgan_g_loss, wgan_d_loss
|
| 35 |
+
elif gan_type == 'lsgan':
|
| 36 |
+
return lsgan_g_loss, lsgan_d_loss
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError('Unrecognized GAN type "%s"' % gan_type)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def bce_loss(input, target):
|
| 42 |
+
"""
|
| 43 |
+
Numerically stable version of the binary cross-entropy loss function.
|
| 44 |
+
|
| 45 |
+
As per https://github.com/pytorch/pytorch/issues/751
|
| 46 |
+
See the TensorFlow docs for a derivation of this formula:
|
| 47 |
+
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
|
| 48 |
+
|
| 49 |
+
Inputs:
|
| 50 |
+
- input: PyTorch Tensor of shape (N, ) giving scores.
|
| 51 |
+
- target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
- A PyTorch Tensor containing the mean BCE loss over the minibatch of
|
| 55 |
+
input data.
|
| 56 |
+
"""
|
| 57 |
+
neg_abs = -input.abs()
|
| 58 |
+
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
|
| 59 |
+
return loss.mean()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _make_targets(x, y):
|
| 63 |
+
"""
|
| 64 |
+
Inputs:
|
| 65 |
+
- x: PyTorch Tensor
|
| 66 |
+
- y: Python scalar
|
| 67 |
+
|
| 68 |
+
Outputs:
|
| 69 |
+
- out: PyTorch Variable with same shape and dtype as x, but filled with y
|
| 70 |
+
"""
|
| 71 |
+
return torch.full_like(x, y)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def gan_g_loss(scores_fake):
|
| 75 |
+
"""
|
| 76 |
+
Input:
|
| 77 |
+
- scores_fake: Tensor of shape (N,) containing scores for fake samples
|
| 78 |
+
|
| 79 |
+
Output:
|
| 80 |
+
- loss: Variable of shape (,) giving GAN generator loss
|
| 81 |
+
"""
|
| 82 |
+
if scores_fake.dim() > 1:
|
| 83 |
+
scores_fake = scores_fake.view(-1)
|
| 84 |
+
y_fake = _make_targets(scores_fake, 1)
|
| 85 |
+
return bce_loss(scores_fake, y_fake)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def gan_d_loss(scores_real, scores_fake):
|
| 89 |
+
"""
|
| 90 |
+
Input:
|
| 91 |
+
- scores_real: Tensor of shape (N,) giving scores for real samples
|
| 92 |
+
- scores_fake: Tensor of shape (N,) giving scores for fake samples
|
| 93 |
+
|
| 94 |
+
Output:
|
| 95 |
+
- loss: Tensor of shape (,) giving GAN discriminator loss
|
| 96 |
+
"""
|
| 97 |
+
assert scores_real.size() == scores_fake.size()
|
| 98 |
+
if scores_real.dim() > 1:
|
| 99 |
+
scores_real = scores_real.view(-1)
|
| 100 |
+
scores_fake = scores_fake.view(-1)
|
| 101 |
+
y_real = _make_targets(scores_real, 1)
|
| 102 |
+
y_fake = _make_targets(scores_fake, 0)
|
| 103 |
+
loss_real = bce_loss(scores_real, y_real)
|
| 104 |
+
loss_fake = bce_loss(scores_fake, y_fake)
|
| 105 |
+
return loss_real + loss_fake
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def wgan_g_loss(scores_fake):
|
| 109 |
+
"""
|
| 110 |
+
Input:
|
| 111 |
+
- scores_fake: Tensor of shape (N,) containing scores for fake samples
|
| 112 |
+
|
| 113 |
+
Output:
|
| 114 |
+
- loss: Tensor of shape (,) giving WGAN generator loss
|
| 115 |
+
"""
|
| 116 |
+
return -scores_fake.mean()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def wgan_d_loss(scores_real, scores_fake):
|
| 120 |
+
"""
|
| 121 |
+
Input:
|
| 122 |
+
- scores_real: Tensor of shape (N,) giving scores for real samples
|
| 123 |
+
- scores_fake: Tensor of shape (N,) giving scores for fake samples
|
| 124 |
+
|
| 125 |
+
Output:
|
| 126 |
+
- loss: Tensor of shape (,) giving WGAN discriminator loss
|
| 127 |
+
"""
|
| 128 |
+
return scores_fake.mean() - scores_real.mean()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def lsgan_g_loss(scores_fake):
|
| 132 |
+
if scores_fake.dim() > 1:
|
| 133 |
+
scores_fake = scores_fake.view(-1)
|
| 134 |
+
y_fake = _make_targets(scores_fake, 1)
|
| 135 |
+
return F.mse_loss(scores_fake.sigmoid(), y_fake)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def lsgan_d_loss(scores_real, scores_fake):
|
| 139 |
+
assert scores_real.size() == scores_fake.size()
|
| 140 |
+
if scores_real.dim() > 1:
|
| 141 |
+
scores_real = scores_real.view(-1)
|
| 142 |
+
scores_fake = scores_fake.view(-1)
|
| 143 |
+
y_real = _make_targets(scores_real, 1)
|
| 144 |
+
y_fake = _make_targets(scores_fake, 0)
|
| 145 |
+
loss_real = F.mse_loss(scores_real.sigmoid(), y_real)
|
| 146 |
+
loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake)
|
| 147 |
+
return loss_real + loss_fake
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def gradient_penalty(x_real, x_fake, f, gamma=1.0):
|
| 151 |
+
N = x_real.size(0)
|
| 152 |
+
device, dtype = x_real.device, x_real.dtype
|
| 153 |
+
eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype)
|
| 154 |
+
x_hat = eps * x_real + (1 - eps) * x_fake
|
| 155 |
+
x_hat_score = f(x_hat)
|
| 156 |
+
if x_hat_score.dim() > 1:
|
| 157 |
+
x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1)
|
| 158 |
+
x_hat_score = x_hat_score.sum()
|
| 159 |
+
grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True)
|
| 160 |
+
grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1)
|
| 161 |
+
gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean()
|
| 162 |
+
return gp_loss
|
| 163 |
+
|
| 164 |
+
# VGG Features matching
|
| 165 |
+
class Vgg19(torch.nn.Module):
|
| 166 |
+
def __init__(self, requires_grad=False):
|
| 167 |
+
super(Vgg19, self).__init__()
|
| 168 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 169 |
+
self.slice1 = torch.nn.Sequential()
|
| 170 |
+
self.slice2 = torch.nn.Sequential()
|
| 171 |
+
self.slice3 = torch.nn.Sequential()
|
| 172 |
+
self.slice4 = torch.nn.Sequential()
|
| 173 |
+
self.slice5 = torch.nn.Sequential()
|
| 174 |
+
for x in range(2):
|
| 175 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 176 |
+
for x in range(2, 7):
|
| 177 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 178 |
+
for x in range(7, 12):
|
| 179 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 180 |
+
for x in range(12, 21):
|
| 181 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 182 |
+
for x in range(21, 30):
|
| 183 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 184 |
+
if not requires_grad:
|
| 185 |
+
for param in self.parameters():
|
| 186 |
+
param.requires_grad = False
|
| 187 |
+
|
| 188 |
+
def forward(self, X):
|
| 189 |
+
h_relu1 = self.slice1(X)
|
| 190 |
+
h_relu2 = self.slice2(h_relu1)
|
| 191 |
+
h_relu3 = self.slice3(h_relu2)
|
| 192 |
+
h_relu4 = self.slice4(h_relu3)
|
| 193 |
+
h_relu5 = self.slice5(h_relu4)
|
| 194 |
+
out = [h_relu5, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 195 |
+
return out
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class VGGLoss(nn.Module):
|
| 199 |
+
def __init__(self):
|
| 200 |
+
super(VGGLoss, self).__init__()
|
| 201 |
+
if torch.cuda.is_available():
|
| 202 |
+
self.vgg = Vgg19().cuda()
|
| 203 |
+
else:
|
| 204 |
+
self.vgg = Vgg19()
|
| 205 |
+
self.criterion = nn.L1Loss()
|
| 206 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
| 207 |
+
|
| 208 |
+
def forward(self, x, y):
|
| 209 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
| 210 |
+
loss = 0
|
| 211 |
+
for i in range(len(x_vgg)):
|
| 212 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 213 |
+
return loss
|