Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| class GANLoss(nn.Module): | |
| def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0): | |
| super().__init__() | |
| self.register_buffer('real_label', torch.tensor(real_label)) | |
| self.register_buffer('fake_label', torch.tensor(fake_label)) | |
| if gan_mode == 'vanilla': | |
| self.loss = nn.BCEWithLogitsLoss() | |
| elif gan_mode == 'lsgan': | |
| self.loss = nn.MSELoss() | |
| def get_labels(self, preds, target_is_real): | |
| if target_is_real: | |
| labels = self.real_label | |
| else: | |
| labels = self.fake_label | |
| return labels.expand_as(preds) | |
| def __call__(self, preds, target_is_real): | |
| labels = self.get_labels(preds, target_is_real) | |
| loss = self.loss(preds, labels) | |
| return loss |