Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| #################################################################################################### | |
| # adversarial loss for different gan mode | |
| #################################################################################################### | |
| class GANLoss(nn.Module): | |
| """Define different GAN objectives. | |
| The GANLoss class abstracts away the need to create the target label tensor | |
| that has the same size as the input. | |
| """ | |
| def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): | |
| """ Initialize the GANLoss class. | |
| Parameters: | |
| gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. | |
| target_real_label (bool) - - label for a real image | |
| target_fake_label (bool) - - label of a fake image | |
| Note: Do not use sigmoid as the last layer of Discriminator. | |
| LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. | |
| """ | |
| super(GANLoss, self).__init__() | |
| self.register_buffer('real_label', torch.tensor(target_real_label)) | |
| self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
| self.gan_mode = gan_mode | |
| if gan_mode == 'lsgan': | |
| self.loss = nn.MSELoss() | |
| elif gan_mode == 'vanilla': | |
| self.loss = nn.BCEWithLogitsLoss() | |
| elif gan_mode == 'hinge': | |
| self.loss = nn.ReLU() | |
| elif gan_mode in ['wgangp', 'nonsaturating']: | |
| self.loss = None | |
| else: | |
| raise NotImplementedError('gan mode %s not implemented' % gan_mode) | |
| def get_target_tensor(self, prediction, target_is_real): | |
| """Create label tensors with the same size as the input. | |
| Parameters: | |
| prediction (tensor) - - tpyically the prediction from a discriminator | |
| target_is_real (bool) - - if the ground truth label is for real examples or fake examples | |
| Returns: | |
| A label tensor filled with ground truth label, and with the size of the input | |
| """ | |
| if target_is_real: | |
| target_tensor = self.real_label | |
| else: | |
| target_tensor = self.fake_label | |
| return target_tensor.expand_as(prediction) | |
| def calculate_loss(self, prediction, target_is_real, is_dis=False): | |
| """Calculate loss given Discriminator's output and grount truth labels. | |
| Parameters: | |
| prediction (tensor) - - tpyically the prediction output from a discriminator | |
| target_is_real (bool) - - if the ground truth label is for real examples or fake examples | |
| Returns: | |
| the calculated loss. | |
| """ | |
| if self.gan_mode in ['lsgan', 'vanilla']: | |
| target_tensor = self.get_target_tensor(prediction, target_is_real) | |
| loss = self.loss(prediction, target_tensor) | |
| if self.gan_mode == 'lsgan': | |
| loss = loss * 0.5 | |
| else: | |
| if is_dis: | |
| if target_is_real: | |
| prediction = -prediction | |
| if self.gan_mode == 'wgangp': | |
| loss = prediction.mean() | |
| elif self.gan_mode == 'nonsaturating': | |
| loss = F.softplus(prediction).mean() | |
| elif self.gan_mode == 'hinge': | |
| loss = self.loss(1+prediction).mean() | |
| else: | |
| if self.gan_mode == 'nonsaturating': | |
| loss = F.softplus(-prediction).mean() | |
| else: | |
| loss = -prediction.mean() | |
| return loss | |
| def __call__(self, predictions, target_is_real, is_dis=False): | |
| """Calculate loss for multi-scales gan""" | |
| if isinstance(predictions, list): | |
| losses = [] | |
| for prediction in predictions: | |
| losses.append(self.calculate_loss(prediction, target_is_real, is_dis)) | |
| loss = sum(losses) | |
| else: | |
| loss = self.calculate_loss(predictions, target_is_real, is_dis) | |
| return loss | |
| def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): | |
| """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 | |
| Arguments: | |
| netD (network) -- discriminator network | |
| real_data (tensor array) -- real examples | |
| fake_data (tensor array) -- generated examples from the generator | |
| device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') | |
| type (str) -- if we mix real and fake data or not [real | fake | mixed]. | |
| constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 | |
| lambda_gp (float) -- weight for this loss | |
| Returns the gradient penalty loss | |
| """ | |
| if lambda_gp > 0.0: | |
| if type == 'real': # either use real examples, fake examples, or a linear interpolation of two. | |
| interpolatesv = real_data | |
| elif type == 'fake': | |
| interpolatesv = fake_data | |
| elif type == 'mixed': | |
| alpha = torch.rand(real_data.shape[0], 1, device=device) | |
| alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) | |
| interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) | |
| else: | |
| raise NotImplementedError('{} not implemented'.format(type)) | |
| interpolatesv.requires_grad_(True) | |
| disc_interpolates = netD(interpolatesv) | |
| if isinstance(disc_interpolates, list): | |
| gradients = 0 | |
| for disc_interpolate in disc_interpolates: | |
| gradients += torch.autograd.grad(outputs=disc_interpolate, inputs=interpolatesv, | |
| grad_outputs=torch.ones(disc_interpolate.size()).to(device), | |
| create_graph=True, retain_graph=True, only_inputs=True)[0] | |
| else: | |
| gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, | |
| grad_outputs=torch.ones(disc_interpolates.size()).to(device), | |
| create_graph=True, retain_graph=True, only_inputs=True)[0] | |
| gradients = gradients.view(real_data.size(0), -1) # flat the data | |
| gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps | |
| return gradient_penalty, gradients | |
| else: | |
| return 0.0, None | |
| #################################################################################################### | |
| # trained LPIPS loss | |
| #################################################################################################### | |
| def normalize_tensor(x, eps=1e-10): | |
| norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) | |
| return x/(norm_factor+eps) | |
| def spatial_average(x, keepdim=True): | |
| return x.mean([2, 3], keepdim=keepdim) | |
| class NetLinLayer(nn.Module): | |
| """ A single linear layer which does a 1x1 conv """ | |
| def __init__(self, chn_in, chn_out=1, use_dropout=False): | |
| super(NetLinLayer, self).__init__() | |
| layers = [nn.Dropout(), ] if (use_dropout) else [] | |
| layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] | |
| self.model = nn.Sequential(*layers) | |
| class LPIPSLoss(nn.Module): | |
| """ | |
| Learned perceptual metric | |
| https://github.com/richzhang/PerceptualSimilarity | |
| """ | |
| def __init__(self, use_dropout=True, ckpt_path=None): | |
| super(LPIPSLoss, self).__init__() | |
| self.path = ckpt_path | |
| self.net = VGG16() | |
| self.chns = [64, 128, 256, 512, 512] # vg16 features | |
| self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) | |
| self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) | |
| self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) | |
| self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) | |
| self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) | |
| self.load_from_pretrained() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def load_from_pretrained(self): | |
| self.load_state_dict(torch.load(self.path, map_location=torch.device("cpu")), strict=False) | |
| print("loaded pretrained LPIPS loss from {}".format(self.path)) | |
| def _get_features(self, vgg_f): | |
| names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] | |
| feats = [] | |
| for i in range(len(names)): | |
| name = names[i] | |
| feat = vgg_f[name] | |
| feats.append(feat) | |
| return feats | |
| def forward(self, x, y): | |
| x_vgg, y_vgg = self._get_features(self.net(x)), self._get_features(self.net(y)) | |
| lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | |
| reses = [] | |
| loss = 0 | |
| for i in range(len(self.chns)): | |
| x_feats, y_feats = normalize_tensor(x_vgg[i]), normalize_tensor(y_vgg[i]) | |
| diffs = (x_feats - y_feats) ** 2 | |
| res = spatial_average(lins[i].model(diffs)) | |
| loss += res | |
| reses.append(res) | |
| return loss | |
| class PerceptualLoss(nn.Module): | |
| r""" | |
| Perceptual loss, VGG-based | |
| https://arxiv.org/abs/1603.08155 | |
| https://github.com/dxyang/StyleTransfer/blob/master/utils.py | |
| """ | |
| def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 0.0]): | |
| super(PerceptualLoss, self).__init__() | |
| self.add_module('vgg', VGG16()) | |
| self.criterion = nn.L1Loss() | |
| self.weights = weights | |
| def __call__(self, x, y): | |
| # Compute features | |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
| content_loss = 0.0 | |
| content_loss += self.weights[0] * self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) if self.weights[0] > 0 else 0 | |
| content_loss += self.weights[1] * self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) if self.weights[1] > 0 else 0 | |
| content_loss += self.weights[2] * self.criterion(x_vgg['relu3_3'], y_vgg['relu3_3']) if self.weights[2] > 0 else 0 | |
| content_loss += self.weights[3] * self.criterion(x_vgg['relu4_3'], y_vgg['relu4_3']) if self.weights[3] > 0 else 0 | |
| content_loss += self.weights[4] * self.criterion(x_vgg['relu5_3'], y_vgg['relu5_3']) if self.weights[4] > 0 else 0 | |
| return content_loss | |
| class Normalization(nn.Module): | |
| def __init__(self, device): | |
| super(Normalization, self).__init__() | |
| # .view the mean and std to make them [C x 1 x 1] so that they can | |
| # directly work with image Tensor of shape [B x C x H x W]. | |
| # B is batch size. C is number of channels. H is height and W is width. | |
| mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
| std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
| self.mean = mean.view(-1, 1, 1) | |
| self.std = std.view(-1, 1, 1) | |
| def forward(self, img): | |
| # normalize img | |
| return (img - self.mean) / self.std | |
| class VGG16(nn.Module): | |
| def __init__(self): | |
| super(VGG16, self).__init__() | |
| features = models.vgg16(pretrained=True).features | |
| self.relu1_1 = torch.nn.Sequential() | |
| self.relu1_2 = torch.nn.Sequential() | |
| self.relu2_1 = torch.nn.Sequential() | |
| self.relu2_2 = torch.nn.Sequential() | |
| self.relu3_1 = torch.nn.Sequential() | |
| self.relu3_2 = torch.nn.Sequential() | |
| self.relu3_3 = torch.nn.Sequential() | |
| self.relu4_1 = torch.nn.Sequential() | |
| self.relu4_2 = torch.nn.Sequential() | |
| self.relu4_3 = torch.nn.Sequential() | |
| self.relu5_1 = torch.nn.Sequential() | |
| self.relu5_2 = torch.nn.Sequential() | |
| self.relu5_3 = torch.nn.Sequential() | |
| for x in range(2): | |
| self.relu1_1.add_module(str(x), features[x]) | |
| for x in range(2, 4): | |
| self.relu1_2.add_module(str(x), features[x]) | |
| for x in range(4, 7): | |
| self.relu2_1.add_module(str(x), features[x]) | |
| for x in range(7, 9): | |
| self.relu2_2.add_module(str(x), features[x]) | |
| for x in range(9, 12): | |
| self.relu3_1.add_module(str(x), features[x]) | |
| for x in range(12, 14): | |
| self.relu3_2.add_module(str(x), features[x]) | |
| for x in range(14, 16): | |
| self.relu3_3.add_module(str(x), features[x]) | |
| for x in range(16, 18): | |
| self.relu4_1.add_module(str(x), features[x]) | |
| for x in range(18, 21): | |
| self.relu4_2.add_module(str(x), features[x]) | |
| for x in range(21, 23): | |
| self.relu4_3.add_module(str(x), features[x]) | |
| for x in range(23, 26): | |
| self.relu5_1.add_module(str(x), features[x]) | |
| for x in range(26, 28): | |
| self.relu5_2.add_module(str(x), features[x]) | |
| for x in range(28, 30): | |
| self.relu5_3.add_module(str(x), features[x]) | |
| # don't need the gradients, just want the features | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x,): | |
| relu1_1 = self.relu1_1(x) | |
| relu1_2 = self.relu1_2(relu1_1) | |
| relu2_1 = self.relu2_1(relu1_2) | |
| relu2_2 = self.relu2_2(relu2_1) | |
| relu3_1 = self.relu3_1(relu2_2) | |
| relu3_2 = self.relu3_2(relu3_1) | |
| relu3_3 = self.relu3_3(relu3_2) | |
| relu4_1 = self.relu4_1(relu3_3) | |
| relu4_2 = self.relu4_2(relu4_1) | |
| relu4_3 = self.relu4_3(relu4_2) | |
| relu5_1 = self.relu5_1(relu4_3) | |
| relu5_2 = self.relu5_2(relu5_1) | |
| relu5_3 = self.relu5_3(relu5_2) | |
| out = { | |
| 'relu1_1': relu1_1, | |
| 'relu1_2': relu1_2, | |
| 'relu2_1': relu2_1, | |
| 'relu2_2': relu2_2, | |
| 'relu3_1': relu3_1, | |
| 'relu3_2': relu3_2, | |
| 'relu3_3': relu3_3, | |
| 'relu4_1': relu4_1, | |
| 'relu4_2': relu4_2, | |
| 'relu4_3': relu4_3, | |
| 'relu5_1': relu5_1, | |
| 'relu5_2': relu5_2, | |
| 'relu5_3': relu5_3, | |
| } | |
| return out |