ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
import torch.nn as nn
from Train.Loss.LossFunction.SingleScaleSpectralLoss import SingleScaleSpectralLoss
class MultiScaleSpectralLoss(nn.Module):
def __init__(
self,
n_ffts: list = [2048, 1024, 512, 256],
alpha=1.0,
overlap=0.75,
eps=1e-7):
super().__init__()
self.losses = nn.ModuleList([SingleScaleSpectralLoss(n_fft, alpha, overlap, eps) for n_fft in n_ffts])
def forward(self, x_pred, x_true):
# cut reverbation off
x_pred = x_pred[..., : x_true.shape[-1]]
losses = [loss(x_pred, x_true) for loss in self.losses]
return sum(losses).sum()