MusicLSTMDemo / MQGAN /losses.py
ZDisket
add1
827b824
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSGANLoss(nn.Module):
def __init__(self, real_label=1.0, fake_label=0.0, decay=0.99, use_lecam=True):
super().__init__()
self.real_label = real_label
self.fake_label = fake_label
self.decay = decay
self.use_lecam = use_lecam
# we'll compute per‑element MSE and reduce manually
self.criterion = nn.MSELoss(reduction='none')
# EMA buffers
self.register_buffer("ema_real", torch.tensor(0.0))
self.register_buffer("ema_fake", torch.tensor(0.0))
self.ema_initialized = False
def _masked_mse(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor=None):
# pred/target: (..., N) same shape. mask: same shape, bool or float, or None.
err = self.criterion(pred, target)
if mask is not None:
# ensure float mask
m = mask.float()
err = err * m
valid = m.sum()
if valid.item() > 0:
return err.sum() / valid
else:
return torch.tensor(0., device=pred.device)
else:
# default mean over all elements
return err.mean()
def update_ema(self, real_out: torch.Tensor, fake_out: torch.Tensor,
real_mask: torch.Tensor=None, fake_mask: torch.Tensor=None):
# compute means
if real_mask is not None:
m = real_mask.float()
real_mean = (real_out * m).sum() / m.sum().clamp(min=1)
else:
real_mean = real_out.mean()
if fake_mask is not None:
m = fake_mask.float()
fake_mean = (fake_out * m).sum() / m.sum().clamp(min=1)
else:
fake_mean = fake_out.mean()
if not self.ema_initialized:
self.ema_real.copy_(real_mean.detach())
self.ema_fake.copy_(fake_mean.detach())
self.ema_initialized = True
else:
self.ema_real.mul_(self.decay).add_((1 - self.decay) * real_mean.detach())
self.ema_fake.mul_(self.decay).add_((1 - self.decay) * fake_mean.detach())
def lecam_loss(self, real_out: torch.Tensor, fake_out: torch.Tensor,
real_mask: torch.Tensor=None, fake_mask: torch.Tensor=None):
# Ensure EMA tensors are on the same device as the outputs
device = real_out.device
self.ema_real = self.ema_real.to(device)
self.ema_fake = self.ema_fake.to(device)
ema_r = self.ema_real.detach()
ema_f = self.ema_fake.detach()
# LeCam on real
if real_mask is not None:
diff_r = (real_out - ema_f).clamp(min=0) * real_mask.float()
term_r = diff_r.pow(2).sum() / real_mask.float().sum().clamp(min=1)
else:
term_r = ((real_out - ema_f).clamp(min=0) ** 2).mean()
# LeCam on fake
if fake_mask is not None:
diff_f = (ema_r - fake_out).clamp(min=0) * fake_mask.float()
term_f = diff_f.pow(2).sum() / fake_mask.float().sum().clamp(min=1)
else:
term_f = ((ema_r - fake_out).clamp(min=0) ** 2).mean()
return term_r + term_f
def discriminator_loss(
self,
real_output: torch.Tensor,
fake_output: torch.Tensor,
real_mask: torch.Tensor=None,
fake_mask: torch.Tensor=None
):
"""
real_output, fake_output: (B, 1, L)
real_mask, fake_mask: same shape bool mask, or None.
"""
# prepare targets
real_tgt = torch.full_like(real_output, self.real_label)
fake_tgt = torch.full_like(fake_output, self.fake_label)
# LSGAN losses
real_loss = self._masked_mse(real_output, real_tgt, real_mask)
fake_loss = self._masked_mse(fake_output, fake_tgt, fake_mask)
loss = 0.5 * (real_loss + fake_loss)
if self.use_lecam:
self.update_ema(real_output, fake_output, real_mask, fake_mask)
loss = loss + self.lecam_loss(real_output, fake_output, real_mask, fake_mask)
return loss
def generator_loss(
self,
fake_output: torch.Tensor,
fake_mask: torch.Tensor=None
):
real_tgt = torch.full_like(fake_output, self.real_label)
return self._masked_mse(fake_output, real_tgt, fake_mask)