import torch import torch.nn as nn import torch.nn.functional as F class GANLoss(nn.Module): def __init__(self, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor, opt=None): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_tensor = None self.fake_label_tensor = None self.zero_tensor = None self.Tensor = tensor self.opt = opt def get_target_tensor(self, input, target_is_real): if target_is_real: return torch.ones_like(input).detach() else: return torch.zeros_like(input).detach() def get_zero_tensor(self, input): return torch.zeros_like(input).detach() def loss(self, inputs, target_is_real, for_discriminator=True): target_tensor = self.get_target_tensor(inputs, target_is_real) loss = F.binary_cross_entropy_with_logits(inputs, target_tensor) return loss def __call__(self, inputs, target_is_real, for_discriminator=True): # computing loss is a bit complicated because |input| may not be # a tensor, but list of tensors in case of multiscale discriminator if isinstance(inputs, list): loss = 0 for pred_i in inputs: if isinstance(pred_i, list): pred_i = pred_i[-1] loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) loss += new_loss return loss / len(inputs) else: return self.loss(inputs, target_is_real, for_discriminator)