| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torchaudio |
| from torch.nn import functional as F |
| from .core import upsample |
| |
| class SSSLoss(nn.Module): |
| """ |
| Single-scale Spectral Loss. |
| """ |
|
|
| def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7): |
| super().__init__() |
| self.n_fft = n_fft |
| self.alpha = alpha |
| self.eps = eps |
| self.hop_length = int(n_fft * (1 - overlap)) |
| self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False) |
| |
| def forward(self, x_true, x_pred): |
| S_true = self.spec(x_true) + self.eps |
| S_pred = self.spec(x_pred) + self.eps |
| |
| converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2))) |
| |
| log_term = F.l1_loss(S_true.log(), S_pred.log()) |
|
|
| loss = converge_term + self.alpha * log_term |
| return loss |
| |
| |
| class RSSLoss(nn.Module): |
| ''' |
| Random-scale Spectral Loss. |
| ''' |
| |
| def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'): |
| super().__init__() |
| self.fft_min = fft_min |
| self.fft_max = fft_max |
| self.n_scale = n_scale |
| self.lossdict = {} |
| for n_fft in range(fft_min, fft_max): |
| self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device) |
| |
| def forward(self, x_pred, x_true): |
| value = 0. |
| n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,)) |
| for n_fft in n_ffts: |
| loss_func = self.lossdict[int(n_fft)] |
| value += loss_func(x_true, x_pred) |
| return value / self.n_scale |
| |
| |
| |
|
|