XCodec2_24khz / criterions /gan_loss.py
Respair's picture
Upload folder using huggingface_hub
59b7eeb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class GANLoss(nn.Module):
def __init__(self):
super(GANLoss, self).__init__()
def disc_loss(self, real, fake):
real_loss = F.mse_loss(real, torch.ones_like(real))
fake_loss = F.mse_loss(fake, torch.zeros_like(fake))
return real_loss, fake_loss
def gen_loss(self, fake):
gen_loss = F.mse_loss(fake, torch.ones_like(fake))
return gen_loss