|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class GANLoss(nn.Module): |
|
|
def __init__(self, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor, opt=None): |
|
|
super(GANLoss, self).__init__() |
|
|
self.real_label = target_real_label |
|
|
self.fake_label = target_fake_label |
|
|
self.real_label_tensor = None |
|
|
self.fake_label_tensor = None |
|
|
self.zero_tensor = None |
|
|
self.Tensor = tensor |
|
|
self.opt = opt |
|
|
|
|
|
def get_target_tensor(self, input, target_is_real): |
|
|
if target_is_real: |
|
|
return torch.ones_like(input).detach() |
|
|
else: |
|
|
return torch.zeros_like(input).detach() |
|
|
|
|
|
def get_zero_tensor(self, input): |
|
|
return torch.zeros_like(input).detach() |
|
|
|
|
|
def loss(self, inputs, target_is_real, for_discriminator=True): |
|
|
target_tensor = self.get_target_tensor(inputs, target_is_real) |
|
|
loss = F.binary_cross_entropy_with_logits(inputs, target_tensor) |
|
|
return loss |
|
|
|
|
|
def __call__(self, inputs, target_is_real, for_discriminator=True): |
|
|
|
|
|
|
|
|
if isinstance(inputs, list): |
|
|
loss = 0 |
|
|
for pred_i in inputs: |
|
|
if isinstance(pred_i, list): |
|
|
pred_i = pred_i[-1] |
|
|
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) |
|
|
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) |
|
|
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) |
|
|
loss += new_loss |
|
|
return loss / len(inputs) |
|
|
else: |
|
|
return self.loss(inputs, target_is_real, for_discriminator) |
|
|
|