| import torch |
|
|
| from audiozen.acoustics.audio_feature import stft |
| from audiozen.loss import freq_MAE, mag_MAE |
|
|
|
|
| class SNRLoss(torch.nn.Module): |
| def __init__(self, tao=1e-3): |
| super().__init__() |
| self.tao = tao |
| self.eps = 1e-8 |
|
|
| def l2norm(self, mat, keepdim=False): |
| return torch.norm(mat, dim=-1, keepdim=keepdim) |
|
|
| def forward(self, x, s, length_list=None): |
| if x.shape != s.shape: |
| raise RuntimeError(f"Dimension mismatch when calculate si-snr, {x.shape} vs {s.shape}") |
|
|
| if length_list is None: |
| loss = 10 * torch.log10( |
| self.l2norm(s - x) ** 2 + self.tao * self.l2norm(s) ** 2 + self.eps |
| ) - 10 * torch.log10(self.l2norm(s) ** 2 + self.eps) |
| return loss.mean() |
|
|
| loss = 0 |
| for i, length in enumerate(length_list): |
| x_i = x[i, :length] |
| s_i = s[i, :length] |
|
|
| loss += 10 * torch.log10( |
| self.l2norm(s_i - x_i) ** 2 + self.tao * self.l2norm(s_i) ** 2 + self.eps |
| ) - 10 * torch.log10(self.l2norm(s_i) ** 2 + self.eps) |
|
|
| return loss.mean() |
|
|
|
|
| class FreqLoss(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, s, length_list=None): |
| if x.shape != s.shape: |
| raise RuntimeError(f"Dimension mismatch when calculate si-snr, {x.shape} vs {s.shape}") |
|
|
| if length_list is None: |
| loss = freq_MAE(s, x) + mag_MAE(s, x) |
| return loss |
| else: |
| loss = 0.0 |
| for i, length in enumerate(length_list): |
| x_i = x[i, :length] |
| s_i = s[i, :length] |
|
|
| loss += freq_MAE(s_i, x_i) + mag_MAE(s_i, x_i) |
|
|
| return loss.mean() |
|
|
|
|
| class MultiResolutionL1SpecLoss(torch.nn.Module): |
| def __init__( |
| self, window_sz: list = [512, 1024, 2048, 256, 128], hop_sz: list = [256, 512, 1024, 128, 64], eps=1e-8 |
| ): |
| super().__init__() |
| self.eps = eps |
| self.window_sz = window_sz |
| self.hop_sz = hop_sz |
|
|
| def forward(self, est, ref): |
| |
| scaling_factor = torch.sum(est * ref, -1, keepdim=True) / (torch.sum(est**2, -1, keepdim=True) + self.eps) |
|
|
| time_domain_loss = torch.mean((est * scaling_factor - ref).abs(), dim=-1) |
| spectral_loss = torch.zeros_like(time_domain_loss) |
|
|
| for win, hop in zip(self.window_sz, self.hop_sz): |
| est_mag, *_ = stft(est, n_fft=win, hop_length=hop, win_length=win, window=None) |
| ref_mag, *_ = stft(ref, n_fft=win, hop_length=hop, win_length=win, window=None) |
|
|
| spectral_loss += torch.mean((est_mag - ref_mag).abs(), dim=(1, 2)) |
|
|
| loss = 0.5 * time_domain_loss + 0.5 * spectral_loss / len(self.window_sz) |
|
|
| return loss.mean() |
|
|
|
|
| if __name__ == "__main__": |
| import numpy as np |
|
|
| x = np.random.rand(2, 16000) |
| s = np.random.rand(2, 16000) |
|
|
| x = torch.tensor(x, dtype=torch.float32) |
| s = torch.tensor(s, dtype=torch.float32) |
|
|
| loss = MultiResolutionL1SpecLoss() |
| print(loss(x, s)) |
|
|