ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
import torch.nn as nn
from torchaudio.transforms import Spectrogram
import torch.nn.functional as F
class SingleScaleSpectralLoss(nn.Module):
def __init__(self, n_fft, alpha=1.0, overlap=0.75, eps=1e-7):
super(SingleScaleSpectralLoss,self).__init__()
self.n_fft = n_fft
self.alpha = alpha
self.eps = eps
self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
self.spec = Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length)
def forward(self, x_pred, x_true):
#spec = Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length)
#spec.to(x_pred.device)
S_true = self.spec(x_true)
S_pred = self.spec(x_pred)
linear_term = F.l1_loss(S_pred, S_true)
log_term = F.l1_loss((S_true + self.eps).log2(), (S_pred + self.eps).log2())
loss = linear_term + self.alpha * log_term
return loss